From 46a58e0849f1743631ca53af46291526a9b22304 Mon Sep 17 00:00:00 2001 From: Kanvi Khanna Date: Mon, 18 Jul 2022 12:41:43 -0700 Subject: [PATCH 001/410] [oneDNN] Add 3D suuport to layout optimizer --- .../generic_layout_optimizer_test.cc | 86 ++++++++++++++ .../generic_layout_optimizer_transposer.cc | 108 +++++++++++++++--- .../generic_layout_optimizer_transposer.h | 10 ++ ...ric_layout_optimizer_transposer_factory.cc | 3 + 4 files changed, 194 insertions(+), 13 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_test.cc b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_test.cc index 255e3ec99995c8..fd5296389f52f7 100644 --- a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_test.cc @@ -41,6 +41,7 @@ using ::tensorflow::Scope; using ::tensorflow::ops::Conv2D; using ::tensorflow::ops::Identity; using ::tensorflow::ops::RandomUniform; +using ::tensorflow::ops::Conv3D; constexpr int kBatchSize = 32; constexpr int kWidth = 10; @@ -79,6 +80,10 @@ constexpr int kDepthOut = 16; { 0, 2, 3, 1 } #define PERMUTATION_DST_TO_SRC \ { 0, 3, 1, 2 } +#define DIMS_5D(n, d, h, w, c) \ + { n, c, d, h, w } +#define SRC_DATA_FORMAT_5D "NCDHW" +#define DST_DATA_FORMAT_5D "NDHWC" #endif // (GOOGLE_CUDA || TENSORFLOW_USE_ROCM) template @@ -159,6 +164,37 @@ Output SimpleConv2DBackpropInput(tensorflow::Scope* s, int input_size, return conv_backprop_input; } +template +Output SimpleConv3D(tensorflow::Scope* s, int input_size, int filter_size, + const string& padding, const string& device) { + int batch_size = 8; + int input_height = input_size; + int input_width = input_size; + int input_depth = 4; + int input_channel = 3; + int filter_count = 6; + int stride = 1; + TensorShape input_shape(DIMS_5D(batch_size, input_depth, input_height, + input_width, input_channel)); + Tensor input_data(DataTypeToEnum::value, input_shape); + test::FillIota(&input_data, static_cast(1)); + Output input = + ops::Const(s->WithOpName("Input"), Input::Initializer(input_data)); + + TensorShape filter_shape( + {filter_size, filter_size, filter_size, input_channel, filter_count}); + Tensor filter_data(DataTypeToEnum::value, filter_shape); + test::FillIota(&filter_data, static_cast(1)); + Output filter = + ops::Const(s->WithOpName("Filter"), Input::Initializer(filter_data)); + + Output conv = + ops::Conv3D(s->WithOpName("Conv3D").WithDevice(device), input, filter, + DIMS_5D(1, stride, stride, stride, 1), padding, + ops::Conv3D::Attrs().DataFormat(SRC_DATA_FORMAT_5D)); + return conv; +} + class GenericLayoutOptimizerTest : public GrapplerTest { protected: void SetUp() override { @@ -674,6 +710,56 @@ TEST_F(GenericLayoutOptimizerTest, PreserveInputShapes) { output_shapes.DebugString()); } +TEST_F(GenericLayoutOptimizerTest, OptimizeSimpleConv3DGraph_CPU) { +#if (GOOGLE_CUDA || TENSORFLOW_USE_ROCM) + GTEST_SKIP() << "CUDA or ROCm is enabled"; +#endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM) + // A simple graph contains 1 Conv3D node, 2 input and 1 output nodes. + // Data format is NCDHW on CPU. + Scope scope = Scope::NewRootScope(); + + auto conv3d = SimpleConv3D(&scope, 32, 1, "VALID", "/CPU:0"); + auto identity = Identity(scope.WithOpName("Output"), conv3d); + GrapplerItem item; + TF_ASSERT_OK(scope.ToGraphDef(&item.graph)); + + GenericLayoutOptimizer optimizer(REWRITER_CONFIG); + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output)); + + Status status; + utils::GraphView graph_view(&output, &status); + TF_ASSERT_OK(status); + // The expected optimized graph contains 2 extra sets of Transpose nodes and + // has the Conv3D's data_format set "NDHWC" on CPU. + auto* input_transpose_node = graph_view.GetNode( + absl::StrCat("Conv3D-0-Transpose", SRC_DATA_FORMAT_5D, "To", + DST_DATA_FORMAT_5D, "-LayoutOptimizer")); + + ASSERT_NE(input_transpose_node, nullptr); + ASSERT_EQ(input_transpose_node->NumRegularFanins(), 2); + VerifyRegularFaninMatch(input_transpose_node, 0, "Input", 0); + + auto* conv3d_node = graph_view.GetNode("Conv3D"); + ASSERT_NE(conv3d_node, nullptr); + ASSERT_EQ(conv3d_node->NumRegularFanins(), 2); + VerifyRegularFaninMatch(conv3d_node, 0, input_transpose_node->GetName(), 0); + VerifyRegularFaninMatch(conv3d_node, 1, "Filter", 0); + VerifyDataFormatAttributeMatch(conv3d_node, DST_DATA_FORMAT_5D); + + auto* output_transpose_node = graph_view.GetNode( + absl::StrCat("Conv3D-0-0-Transpose", DST_DATA_FORMAT_5D, "To", + SRC_DATA_FORMAT_5D, "-LayoutOptimizer")); + ASSERT_NE(output_transpose_node, nullptr); + ASSERT_EQ(output_transpose_node->NumRegularFanins(), 2); + VerifyRegularFaninMatch(output_transpose_node, 0, conv3d_node->GetName(), 0); + + auto* output_node = graph_view.GetNode("Output"); + ASSERT_NE(output_node, nullptr); + ASSERT_EQ(output_node->NumRegularFanins(), 1); + VerifyRegularFaninMatch(output_node, 0, output_transpose_node->GetName(), 0); +} + // TODO(yanzha): Add more complex Graph for test. } // namespace grappler diff --git a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc index 92feae4484ab18..c745e07380c705 100644 --- a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc +++ b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc @@ -97,6 +97,16 @@ bool IsNonFloatingConv2D(const utils::MutableNodeView& node) { return false; } +bool IsNonFloatingConv3D(const utils::MutableNodeView& node) { + if (IsConv3D(*node.node())) { + const auto* attr = node.GetAttr(kAttrT); + if (attr != nullptr) { + return !kDataTypeIsFloating.Contains(attr->type()); + } + } + return false; +} + // Utils for layout agnostic transposer. bool IsComparisonOp(const NodeDef& node) { @@ -276,8 +286,10 @@ bool Transposer::ShouldProcess(const TransposeContext& context, // Only transposes floating point nodes. const bool is_integer_conv2d = IsNonFloatingConv2D(node); + const bool is_integer_conv3d = IsNonFloatingConv3D(node); return is_on_target_device && data_format_match && !is_integer_conv2d && + !is_integer_conv3d && !context.nodes_to_preserve.contains(node_def->name()) && !(node.NumRegularFanouts() == 0 && node.NumControlledFanouts() == 0); } @@ -1005,6 +1017,28 @@ Status MaxPoolV2Transposer::TransposeNode(TransposeContext* context, return context->graph_view->GetMutationBuilder()->Apply(); } +Status MaxPool3DTransposer::TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) { + DCHECK(IsMaxPool3D(*node->node())); + // We check data_input's shape instead, because the shape inference of + // MaxPool3D is not able to infer the shape when ksize or strides is not + // constant. + const auto& data_fanin = node->GetRegularFanin(0); + auto* data_fanin_node = data_fanin.node_view(); + if (!ShouldProcess(*context, *node) || + !IsFanoutPortRankN(*data_fanin_node, data_fanin.index(), 5)) { + return Status::OK(); + } + ScopedDataFormatUpgrader data_format_upgrader(context, 5); + VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName() + << "' with op '" << node->GetOp() << "' from data format '" + << context->src_format << "' to '" << context->dst_format << "'"; + TF_RETURN_IF_ERROR(UpdateNode(context, node)); + TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose)); + TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose)); + return context->graph_view->GetMutationBuilder()->Apply(); +} + Status MaxPoolGradTransposer::TransposeNode(TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsMaxPoolGrad(*node->node()) || IsMaxPoolGradGradV1(*node->node())); @@ -1416,7 +1450,8 @@ bool MergeTransposer::IsEveryFaninAfterDstToSrcTransform( const TransposeContext& context, const utils::MutableNodeView& node) const { for (const auto& regular_fanin : node.GetRegularFanins()) { auto* regular_fanin_node = regular_fanin.node_view(); - if (IsFanoutPortRankN(*regular_fanin_node, regular_fanin.index(), 4) && + if ((IsFanoutPortRankN(*regular_fanin_node, regular_fanin.index(), 4) || + IsFanoutPortRankN(*regular_fanin_node, regular_fanin.index(), 5)) && ((IsAfterDstToSrcTransform(context, *regular_fanin_node) && IsLayoutAgnosticOp(*regular_fanin_node->node())) || IsLayoutOptimizerAddedDstToSrcTranspose(context, @@ -1431,7 +1466,12 @@ bool MergeTransposer::IsEveryFaninAfterDstToSrcTransform( Status MergeTransposer::TransposeNode(TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsMerge(*node->node())); - if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) || + const int rank = GetFaninPortRank(*node, 0); + if (rank != 4 && rank != 5) { + return Status::OK(); + } + ScopedDataFormatUpgrader data_format_upgrader(context, rank); + if (!ShouldProcess(*context, *node) || !IsEveryFaninAfterDstToSrcTransform(*context, *node)) { return OkStatus(); } @@ -1678,7 +1718,16 @@ Status SplitTransposer::TransposeNode(TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsSplit(*node->node())); const auto ports = GetDataFanoutPorts(*node); - if (!ShouldProcess(*context, *node) || !IsFanoutPortsRankN(*node, ports, 4) || + int rank = 4; + if (!IsFanoutPortsRankN(*node, ports, 4)) { + if (!IsFanoutPortsRankN(*node, ports, 5)) { + return Status::OK(); + } else { + rank = 5; + } + } + ScopedDataFormatUpgrader data_format_upgrader(context, rank); + if (!ShouldProcess(*context, *node) || !IsAfterDstToSrcTransform(*context, *node)) { return OkStatus(); } @@ -1694,7 +1743,16 @@ Status SplitVTransposer::TransposeNode(TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsSplitV(*node->node())); const auto ports = GetDataFanoutPorts(*node); - if (!ShouldProcess(*context, *node) || !IsFanoutPortsRankN(*node, ports, 4) || + int rank = 4; + if (!IsFanoutPortsRankN(*node, ports, 4)) { + if (!IsFanoutPortsRankN(*node, ports, 5)) { + return Status::OK(); + } else { + rank = 5; + } + } + ScopedDataFormatUpgrader data_format_upgrader(context, rank); + if (!ShouldProcess(*context, *node) || !IsAfterDstToSrcTransform(*context, *node)) { return OkStatus(); } @@ -1867,17 +1925,27 @@ Status StridedSliceTransposer::PermuteMask(TransposeContext* context, Status StridedSliceTransposer::TransposeNode(TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsStridedSlice(*node->node())); - if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) || - !IsFaninPortsDimsNIfConst(*node, {1, 2, 3}, {4}) || - !HasOnlyBeginEndMask(*node) || - !IsAfterDstToSrcTransform(*context, *node)) { + const int rank = GetFanoutPortRank(*node, 0); + if (rank != 4 && rank != 5) { + return Status::OK(); + } + ScopedDataFormatUpgrader data_format_upgrader(context, rank); + if (!ShouldProcess(*context, *node) || !HasOnlyBeginEndMask(*node) || + !IsAfterDstToSrcTransform(*context, *node) || + (!IsFaninPortsDimsNIfConst(*node, {1, 2, 3}, {4}) && + !IsFaninPortsDimsNIfConst(*node, {1, 2, 3, 4}, {5}))) { return OkStatus(); } TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose)); TF_RETURN_IF_ERROR(PermuteMask(context, node, "begin_mask")); TF_RETURN_IF_ERROR(PermuteMask(context, node, "end_mask")); - TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {1, 2, 3}, node, - kOpDataFormatVecPermute)); + if (rank == 4) { + TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {1, 2, 3}, node, + kOpDataFormatVecPermute)); + } else if (rank == 5) { + TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {1, 2, 3, 4}, node, + kOpDataFormatVecPermute)); + } TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose)); return context->graph_view->GetMutationBuilder()->Apply(); } @@ -1885,7 +1953,12 @@ Status StridedSliceTransposer::TransposeNode(TransposeContext* context, Status SwitchTransposer::TransposeNode(TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsSwitch(*node->node())); - if (!ShouldProcess(*context, *node) || !IsFaninPortRankN(*node, 0, 4) || + const int rank = GetFaninPortRank(*node, 0); + if (rank != 4 && rank != 5) { + return Status::OK(); + } + ScopedDataFormatUpgrader data_format_upgrader(context, rank); + if (!ShouldProcess(*context, *node) || !IsAfterDstToSrcTransform(*context, *node)) { return OkStatus(); } @@ -1898,7 +1971,12 @@ Status SwitchTransposer::TransposeNode(TransposeContext* context, Status TernaryOpTransposer::TransposeNode(TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsTernaryOp(*node->node())); - if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) || + const int rank = GetFanoutPortRank(*node, 0); + if (rank != 4 && rank != 5) { + return Status::OK(); + } + ScopedDataFormatUpgrader data_format_upgrader(context, rank); + if (!ShouldProcess(*context, *node) || !IsAfterDstToSrcTransform(*context, *node)) { return OkStatus(); } @@ -1968,7 +2046,7 @@ bool IsLayoutSensitiveOp(const NodeDef& node) { IsMaxPoolV2(node) || IsMaxPoolGrad(node) || IsMaxPoolGradV2(node) || IsMaxPoolGradGradV1(node) || IsMaxPoolGradGradV2(node) || IsConv3D(node) || IsConv3DBackpropInputV2(node) || - IsConv3DBackpropFilterV2(node); + IsConv3DBackpropFilterV2(node) || IsMaxPool3D(node); } bool IsDefaultLayoutAgnosticOp(const NodeDef& node) { @@ -2067,6 +2145,10 @@ bool IsUnaryGrad(const NodeDef& node) { bool IsMaxPoolV2(const NodeDef& node) { return node.op() == "MaxPoolV2"; } +bool IsMaxPool3D(const NodeDef& node) { return node.op() == "MaxPool3D"; } + +// TODO(intel-tf): Add support for MaxPoolGrad3D + bool IsMaxPoolGradV2(const NodeDef& node) { return node.op() == "MaxPoolGradV2"; } diff --git a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.h b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.h index edf319c1e02cd1..d8cf90c53983f2 100644 --- a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.h +++ b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.h @@ -308,6 +308,14 @@ class MaxPoolV2Transposer : public LayoutSensitiveOpTransposer { utils::MutableNodeView* node) override; }; +class MaxPool3DTransposer : public LayoutSensitiveOpTransposer { + public: + explicit MaxPool3DTransposer() : LayoutSensitiveOpTransposer() {} + + Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; +}; + class MaxPoolGradTransposer : public LayoutSensitiveOpTransposer { public: explicit MaxPoolGradTransposer() : LayoutSensitiveOpTransposer() {} @@ -627,6 +635,8 @@ bool IsUnaryGrad(const NodeDef& node); bool IsMaxPoolV2(const NodeDef& node); +bool IsMaxPool3D(const NodeDef& node); + bool IsMaxPoolGradV2(const NodeDef& node); bool IsMaxPoolGradGradV1(const NodeDef& node); diff --git a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer_factory.cc b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer_factory.cc index b2b608f3cb1d62..223c0d28447ddf 100644 --- a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer_factory.cc +++ b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer_factory.cc @@ -74,6 +74,9 @@ std::shared_ptr TransposerFactory::GetTransposer( if (IsMaxPoolGradV2(node) || IsMaxPoolGradGradV2(node)) { return GetOrCreateIfNotFound("MaxPoolGradV2"); } + if (IsMaxPool3D(node)) { + return GetOrCreateIfNotFound("MaxPool3D"); + } // Check layout agnostic ops. if (IsDefaultLayoutAgnosticOp(node)) { return GetOrCreateIfNotFound( From 68f0b741bd017105401fb9be73f8be6b7112d766 Mon Sep 17 00:00:00 2001 From: Kanvi Khanna Date: Wed, 20 Jul 2022 14:55:07 -0700 Subject: [PATCH 002/410] Fix build issue --- .../generic_layout_optimizer_test.cc | 28 ++++++++++++------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_test.cc b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_test.cc index fd5296389f52f7..133dc528fcc3e0 100644 --- a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_test.cc @@ -69,6 +69,10 @@ constexpr int kDepthOut = 16; { 0, 3, 1, 2 } #define PERMUTATION_DST_TO_SRC \ { 0, 2, 3, 1 } +#define DIMS_5D(n, d, h, w, c) \ + { n, d, h, w, c } +#define SRC_DATA_FORMAT_5D "NDHWC" +#define DST_DATA_FORMAT_5D "NCDHW" #else #define DIMS(n, h, w, c) \ { n, c, h, w } @@ -711,9 +715,6 @@ TEST_F(GenericLayoutOptimizerTest, PreserveInputShapes) { } TEST_F(GenericLayoutOptimizerTest, OptimizeSimpleConv3DGraph_CPU) { -#if (GOOGLE_CUDA || TENSORFLOW_USE_ROCM) - GTEST_SKIP() << "CUDA or ROCm is enabled"; -#endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM) // A simple graph contains 1 Conv3D node, 2 input and 1 output nodes. // Data format is NCDHW on CPU. Scope scope = Scope::NewRootScope(); @@ -730,6 +731,19 @@ TEST_F(GenericLayoutOptimizerTest, OptimizeSimpleConv3DGraph_CPU) { Status status; utils::GraphView graph_view(&output, &status); TF_ASSERT_OK(status); + + auto* conv3d_node = graph_view.GetNode("Conv3D"); + ASSERT_NE(conv3d_node, nullptr); + ASSERT_EQ(conv3d_node->NumRegularFanins(), 2); + VerifyRegularFaninMatch(conv3d_node, 1, "Filter", 0); + + auto* output_node = graph_view.GetNode("Output"); + ASSERT_NE(output_node, nullptr); + ASSERT_EQ(output_node->NumRegularFanins(), 1); + +#if (GOOGLE_CUDA || TENSORFLOW_USE_ROCM) + VerifyDataFormatAttributeMatch(conv3d_node, SRC_DATA_FORMAT_5D); +#else // The expected optimized graph contains 2 extra sets of Transpose nodes and // has the Conv3D's data_format set "NDHWC" on CPU. auto* input_transpose_node = graph_view.GetNode( @@ -740,11 +754,7 @@ TEST_F(GenericLayoutOptimizerTest, OptimizeSimpleConv3DGraph_CPU) { ASSERT_EQ(input_transpose_node->NumRegularFanins(), 2); VerifyRegularFaninMatch(input_transpose_node, 0, "Input", 0); - auto* conv3d_node = graph_view.GetNode("Conv3D"); - ASSERT_NE(conv3d_node, nullptr); - ASSERT_EQ(conv3d_node->NumRegularFanins(), 2); VerifyRegularFaninMatch(conv3d_node, 0, input_transpose_node->GetName(), 0); - VerifyRegularFaninMatch(conv3d_node, 1, "Filter", 0); VerifyDataFormatAttributeMatch(conv3d_node, DST_DATA_FORMAT_5D); auto* output_transpose_node = graph_view.GetNode( @@ -754,10 +764,8 @@ TEST_F(GenericLayoutOptimizerTest, OptimizeSimpleConv3DGraph_CPU) { ASSERT_EQ(output_transpose_node->NumRegularFanins(), 2); VerifyRegularFaninMatch(output_transpose_node, 0, conv3d_node->GetName(), 0); - auto* output_node = graph_view.GetNode("Output"); - ASSERT_NE(output_node, nullptr); - ASSERT_EQ(output_node->NumRegularFanins(), 1); VerifyRegularFaninMatch(output_node, 0, output_transpose_node->GetName(), 0); +#endif } // TODO(yanzha): Add more complex Graph for test. From e06346c7b12be5f28533f5e3199985f8500bad80 Mon Sep 17 00:00:00 2001 From: gaikwadrahul8 <115997457+gaikwadrahul8@users.noreply.github.com> Date: Mon, 13 Feb 2023 03:26:24 +0530 Subject: [PATCH 003/410] Updated Install libffi7 package step While following the TensorFlow Lite C++ minimal example instructions I encountered issue with libffi7 package because since Ubuntu 20.10 comes with libff8 instead of libffi7. Please check this reference(https://askubuntu.com/questions/1286772/libffi-so-7-cannot-open-shared-object-file-no-such-file-or-directory) from askubuntu and I installed libffi7 by manually downloading the deb package from ubuntu focal (20.04) By following below steps so It's better to add below steps for Ubuntu users who are using 20.10 or later version. 1. wget http://es.archive.ubuntu.com/ubuntu/pool/main/libf/libffi/libffi7_3.3-4_amd64.deb 2. sudo dpkg -i libffi7_3.3-4_amd64.deb --- tensorflow/lite/examples/minimal/README.md | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/tensorflow/lite/examples/minimal/README.md b/tensorflow/lite/examples/minimal/README.md index 76a44d463f4cd1..2535e0e2cd3ca5 100644 --- a/tensorflow/lite/examples/minimal/README.md +++ b/tensorflow/lite/examples/minimal/README.md @@ -14,13 +14,23 @@ sudo apt-get install cmake Or you can follow [the official cmake installation guide](https://cmake.org/install/) -#### Step 2. Clone TensorFlow repository +#### Step 2. Install libffi7 package(Optional) + +It requires libffi7. On Ubuntu 20.10 or later, you can simply run the following +command. + +```sh +wget http://es.archive.ubuntu.com/ubuntu/pool/main/libf/libffi/libffi7_3.3-4_amd64.deb +sudo dpkg -i libffi7_3.3-4_amd64.deb +``` + +#### Step 3. Clone TensorFlow repository ```sh git clone https://github.com/tensorflow/tensorflow.git tensorflow_src ``` -#### Step 3. Create CMake build directory and run CMake tool +#### Step 4. Create CMake build directory and run CMake tool ```sh mkdir minimal_build @@ -28,7 +38,7 @@ cd minimal_build cmake ../tensorflow_src/tensorflow/lite/examples/minimal ``` -#### Step 4. Build TensorFlow Lite +#### Step 5. Build TensorFlow Lite In the minimal_build directory, From 07ad66b922313704be2f1d7c958524add9999d0d Mon Sep 17 00:00:00 2001 From: Gauri1 Deshpande Date: Tue, 25 Apr 2023 17:49:07 -0700 Subject: [PATCH 004/410] enable bf16 kernels for FusedBatchNormV3 --- tensorflow/core/kernels/fused_batch_norm_op.cc | 12 ++++++++++++ tensorflow/core/kernels/fused_batch_norm_op_test.cc | 5 +++++ tensorflow/core/kernels/mkl/mkl_tmp_bf16_ops.cc | 10 ---------- tensorflow/python/ops/nn_fused_batchnorm_test.py | 2 -- 4 files changed, 17 insertions(+), 12 deletions(-) diff --git a/tensorflow/core/kernels/fused_batch_norm_op.cc b/tensorflow/core/kernels/fused_batch_norm_op.cc index 9050ed42c17a14..609c9f176986ac 100644 --- a/tensorflow/core/kernels/fused_batch_norm_op.cc +++ b/tensorflow/core/kernels/fused_batch_norm_op.cc @@ -1962,6 +1962,18 @@ REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV3") .TypeConstraint("U"), FusedBatchNormGradOpV3); + +REGISTER_KERNEL_BUILDER(Name("FusedBatchNormV3") + .Device(DEVICE_CPU) + .TypeConstraint("T") + .TypeConstraint("U"), + FusedBatchNormOpV3); + +REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV3") + .Device(DEVICE_CPU) + .TypeConstraint("T") + .TypeConstraint("U"), + FusedBatchNormGradOpV3); #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM REGISTER_KERNEL_BUILDER( diff --git a/tensorflow/core/kernels/fused_batch_norm_op_test.cc b/tensorflow/core/kernels/fused_batch_norm_op_test.cc index 4e4d91d5cc61a9..52e08068427b72 100644 --- a/tensorflow/core/kernels/fused_batch_norm_op_test.cc +++ b/tensorflow/core/kernels/fused_batch_norm_op_test.cc @@ -219,6 +219,7 @@ TEST_F(FusedBatchNormGradOpTest, Simple) { using fp32 = float; using fp16 = Eigen::half; +using bf16 = bfloat16; template static Graph* FusedBatchNormInference(int n, int h, int w, int c, @@ -324,9 +325,11 @@ static Graph* FusedBatchNormGrad(int n, int h, int w, int c, bool is_training, BM_FusedBatchNorm(64, 14, 14, 256, fp32, false, NHWC, cpu); BM_FusedBatchNorm(64, 14, 14, 256, fp16, false, NHWC, cpu); +BM_FusedBatchNorm(64, 14, 14, 256, bf16, false, NHWC, cpu); BM_FusedBatchNorm(64, 14, 14, 256, fp32, true, NHWC, cpu); BM_FusedBatchNorm(64, 14, 14, 256, fp16, true, NHWC, cpu); +BM_FusedBatchNorm(64, 14, 14, 256, bf16, true, NHWC, cpu); #ifdef GOOGLE_CUDA BM_FusedBatchNorm(64, 14, 14, 256, fp32, false, NHWC, gpu); @@ -375,6 +378,8 @@ BM_FusedBatchNorm(64, 14, 14, 256, fp16, true, NCHW, gpu); BM_FusedBatchNormGradResnetShapes(fp32, true, NHWC, cpu); BM_FusedBatchNormGradResnetShapes(fp32, false, NHWC, cpu); +BM_FusedBatchNormGradResnetShapes(bf16, true, NHWC, cpu); +BM_FusedBatchNormGradResnetShapes(bf16, false, NHWC, cpu); #ifdef GOOGLE_CUDA BM_FusedBatchNormGradResnetShapes(fp32, true, NHWC, gpu); diff --git a/tensorflow/core/kernels/mkl/mkl_tmp_bf16_ops.cc b/tensorflow/core/kernels/mkl/mkl_tmp_bf16_ops.cc index 5177eadfbe2015..ce2cb36f0d4e59 100644 --- a/tensorflow/core/kernels/mkl/mkl_tmp_bf16_ops.cc +++ b/tensorflow/core/kernels/mkl/mkl_tmp_bf16_ops.cc @@ -44,16 +44,6 @@ class RaiseBfloat16Error : public OpKernel { REGISTER_KERNEL_BUILDER( \ Name("_FusedConv2D").Device(DEVICE_CPU).TypeConstraint("T"), \ RaiseBfloat16Error); \ - REGISTER_KERNEL_BUILDER(Name("FusedBatchNormV3") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .TypeConstraint("U"), \ - RaiseBfloat16Error); \ - REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV3") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .TypeConstraint("U"), \ - RaiseBfloat16Error); \ REGISTER_KERNEL_BUILDER( \ Name("_FusedMatMul").Device(DEVICE_CPU).TypeConstraint("T"), \ RaiseBfloat16Error); diff --git a/tensorflow/python/ops/nn_fused_batchnorm_test.py b/tensorflow/python/ops/nn_fused_batchnorm_test.py index c54969c144225c..4dd5202f6cdb2c 100644 --- a/tensorflow/python/ops/nn_fused_batchnorm_test.py +++ b/tensorflow/python/ops/nn_fused_batchnorm_test.py @@ -411,8 +411,6 @@ def _runtests(self, x_shape, is_training, gradient_test=False, factors = [1.0, 0.6] for dtype in [np.float16, np.float32, dtypes.bfloat16.as_numpy_dtype]: for use_gpu in use_gpu_vals: - if dtype == dtypes.bfloat16.as_numpy_dtype and not use_gpu: - continue for data_format in data_format_list: if data_format == 'NHWC' or data_format == 'NDHWC': scale_shape = x_shape[-1:] From 894af8ce883332a08c45283d1312087ac71de839 Mon Sep 17 00:00:00 2001 From: mdfaijul Date: Fri, 28 Apr 2023 10:00:37 -0700 Subject: [PATCH 005/410] Fix to EagerContext API change. --- .../core/common_runtime/eager/mkl_eager_op_rewrite_test.cc | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite_test.cc b/tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite_test.cc index b01646b749890f..057fabdec12783 100644 --- a/tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite_test.cc +++ b/tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite_test.cc @@ -38,12 +38,13 @@ class EagerOpRewriteTest : public ::testing::Test { std::make_unique(DeviceFactory::NewDevice( "CPU", {}, "/job:localhost/replica:0/task:0/device:CPU:0")); bool async = false; - tensorflow::Rendezvous* rendezvous = - new tensorflow::IntraProcessRendezvous(device_mgr.get()); + auto rendezvous = + tsl::core::RefCountPtr( + new tensorflow::IntraProcessRendezvous(device_mgr.get())); eager_ctx_ = new tensorflow::EagerContext( SessionOptions(), tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, - async, device_mgr.get(), false, rendezvous, nullptr, nullptr, + async, device_mgr.get(), false, std::move(rendezvous), nullptr, nullptr, /*run_eager_op_as_function=*/true); EagerExecutor executor_(false); From 1cf60253d3142a8ad3d0e2fede3ec2cfd6c368ec Mon Sep 17 00:00:00 2001 From: shuw Date: Wed, 12 Apr 2023 21:12:48 -0700 Subject: [PATCH 006/410] Fix bias Add for high-rank tensor input --- .../compiler/xla/service/gpu/gemm_rewriter.cc | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc b/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc index a0b0ff439f1bd3..c192ccf4a5284b 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc @@ -522,11 +522,12 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { // bitcast(add(gemm(a, b), bitcast(broadcast(bias)))) -> // bitcast(gemm(a, b, bitcast(broadcast(bias)))) (FuseMatrixBiasAdd) // - if (Match(instr, - m::AddAnyOrder( - m::Bitcast(CublasLtMatmul(&existing_gemm).WithOneUser()) - .WithOneUser(), - m::Broadcast(&bias, m::Op()).WithOneUser()))) { + if (Match( + instr, + m::AddAnyOrder( + m::Bitcast(CublasLtMatmulMaybeF8(&existing_gemm).WithOneUser()) + .WithOneUser(), + m::Broadcast(&bias, m::Op()).WithOneUser()))) { TF_ASSIGN_OR_RETURN( HloInstruction * new_add, MakeBinaryHlo(HloOpcode::kAdd, existing_gemm, @@ -557,7 +558,8 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { // transformation, but it doesn't hurt anything. if (Match(instr, m::AddAnyOrder( - m::Bitcast(GemmOrCublasLtMatmul(&existing_gemm).WithOneUser()) + m::Bitcast( + GemmOrCublasLtMatmulMaybeF8(&existing_gemm).WithOneUser()) .WithOneUser(), m::Op(&bias).WithPredicate(is_not_broadcast)))) { HloInstruction *new_bitcast = @@ -573,8 +575,9 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { } if (Match(instr, - m::AddAnyOrder(GemmOrCublasLtMatmul(&existing_gemm).WithOneUser(), - m::Op(&bias).WithPredicate(is_not_broadcast)))) { + m::AddAnyOrder( + GemmOrCublasLtMatmulMaybeF8(&existing_gemm).WithOneUser(), + m::Op(&bias).WithPredicate(is_not_broadcast)))) { return FuseMatrixBiasAdd(instr, bias, existing_gemm); } From d8d2a7cc4b2f3372448c627c0f55574091af5a99 Mon Sep 17 00:00:00 2001 From: shuw Date: Thu, 13 Apr 2023 18:29:02 +0000 Subject: [PATCH 007/410] Add unittest --- .../compiler/xla/service/gpu/gemm_rewriter.cc | 7 +- .../service/gpu/tests/gemm_rewrite_test.cc | 71 ++++++++++++++++++- 2 files changed, 76 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc b/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc index c192ccf4a5284b..ac6dc34ab6d29e 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc @@ -1106,7 +1106,12 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { std::vector operands(gemm->operands().begin(), gemm->operands().end()); - operands.insert(operands.begin() + 2, MaybeConstantFoldBias(bias)); + HloInstruction* broadcast_bias = MaybeConstantFoldBias(bias); + if (gemm->custom_call_target() == kCublasLtMatmulF8CallTarget) { + operands.at(2) = broadcast_bias; + } else { + operands.insert(operands.begin() + 2, broadcast_bias); + } std::unique_ptr fused_op = gemm->CloneWithNewOperands(gemm->shape(), operands); diff --git a/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc index dcfa03beaeba06..aa41cf32f4f989 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc @@ -4198,7 +4198,6 @@ class ParameterizedFp8GemmRewriteTest : public ParameterizedGemmRewriteTest { return; } EXPECT_TRUE(RunAndCompare(hlo_text, error_spec)); - // Most FP8 tests directly create a GemmRewriter and check the output. // Here, also run the entire HLO pass pipeline to ensure no other passes // interfere with GemmRewriter's pattern matching. @@ -5259,6 +5258,76 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF32VectorBiasF8) { )"); } +TEST_P(ParameterizedFp8GemmRewriteTest, Rank3ScaledABUnscaledDVectorBiasF8) { +#if CUDA_VERSION < 12000 + GTEST_SKIP() << "A matrix bias on a matmul is only supported in CUDA 12"; +#endif + const char* hlo_text = R"( + HloModule test + + ENTRY test { + x = f8e4m3fn[4,16,16] parameter(0) + y = f8e4m3fn[16,32] parameter(1) + b = f32[32] parameter(2) + b_f16 = f16[32] convert(b) + b_bcast = f16[4,16,32] broadcast(b_f16), dimensions={2} + x_f16 = f16[4,16,16] convert(x) + y_f16 = f16[16,32] convert(y) + x_scale = f16[] parameter(3) + y_scale = f16[] parameter(4) + x_scale_bcast = f16[4,16,16] broadcast(x_scale), dimensions={} + y_scale_bcast = f16[16,32] broadcast(y_scale), dimensions={} + x_unscaled = f16[4,16,16] multiply(x_f16, x_scale_bcast) + x_unscaled_bitcast = f16[64,16] bitcast(x_unscaled) + y_unscaled = f16[16,32] multiply(y_f16, y_scale_bcast) + dot_a = f16[64,32] dot(x_unscaled_bitcast, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0} + dot_a_bitcast = f16[4,16,32]{2,1,0} bitcast(dot_a) + ROOT out = f16[4,16,32] add(dot_a_bitcast, b_bcast) + } + +)"; + + CheckFp8IfOnHopper(hlo_text, ErrorSpec{0.1, 0.1}); + RunAndFilecheckHloRewrite(hlo_text, + GemmRewriter(se::CudaComputeCapability{ + se::CudaComputeCapability::HOPPER, 0}), + R"( + +; CHECK-LABEL: ENTRY %test (x: f8e4m3fn[4,16,16], y: f8e4m3fn[16,32], b: f32[32], x_scale: f16[], y_scale: f16[]) -> f16[4,16,32] { +; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fn[4,16,16]{2,1,0} parameter(0) +; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f8e4m3fn[64,16]{1,0} bitcast([[P0]]) +; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[32,16]{1,0} transpose([[P1]]), dimensions={1,0} +; CHECK-NEXT: [[B:%[^ ]+]] = f32[32]{0} parameter(2) +; CHECK-NEXT: [[B_F16:%[^ ]+]] = f16[32]{0} convert([[B]]) +; CHECK-NEXT: [[B_BCAST:%[^ ]+]] = f16[4,16,32]{2,1,0} broadcast([[B_F16]]), dimensions={2} +; CHECK-NEXT: [[B_BITCAST:%[^ ]+]] = f16[64,32]{1,0} bitcast([[B_BCAST]]) +; CHECK-NEXT: [[P2:%[^ ]+]] = f16[] parameter(3) +; CHECK-NEXT: [[P2_CV:%[^ ]+]] = f32[] convert([[P2]]) +; CHECK-NEXT: [[P3:%[^ ]+]] = f16[] parameter(4) +; CHECK-NEXT: [[P3_CV:%[^ ]+]] = f32[] convert([[P3]]) +; CHECK-NEXT: [[C:%[^ ]+]] = f32[] constant(1) +; CHECK-NEXT: [[GEMM:%[^ ]+]] = f16[64,32]{1,0} custom-call([[P0_BITCAST]], [[P1_TRANSPOSE]], [[B_BITCAST]], [[P2_CV]], [[P3_CV]], /*index=5*/[[C]], [[C]]), +; CHECK: custom_call_target="__cublas$lt$matmul$f8", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":1 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"DEFAULT\" +; CHECK: }" +; CHECK: ROOT [[OUT:%[^ ]+]] = f16[4,16,32]{2,1,0} bitcast([[GEMM]]) + )"); +} + TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDVectorBiasThenReluActivationF8) { #if CUDA_VERSION < 12000 From 315d240858f48df25268b24d708a6120763dc03a Mon Sep 17 00:00:00 2001 From: shuw Date: Fri, 14 Apr 2023 08:34:16 -0700 Subject: [PATCH 008/410] Update unittest --- .../xla/service/gpu/tests/gemm_rewrite_test.cc | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc index aa41cf32f4f989..18c916a3f659ac 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc @@ -5264,7 +5264,6 @@ TEST_P(ParameterizedFp8GemmRewriteTest, Rank3ScaledABUnscaledDVectorBiasF8) { #endif const char* hlo_text = R"( HloModule test - ENTRY test { x = f8e4m3fn[4,16,16] parameter(0) y = f8e4m3fn[16,32] parameter(1) @@ -5284,15 +5283,24 @@ TEST_P(ParameterizedFp8GemmRewriteTest, Rank3ScaledABUnscaledDVectorBiasF8) { dot_a_bitcast = f16[4,16,32]{2,1,0} bitcast(dot_a) ROOT out = f16[4,16,32] add(dot_a_bitcast, b_bcast) } - )"; - CheckFp8IfOnHopper(hlo_text, ErrorSpec{0.1, 0.1}); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + GemmRewriter pass( + se::CudaComputeCapability{se::CudaComputeCapability::HOPPER, 0}); + TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); + EXPECT_TRUE(changed); + + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::Bitcast(m::CustomCall({"__cublas$lt$matmul$f8"}) + .WithShape(F16, {64, 32})) + .WithShape(F16, {4, 16, 32}))); + RunAndFilecheckHloRewrite(hlo_text, GemmRewriter(se::CudaComputeCapability{ se::CudaComputeCapability::HOPPER, 0}), R"( - ; CHECK-LABEL: ENTRY %test (x: f8e4m3fn[4,16,16], y: f8e4m3fn[16,32], b: f32[32], x_scale: f16[], y_scale: f16[]) -> f16[4,16,32] { ; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fn[4,16,16]{2,1,0} parameter(0) ; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f8e4m3fn[64,16]{1,0} bitcast([[P0]]) From 737953bc346b443917e711440e163b3c28b44673 Mon Sep 17 00:00:00 2001 From: shuw Date: Fri, 9 Jun 2023 02:57:59 -0700 Subject: [PATCH 009/410] WIP: dup slice --- .../compiler/xla/service/gpu/gemm_rewriter.cc | 205 +- .../service/gpu/tests/gemm_rewrite_test.cc | 4348 +---------------- 2 files changed, 342 insertions(+), 4211 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc b/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc index ac6dc34ab6d29e..d1f3fb3cf14215 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc @@ -578,6 +578,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { m::AddAnyOrder( GemmOrCublasLtMatmulMaybeF8(&existing_gemm).WithOneUser(), m::Op(&bias).WithPredicate(is_not_broadcast)))) { + std::cout << "uuuuuuuuuuuuuuuuuuuuuuuuuuuu\n"; return FuseMatrixBiasAdd(instr, bias, existing_gemm); } @@ -724,20 +725,48 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { return false; } + // Get the padded shape. + auto pad_shape = [&batch_dims](const Shape old_shape) { + Shape padded_shape = old_shape; + for (int i = 0; i < old_shape.rank(); ++i) { + if (!absl::c_linear_search(batch_dims, i)) { + int64_t padded_dimension = + RoundUpTo(old_shape.dimensions(i), 16); + padded_shape.set_dimensions(i, padded_dimension); + } + } + return padded_shape; + }; + + // If slice is needed, won't pattern matched to FuseMatrixBias + bool slice_needed = pad_shape(instr->shape()).dimensions() != instr->shape().dimensions(); // Fuse the possible addition of a matrix bias here to enable the subsequent // fusion of the scaling and conversion of D into the Custom Call. Fusing // a matrix bias is only supported with CUDA 12 and above. HloInstruction *c = nullptr, *add = nullptr; - - if (instr->user_count() == 1 && - instr->users()[0]->opcode() == HloOpcode::kAdd) { - HloInstruction *bias = instr->users()[0]->mutable_operand( - !instr->users()[0]->operand_index(instr)); + bool has_matrix_bias = false; + bool is_high_rank_input = a->shape().rank() > 2 || b->shape().rank() > 2; + if (slice_needed && ((instr->user_count() == 1 && + instr->users()[0]->opcode() == HloOpcode::kAdd) || + (instr->user_count() == 1 && instr->users()[0]->user_count() == 1 && + instr->users()[0]->opcode()== HloOpcode::kBitcast && + instr->users()[0]->users()[0]->opcode() == HloOpcode::kAdd) ) ) { + std::cout << "yyyyyyyyyyyyyyyyyyyyyyyyyy\n"; + std::cout << a->shape().rank() <<" ;"<< b->shape().rank()<users()[0]->mutable_operand(!instr->users()[0]->operand_index(instr)); + } else { + bias = instr->users()[0]->users()[0]->mutable_operand( + !instr->users()[0]->users()[0]->operand_index(instr->users()[0])); + } if (bias->opcode() != HloOpcode::kBroadcast) { - c = bias; - gemm_backend_config.set_beta(1.0); - add = instr->users()[0]; + // c = bias; + // gemm_backend_config.set_beta(1.0); + // add = instr->users()[0]; + has_matrix_bias = true; } + std::cout << "hhhhhhhhhhhhhhhhhhhhhhhhhhhhhhh\n"; } // Each operand must have exactly one contracting and one non-contracting @@ -834,19 +863,6 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { b = TransposeMatrix(b, b_contracting_dims[0], batch_dims); } - // Get the padded shape. - auto pad_shape = [&batch_dims](const Shape old_shape) { - Shape padded_shape = old_shape; - for (int i = 0; i < old_shape.rank(); ++i) { - if (!absl::c_linear_search(batch_dims, i)) { - int64_t padded_dimension = - RoundUpTo(old_shape.dimensions(i), 16); - padded_shape.set_dimensions(i, padded_dimension); - } - } - return padded_shape; - }; - // Pad the non-batch dimensions of the operands to multiples of 16 as // required by cuBLASLt. auto pad_operand = [&instr, &pad_shape](HloInstruction *&x) -> void { @@ -869,12 +885,18 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { return; }; - pad_operand(a); - pad_operand(b); + // Possible padding of ouput shape. + Shape new_output_shape = instr->shape(); + if (!has_matrix_bias) { + pad_operand(a); + pad_operand(b); + new_output_shape = pad_shape(instr->shape()); + } if (c != nullptr) { + std::cout <<" pppppppaaaaaaa\n"; pad_operand(c); } - Shape new_output_shape = pad_shape(instr->shape()); + std::vector operands_list = { a, b, scales_f32[0], scales_f32[1], one, one}; @@ -895,15 +917,20 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { // Slice the result of the GEMM if the operands were padded. HloInstruction *slice = nullptr; - if (new_output_shape.dimensions() != instr->shape().dimensions()) { + if (new_output_shape.dimensions() != instr->shape().dimensions() && !has_matrix_bias) { std::vector start_indices(instr->shape().rank(), 0); std::vector strides(instr->shape().rank(), 1); slice = instr->AddInstruction(HloInstruction::CreateSlice( instr->shape(), new_custom_call, start_indices, instr->shape().dimensions(), strides)); + std::cout << "CCCCCCCCCCCCCCCCCCCCCCCCCC\n"; } - TF_RETURN_IF_ERROR( - ReplaceInstruction(add ? add : instr, slice ? slice : new_custom_call)); + +// TF_RETURN_IF_ERROR( +// ReplaceInstruction(add ? add : instr, slice ? slice : new_custom_call)); + std::cout <<"exiting CreateF8CustomCall!\n"; +TF_RETURN_IF_ERROR( + ReplaceInstruction(instr, slice?slice: new_custom_call)); return true; } @@ -1048,7 +1075,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { } Status FuseMatrixBiasAdd(HloInstruction *instr, HloInstruction *bias, - const HloInstruction *gemm, + HloInstruction *gemm, HloInstruction *bitcast = nullptr) { TF_RET_CHECK(bias->shape() == (bitcast ? bitcast->shape() : gemm->shape())); @@ -1057,7 +1084,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { if (gemm->shape().element_type() == S32) { return OkStatus(); } - + std::cout << "11111111111111111111111111\n"; // Cublas gemm overwrites the bias matrix, so fusion is only possible if the // gemm is the only user. CublasLt gemm can operate out-of-place. bool can_overwrite_bias = [bias]() { @@ -1087,7 +1114,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { return in_out_alias_config.ParameterHasAlias(bias->parameter_number(), /*param_index=*/{}); }(); - bool want_to_fuse_bias = IsCublasLtMatmul(*gemm) || can_overwrite_bias; + bool want_to_fuse_bias = IsCublasLtMatmulF8(*gemm) || IsCublasLtMatmul(*gemm) || can_overwrite_bias; auto config = gemm->backend_config().value(); @@ -1096,24 +1123,119 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { bool supported_epilogue = ((config.epilogue() == GemmBackendConfig::DEFAULT) || (config.epilogue() == GemmBackendConfig::BIAS)); - + std::cout << "22222222222222222222\n"; + std::cout << config.beta() << "; " << want_to_fuse_bias << "; " + << gemm->user_count() << "; " << supported_epilogue << "; \n"; if ((config.beta() != 0) || !want_to_fuse_bias || (gemm->user_count() != 1) || !supported_epilogue) { return OkStatus(); } - + std::cout << "333333333333333333\n"; config.set_beta(1.0); std::vector operands(gemm->operands().begin(), gemm->operands().end()); - HloInstruction* broadcast_bias = MaybeConstantFoldBias(bias); + HloInstruction *broadcast_bias = MaybeConstantFoldBias(bias); + // if (gemm->custom_call_target() == kCublasLtMatmulF8CallTarget) { + // operands.at(2) = broadcast_bias; + // } else + + std::cout << "44444444444444444444\n"; + if (gemm->custom_call_target() == kCublasLtMatmulF8CallTarget) { - operands.at(2) = broadcast_bias; - } else { + absl::Span batch_dims = + config.dot_dimension_numbers().rhs_batch_dimensions(); + // Get the padded shape. + auto pad_shape = [&batch_dims](const Shape old_shape) { + Shape padded_shape = old_shape; + for (int i = 0; i < old_shape.rank(); ++i) { + if (!absl::c_linear_search(batch_dims, i)) { + int64_t padded_dimension = + RoundUpTo(old_shape.dimensions(i), 16); + padded_shape.set_dimensions(i, padded_dimension); + } + } + return padded_shape; + }; + + // Pad the non-batch dimensions of the operands to multiples of 16 as + // required by cuBLASLt. + auto pad_operand = [&instr, &pad_shape](HloInstruction *&x) -> void { + PaddingConfig padding_config; + Shape padded_shape = pad_shape(x->shape()); + for (int i = 0; i < x->shape().rank(); ++i) { + auto dimension = padding_config.add_dimensions(); + dimension->set_edge_padding_low(0); + dimension->set_edge_padding_high(padded_shape.dimensions(i) - + x->shape().dimensions(i)); + dimension->set_interior_padding(0); + } + if (!ShapeUtil::Equal(padded_shape, x->shape())) { + HloInstruction *zero = + instr->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(x->shape().element_type()))); + x = instr->AddInstruction( + HloInstruction::CreatePad(padded_shape, x, zero, padding_config)); + } + return; + }; + HloInstruction *a = operands[0]; + HloInstruction *b = operands[1]; + pad_operand(a); + pad_operand(b); + operands[0] = a; + operands[1] = b; + pad_operand(broadcast_bias); operands.insert(operands.begin() + 2, broadcast_bias); - } - std::unique_ptr fused_op = + Shape new_output_shape = pad_shape(instr->shape()); + bool need_padding = + new_output_shape.dimensions() != instr->shape().dimensions(); + // HloInstruction *new_custom_call = + // gemm->parent()->AddInstruction(HloInstruction::CreateCustomCall( + // ShapeUtil::MakeShapeWithDenseLayout( + // instr->shape().element_type(), new_output_shape.dimensions(), + // instr->shape().layout().minor_to_major()), + // operands, kCublasLtMatmulF8CallTarget)); + std::unique_ptr new_custom_call = + // gemm->CloneWithNewShape(new_output_shape, operands); + gemm->CloneWithNewOperands(new_output_shape,operands); + TF_RETURN_IF_ERROR(new_custom_call->set_backend_config(config)); + + std::cout << "555555555555555\n"; + + std::cout << "6666666666666666\n"; + std::cout << "need_padding"<<": "<< need_padding <<"\n"; + TF_RETURN_IF_ERROR(SetName(instr->GetModule(), new_custom_call.get())); + // Slice the result of the GEMM if the operands were padded. + HloInstruction *slice = nullptr; + if (need_padding) { + std::vector start_indices(instr->shape().rank(), 0); + std::vector strides(instr->shape().rank(), 1); + slice = instr->AddInstruction(HloInstruction::CreateSlice( + instr->shape(), new_custom_call.get(), start_indices, + instr->shape().dimensions(), strides)); + + } + if (slice) { + new_custom_call = slice->CloneWithNewOperands( + slice->shape(), {slice->parent()->AddInstruction(std::move(new_custom_call))}); + } + std::cout << new_custom_call.get()->ToString() << std::endl; + std::cout << instr->ToString() << std::endl; + std::cout << instr->parent()->ToString() << std::endl; + std::cout << slice->ToString() << std::endl; + std::cout << slice->parent()->ToString() << std::endl; + std::cout << "7777777777777777777\n"; + // return ReplaceInstruction(instr, slice); + return ReplaceWithNewInstruction(instr,std::move(new_custom_call)); + + + + } else { + operands.insert(operands.begin() + 2, broadcast_bias); + + std::unique_ptr fused_op = gemm->CloneWithNewOperands(gemm->shape(), operands); TF_RETURN_IF_ERROR(fused_op->set_backend_config(config)); @@ -1145,8 +1267,15 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { bitcast->shape(), {bitcast->parent()->AddInstruction(std::move(fused_op))}); } - + std::cout <<"at exiting fusematrixbiasadd:\n"; + std::cout << instr->ToString() << std::endl; + std::cout <<"----------------------------\n"; + std::cout << fused_op.get()->ToString() << std::endl; + std::cout <<"at exiting fusematrixbiasadd print done\n"; return ReplaceWithNewInstruction(instr, std::move(fused_op)); + } + + } StatusOr FuseVectorBiasAdd(HloInstruction *instr, diff --git a/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc index 18c916a3f659ac..78c8fbd4b42bbb 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc @@ -64,4122 +64,38 @@ class GemmRewriteTest : public GpuCodegenTest { bool tf32_state_; }; -TEST_F(GemmRewriteTest, CheckCustomCallTarget) { - const char* hlo_text = R"( -HloModule SimpleGemm - -ENTRY AddDotsFunc { - x = f32[2,3] parameter(0) - y = f32[3,4] parameter(1) - ROOT dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} -} - -)"; - - DebugOptions debug_options = GetDebugOptionsForTest(); - if (debug_options.xla_gpu_enable_cublaslt()) { - MatchOptimizedHlo(hlo_text, - R"(; CHECK: custom_call_target="__cublas$lt$matmul")"); - } else { - MatchOptimizedHlo(hlo_text, - R"(; CHECK: custom_call_target="__cublas$gemm")"); - } -} - -TEST_F(GemmRewriteTest, TestBatchedAutotuning) { - if (GetCudaComputeCapability().IsAtLeast(se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() - << "There is no autotuning starting with the Nvidia Ampere generation"; - } - const char* hlo_text = R"( -HloModule ComplexDotMultipleNonContracting - -ENTRY %test { - %lhs = f32[7,17,10,13]{3,2,1,0} parameter(0) - %rhs = f32[7,9,10,13,6]{4,3,2,1,0} parameter(1) - ROOT %dot = f32[10,7,17,9,6]{4,3,2,1,0} dot(%lhs, %rhs), lhs_batch_dims={2,0}, rhs_batch_dims={2,0}, lhs_contracting_dims={3}, rhs_contracting_dims={3} -} - -)"; - - MatchOptimizedHlo(hlo_text, - R"( -; CHECK: selected_algorithm - )"); -} - -TEST_F(GemmRewriteTest, SimpleRewriteDeterministic) { - const char* hlo_text = R"( -HloModule SimpleGemm - -ENTRY AddDotsFunc { - x = f32[128,128] parameter(0) - y = f32[128,128] parameter(1) - ROOT dot_a = f32[128,128] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} -} -)"; - - ErrorSpec error_spec = [&] { - DebugOptions debug_options = GetDebugOptionsForTest(); - if (debug_options.xla_gpu_enable_cublaslt()) { - return ErrorSpec{1e-3, 1e-3}; - } else { - return ErrorSpec{1e-3, 1e-3}; - } - }(); - - auto get_module = [&]() { - HloModuleConfig config; - DebugOptions debug_options = GetDebugOptionsForTest(); - debug_options.set_xla_gpu_deterministic_ops(true); - config.set_debug_options(debug_options); - return ParseAndReturnVerifiedModule(hlo_text, config); - }; - - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr optimized_module, - backend().compiler()->RunHloPasses( - *get_module(), backend().default_stream_executor(), - backend().default_stream_executor()->GetAllocator())); - - StatusOr filecheck_result = RunFileCheck(optimized_module->ToString(), - R"( -; CHECK: custom_call_target="__cublas${{(lt\$matmul|gemm)}}" - )"); - TF_ASSERT_OK(filecheck_result.status()); - EXPECT_TRUE(filecheck_result.value()); - EXPECT_TRUE(RunAndCompare(*get_module(), error_spec)); -} - -TEST_F(GemmRewriteTest, BF16GemmCodeGen) { - const char* hlo_text = R"( -HloModule bf16codegendgemm - -ENTRY bf16gemm { - %parameter.1 = bf16[3]{0} parameter(0) - %parameter.2 = bf16[3]{0} parameter(1) - ROOT %dot.3 = bf16[] dot(bf16[3]{0} %parameter.1, bf16[3]{0} %parameter.2), lhs_contracting_dims={0}, rhs_contracting_dims={0}, operand_precision={highest,highest} -} - )"; - - MatchOptimizedHlo(hlo_text, R"( -; CHECK: [[P1:%[^ ]+]] = bf16[3]{0} parameter(1) -; CHECK: [[INSTR_1:%[^ ]+]] = f32[3]{0} convert([[P1]]) -; CHECK: [[P0:%[^ ]+]] = bf16[3]{0} parameter(0) -; CHECK: [[INSTR_3:%[^ ]+]] = f32[3]{0} convert([[P0]]) -; CHECK: [[INSTR_4:%[^ ]+]] = f32[3]{0} multiply([[INSTR_1]], [[INSTR_3]]) -; CHECK: [[INSTR_5:%[^ ]+]] = f32[] constant(0) -; CHECK: [[INSTR_6:%[^ ]+]] = f32[] reduce([[INSTR_4]], [[INSTR_5]]), dimensions={0}, to_apply=[[INSTR_7:%[^ ]+]] -; CHECK: ROOT [[INSTR_8:%[^ ]+]] = bf16[] convert([[INSTR_6]]) - )"); - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); -} - -TEST_F(GemmRewriteTest, BF16Transpose) { - const char* hlo_text = R"( -HloModule broadcast - -ENTRY broadcast { - p = bf16[9] parameter(0) - ROOT out = bf16[1,9] broadcast(p), dimensions={1} -} -)"; - - MatchOptimizedHlo(hlo_text, R"( -; CHECK: bf16[1,9]{1,0} bitcast -; CHECK: bf16[1,9]{1,0} copy -)"); - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); -} - -// A test fixture class for tests which should have similar results with legacy -// cublas and cublasLt -class ParameterizedGemmRewriteTest - : public GemmRewriteTest, - public ::testing::WithParamInterface { - public: - ParameterizedGemmRewriteTest() { - const bool kUsingCublasLt = GetParam(); - replacements_[kCustomCallTargetPlaceholder] = - kUsingCublasLt ? "__cublas$lt$matmul" : "__cublas$gemm"; - } - DebugOptions GetDebugOptionsForTest() override { - DebugOptions debug_options = GemmRewriteTest::GetDebugOptionsForTest(); - debug_options.set_xla_gpu_enable_cublaslt(GetParam()); - return debug_options; - } - void MatchOptimizedHlo(absl::string_view hlo, const absl::string_view pattern, - bool print_operand_shape = false) { - GemmRewriteTest::MatchOptimizedHlo( - hlo, absl::StrReplaceAll(pattern, replacements_), print_operand_shape); - } - absl::string_view CustomCallTarget() { - return replacements_[kCustomCallTargetPlaceholder]; - } - - protected: - absl::flat_hash_map replacements_; - - private: - static constexpr const char* kCustomCallTargetPlaceholder{ - "<>"}; -}; - -TEST_P(ParameterizedGemmRewriteTest, Simple) { - const char* hlo_text = R"( -HloModule test - -ENTRY test { - x = f32[2,3] parameter(0) - y = f32[3,4] parameter(1) - ROOT dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} -} - -)"; - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); - MatchOptimizedHlo(hlo_text, - R"( -; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4]) -> f32[2,4] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,4]{1,0} custom-call([[P0]], [[P1]]), -; CHECK: custom_call_target="<>", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":0 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"DEFAULT\" -; CHECK: }" -)"); -} - -TEST_P(ParameterizedGemmRewriteTest, SimpleRewrite) { - const char* hlo_text = R"( -HloModule SimpleGemm - -ENTRY AddDotsFunc { - x = f32[2,3] parameter(0) - y = f32[3,4] parameter(1) - ROOT dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} -} - -)"; - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); - MatchOptimizedHlo(hlo_text, - R"( -; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,3], y: f32[3,4]) -> f32[2,4] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,4]{1,0} custom-call([[P0]], [[P1]]), -; CHECK: custom_call_target="<>", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":0 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"DEFAULT\" -; CHECK: }" -)"); -} - -TEST_P(ParameterizedGemmRewriteTest, MultipleContractingDims) { - const char* hlo_text = R"( -HloModule MultipleContractingCheckGemm - -ENTRY AddDotsFunc { - x = f32[3,4,2] parameter(0) - y = f32[3,4,5] parameter(1) - ROOT dot_a = f32[2,5] dot(x, y), lhs_contracting_dims={0,1}, rhs_contracting_dims={0,1} -} - -)"; - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); - MatchOptimizedHlo(hlo_text, - R"( -; CHECK-NOT: copy -; -; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[3,4,2], y: f32[3,4,5]) -> f32[2,5] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f32[3,4,2]{2,1,0} parameter(0) -; CHECK-DAG: [[P1:%[^ ]+]] = f32[3,4,5]{2,1,0} parameter(1) -; CHECK-DAG: [[BITCAST0:%[^ ]+]] = f32[2,12]{0,1} bitcast([[P0]]) -; CHECK-DAG: [[BITCAST1:%[^ ]+]] = f32[12,5]{1,0} bitcast([[P1]]) -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,5]{1,0} custom-call([[BITCAST0]], [[BITCAST1]]), -; CHECK: custom_call_target="<>", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":0 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"DEFAULT\" -; CHECK: }" -)"); -} - -TEST_P(ParameterizedGemmRewriteTest, ArgTransposeFoldCheck) { - const char* hlo_text = R"( -HloModule ArgTransposeFoldGemm - -ENTRY AddDotsFunc { - x = f32[3,2] parameter(0) - y = f32[3,4] parameter(1) - x_transposed = f32[2,3] transpose(x), dimensions={1, 0} - ROOT dot_a = f32[2,4] dot(x_transposed, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} -} - -)"; - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); - MatchOptimizedHlo(hlo_text, - R"( -; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[3,2], y: f32[3,4]) -> f32[2,4] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f32[3,2]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,4]{1,0} custom-call([[P0]], [[P1]]), -; CHECK: custom_call_target="<>", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":0 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"DEFAULT\" -; CHECK: }" -)"); -} - -TEST_P(ParameterizedGemmRewriteTest, BatchedArgRowColTransposeFoldCheck) { - const char* hlo_text = R"( -HloModule BatchedArgRowColTransposeFoldGemm - -ENTRY AddDotsFunc { - x = f32[5,3,2] parameter(0) - y = f32[5,3,4] parameter(1) - x_transposed = f32[5,2,3] transpose(x), dimensions={0, 2, 1} - ROOT dot_a = f32[5,2,4] dot(x_transposed, y), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0} -} - -)"; - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3})); - MatchOptimizedHlo(hlo_text, - R"( -; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[5,3,2], y: f32[5,3,4]) -> f32[5,2,4] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f32[5,3,2]{2,1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f32[5,3,4]{2,1,0} parameter(1) -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[5,2,4]{2,1,0} custom-call([[P0]], [[P1]]), -; CHECK: custom_call_target="<>", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":0 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[\"0\"] -; CHECK-DAG: \"rhs_batch_dimensions\":[\"0\"] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"DEFAULT\" -; CHECK: }" -)"); -} - -TEST_P(ParameterizedGemmRewriteTest, BatchRowTransposeFoldCheck) { - const char* hlo_text = R"( -HloModule BatchRowTransposeFoldCheck - -ENTRY AddDotsFunc { - x = f32[2,5,3] parameter(0) - y = f32[5,3,4] parameter(1) - x_transposed = f32[5,2,3] transpose(x), dimensions={1, 0, 2} - ROOT dot_a = f32[5,2,4] dot(x_transposed, y), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0} -} - -)"; - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); - MatchOptimizedHlo(hlo_text, - R"( -; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,5,3], y: f32[5,3,4]) -> f32[5,2,4] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,5,3]{2,1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f32[5,3,4]{2,1,0} parameter(1) -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[5,2,4]{2,1,0} custom-call([[P0]], [[P1]]), -; CHECK: custom_call_target="<>", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":0 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"2\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_batch_dimensions\":[\"0\"] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"DEFAULT\" -; CHECK: }" -)"); -} - -TEST_P(ParameterizedGemmRewriteTest, BatchFromMinorDimTransposeIsNotFolded) { - const char* hlo_text = R"( -HloModule BatchFromMinorDimTransposeDoesntFold - -ENTRY AddDotsFunc { - x = f32[3,2,5] parameter(0) - y = f32[5,3,4] parameter(1) - x_transposed = f32[5,2,3] transpose(x), dimensions={2, 1, 0} - ROOT dot_a = f32[5,2,4] dot(x_transposed, y), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0} -} - -)"; - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); - MatchOptimizedHlo(hlo_text, - R"( -; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[3,2,5], y: f32[5,3,4]) -> f32[5,2,4] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f32[3,2,5]{2,1,0} parameter(0) -; CHECK-DAG: [[P1:%[^ ]+]] = f32[5,3,4]{2,1,0} parameter(1) -; CHECK-DAG: [[FUSION:%[^ ]+]] = f32[5,2,3]{2,1,0} transpose([[P0]]) -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[5,2,4]{2,1,0} custom-call([[FUSION]], [[P1]]), -; CHECK: custom_call_target="<>", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":0 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"2\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[\"0\"] -; CHECK-DAG: \"rhs_batch_dimensions\":[\"0\"] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"DEFAULT\" -; CHECK: }" -)"); -} - -TEST_P(ParameterizedGemmRewriteTest, LargeBatch) { - const char* hlo_text = R"( -HloModule BatchedArgRowColTransposeFoldGemm - -ENTRY AddDotsFunc { - x = f32[20000,4,3,2] parameter(0) - y = f32[20000,4,3,4] parameter(1) - ROOT dot_a = f32[20000,4,2,4] dot(x, y), lhs_contracting_dims={2}, rhs_contracting_dims={2}, lhs_batch_dims={0,1}, rhs_batch_dims={0,1} -} - -)"; - - // Batch sizes larger than 2^16-1 are not supported by cublasLt. Ensure that - // the custom_call_target is __cublas$gemm. - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3})); - MatchOptimizedHlo(hlo_text, - R"( -; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[20000,4,3,2], y: f32[20000,4,3,4]) -> f32[20000,4,2,4] { -; CHECK: [[P0:%[^ ]+]] = f32[20000,4,3,2]{3,2,1,0} parameter(0) -; CHECK: [[BC0:%[^ ]+]] = f32[80000,3,2]{2,1,0} bitcast([[P0]]) -; CHECK: [[P1:%[^ ]+]] = f32[20000,4,3,4]{3,2,1,0} parameter(1) -; CHECK: [[BC1:%[^ ]+]] = f32[80000,3,4]{2,1,0} bitcast([[P1]]) -; CHECK: [[OUT:%[^ ]+]] = f32[80000,2,4]{2,1,0} custom-call([[BC0]], [[BC1]]), -; CHECK: custom_call_target="__cublas$gemm", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":0 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[\"0\"] -; CHECK-DAG: \"rhs_batch_dimensions\":[\"0\"] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK: }" -; CHECK: ROOT {{[^ ]+}} = f32[20000,4,2,4]{3,2,1,0} bitcast([[OUT]]) -)"); -} - -TEST_P(ParameterizedGemmRewriteTest, InstrTransposeFoldCheck) { - const char* hlo_text = R"( -HloModule InstrTransposeFoldGemm - -ENTRY AddDotsFunc { - x = f32[2,3] parameter(0) - y = f32[3,4] parameter(1) - dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - ROOT out = f32[4,2] transpose(dot_a), dimensions={1, 0} -} - -)"; - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); - MatchOptimizedHlo(hlo_text, - R"( -; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,3], y: f32[3,4]) -> f32[4,2] { -; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) -; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[4,2]{1,0} custom-call([[P1]], [[P0]]), -; CHECK: custom_call_target="<>", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":0 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"DEFAULT\" -; CHECK: }" -)"); -} - -TEST_P(ParameterizedGemmRewriteTest, BatchedInstrLayoutTransposed) { - const char* hlo_text = R"( -HloModule BatchedInstrLayoutCheck - -ENTRY AddDotsFunc { - x = f32[5,2,3] parameter(0) - y = f32[5,3,4] parameter(1) - dot_a = f32[5,2,4] dot(x, y), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0} - ROOT out = f32[2,5,4] transpose(dot_a), dimensions={1, 0, 2} -} - -)"; - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); - MatchOptimizedHlo(hlo_text, - R"( -; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[5,2,3], y: f32[5,3,4]) -> f32[2,5,4] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f32[5,2,3]{2,1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f32[5,3,4]{2,1,0} parameter(1) -; CHECK-NEXT: [[GEMM:%[^ ]+]] = f32[5,2,4]{2,0,1} custom-call([[P0]], [[P1]]), -; CHECK: custom_call_target="<>", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":0 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"2\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[\"0\"] -; CHECK-DAG: \"rhs_batch_dimensions\":[\"0\"] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"DEFAULT\" -; CHECK: }" -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,5,4]{2,1,0} bitcast([[GEMM]]) -)"); -} - -TEST_P(ParameterizedGemmRewriteTest, BatchedInstrLayoutBatchNotInMinorDim) { - const char* hlo_text = R"( -HloModule BatchedInstrLayoutBatchNotInMinorDim - -ENTRY AddDotsFunc { - x = f32[5,2,3] parameter(0) - y = f32[5,3,4] parameter(1) - dot_a = f32[5,2,4] dot(x, y), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0} - ROOT out = f32[2,4,5] transpose(dot_a), dimensions={1, 2, 0} -} - -)"; - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); - MatchOptimizedHlo(hlo_text, - R"( -; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[5,2,3], y: f32[5,3,4]) -> f32[2,4,5] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f32[5,2,3]{2,1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f32[5,3,4]{2,1,0} parameter(1) -; CHECK-NEXT: [[GEMM:%[^ ]+]] = f32[5,2,4]{2,1,0} custom-call([[P0]], [[P1]]), -; CHECK: custom_call_target="<>", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":0 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"2\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[\"0\"] -; CHECK-DAG: \"rhs_batch_dimensions\":[\"0\"] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"DEFAULT\" -; CHECK: }" -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,4,5]{2,1,0} [[OP:[^ ]+]]([[GEMM]]) -)"); -} - -TEST_P(ParameterizedGemmRewriteTest, AlphaSimpleRewrite) { - const char* hlo_text = R"( -HloModule AlphaSimpleRewrite - -ENTRY AddDotsFunc { - x = f32[2,2] parameter(0) - y = f32[2,2] parameter(1) - k = f32[] constant(3.0) - k_broadcast = f32[2, 2] broadcast(k), dimensions={} - dot_a = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - ROOT dot_a_multiplied = f32[2, 2] multiply(dot_a, k_broadcast) -} - -)"; - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); - MatchOptimizedHlo(hlo_text, - R"( -; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,2], y: f32[2,2]) -> f32[2,2] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f32[2,2]{1,0} parameter(1) -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,2]{1,0} custom-call([[P0]], [[P1]]), -; CHECK: custom_call_target="<>", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":3 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":0 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"DEFAULT\" -; CHECK: }" -)"); -} - -TEST_P(ParameterizedGemmRewriteTest, ComplexAlphaSimpleRewrite) { - const char* hlo_text = R"( -HloModule ComplexAlphaSimpleRewrite - -ENTRY AddDotsFunc { - x = c64[2,2] parameter(0) - y = c64[2,2] parameter(1) - k = c64[] constant((3.0, 3.0)) - k_broadcast = c64[2, 2] broadcast(k), dimensions={} - dot_a = c64[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - ROOT dot_a_multiplied = c64[2, 2] multiply(dot_a, k_broadcast) -} - -)"; - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-4, 1e-5})); - MatchOptimizedHlo(hlo_text, - R"( -; CHECK-LABEL: ENTRY %AddDotsFunc (x: c64[2,2], y: c64[2,2]) -> c64[2,2] { -; CHECK-NEXT: [[P0:%[^ ]+]] = c64[2,2]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = c64[2,2]{1,0} parameter(1) -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = c64[2,2]{1,0} custom-call([[P0]], [[P1]]), -; CHECK: custom_call_target="<>", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":3 -; CHECK-DAG: \"alpha_imag\":3 -; CHECK-DAG: \"beta\":0 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"DEFAULT\" -; CHECK: }" -)"); -} - -TEST_P(ParameterizedGemmRewriteTest, AlphaMultipleUsersNoRewrite) { - const char* hlo_text = R"( -HloModule AlphaMultipleUsersNoRewrite - -ENTRY AddDotsFunc { - x = f32[2,2] parameter(0) - y = f32[2,2] parameter(1) - k = f32[] constant(3.0) - k_broadcast = f32[2, 2] broadcast(k), dimensions={} - dot_a = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - dot_a_multiplied = f32[2, 2] multiply(dot_a, k_broadcast) - ROOT out = f32[2,2] add(dot_a_multiplied, dot_a) -} - -)"; - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); - MatchOptimizedHlo(hlo_text, - R"( -; CHECK: {{[^ ]+}} = f32[2,2]{1,0} custom-call({{[^,]+}}, {{[^)]+}}), -; CHECK: custom_call_target="<>", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":0 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"DEFAULT\" -; CHECK: }" -)"); -} - -TEST_P(ParameterizedGemmRewriteTest, AlphaVectorNoRewrite) { - const char* hlo_text = R"( -HloModule AlphaVectorNoRewrite - -ENTRY AddDotsFunc { - x = f32[2,2] parameter(0) - y = f32[2,2] parameter(1) - alpha = f32[2] constant({1, 2}) - alpha_broadcast = f32[2,2] broadcast(alpha), dimensions={1} - dot = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - ROOT dot_a_multiplied = f32[2, 2] multiply(dot, alpha_broadcast) -} -)"; - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); - MatchOptimizedHlo(hlo_text, - R"( -; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,2], y: f32[2,2]) -> f32[2,2] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f32[2,2]{1,0} parameter(1) -; CHECK-NEXT: [[OUT:%[^ ]+]] = f32[2,2]{1,0} custom-call([[P0]], [[P1]]), -; CHECK: custom_call_target="<>", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":0 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"DEFAULT\" -; CHECK: }" -)"); -} - -TEST_P(ParameterizedGemmRewriteTest, BF16Gemm) { - const char* hlo_text = R"( -HloModule bf16gemm - -ENTRY bf16gemm { - %parameter.1 = bf16[12,4]{1,0} parameter(0) - %parameter.2 = bf16[4,8]{1,0} parameter(1) - ROOT %dot.8 = bf16[12,8] dot(bf16[12,4] %parameter.1, bf16[4,8] %parameter.2), lhs_contracting_dims={1}, rhs_contracting_dims={0} -} - )"; - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); - - if (GetCudaComputeCapability().IsAtLeast(se::CudaComputeCapability::AMPERE)) { - MatchOptimizedHlo(hlo_text, - R"( -; CHECK: bf16[16,8]{1,0} custom-call(bf16[16,8]{1,0} {{.*}}, bf16[8,8]{1,0} {{.*}}), custom_call_target="<>" - )", - /*print_operand_shape=*/true); - } else { - MatchOptimizedHlo(hlo_text, - R"( -; CHECK: bf16[12,8]{1,0} custom-call(bf16[12,4]{1,0} [[P0:%[^ ]+]], bf16[4,8]{1,0} [[P1:%[^ ]+]]), custom_call_target="<>" - )", - /*print_operand_shape=*/true); - } -} - -TEST_P(ParameterizedGemmRewriteTest, BF16GemmStrided) { - const char* hlo_text = R"( -HloModule bf16gemm - -ENTRY bf16gemm { - %parameter.1 = bf16[3,3,4] parameter(0) - %parameter.2 = bf16[3,3,2] parameter(1) - ROOT %dot.3 = bf16[3,4,2]{2,1,0} dot(bf16[3,3,4]{2,1,0} %parameter.1, bf16[3,3,2]{2,1,0} %parameter.2), lhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1}, operand_precision={highest,highest} -} - - )"; - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); - - if (GetCudaComputeCapability().IsAtLeast(se::CudaComputeCapability::AMPERE)) { - MatchOptimizedHlo(hlo_text, - R"( - ; CHECK: bf16[3,8,8]{2,1,0} custom-call(bf16[3,8,8]{2,1,0} {{.*}}, bf16[3,8,8]{2,1,0} {{.*}}), custom_call_target="<>" - )", - /*print_operand_shape=*/true); - } else { - MatchOptimizedHlo(hlo_text, - R"( - ; CHECK: ROOT [[OUT:%[^ ]+]] = bf16[3,4,2]{2,1,0} custom-call(bf16[3,3,4]{2,1,0} [[A:%[^ ]+]], bf16[3,3,2]{2,1,0} [[B:%[^ ]+]]), custom_call_target="<>" - )", - /*print_operand_shape=*/true); - } -} - -TEST_P(ParameterizedGemmRewriteTest, Int8Gemm) { - const char* hlo_text = R"( -HloModule int8gemm - -ENTRY int8gemm { - %parameter.1 = s8[12,4]{1,0} parameter(0) - %parameter.2 = s8[4,8]{1,0} parameter(1) - ROOT %dot.8 = s32[12,8] dot(s8[12,4] %parameter.1, s8[4,8] %parameter.2), lhs_contracting_dims={1}, rhs_contracting_dims={0} -} - )"; - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); - - if (GetCudaComputeCapability().IsAtLeast(se::CudaComputeCapability::VOLTA)) { - MatchOptimizedHlo(hlo_text, - R"( -; CHECK: s32[12,8]{1,0} custom-call(s8[12,4]{1,0} [[A:%[^ ]+]], s8[4,8]{0,1} [[B:%[^ ]+]]), custom_call_target="__cublas$gemm" - )", - /*print_operand_shape=*/true); - } else { - MatchOptimizedHlo(hlo_text, - R"( -; CHECK: s32[12,8]{1,0} dot(s32[12,4]{1,0} [[A:%[^ ]+]], s32[4,8]{1,0} [[B:%[^ ]+]]), lhs_contracting_dims={1}, rhs_contracting_dims={0} - - )", - /*print_operand_shape=*/true); - } -} - -TEST_P(ParameterizedGemmRewriteTest, Int8GemmNoAlphaRewrite) { - const char* hlo_text = R"( -HloModule int8gemm - -ENTRY int8gemm { - %parameter.1 = s8[12,4]{1,0} parameter(0) - %parameter.2 = s8[4,8]{1,0} parameter(1) - k = s32[] constant(2) - k_broadcast = s32[12,8] broadcast(k), dimensions={} - %dot.8 = s32[12,8] dot(s8[12,4] %parameter.1, s8[4,8] %parameter.2), lhs_contracting_dims={1}, rhs_contracting_dims={0} - ROOT dot_multiplied = s32[12,8] multiply(%dot.8, k_broadcast) -} - )"; - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); - - if (GetCudaComputeCapability().IsAtLeast(se::CudaComputeCapability::VOLTA)) { - MatchOptimizedHlo(hlo_text, - R"( -; CHECK: s32[12,8]{1,0} custom-call(s8[12,4]{1,0} [[A:%[^ ]+]], s8[4,8]{0,1} [[B:%[^ ]+]]), -; CHECK: custom_call_target="__cublas$gemm", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 - )", - /*print_operand_shape=*/true); - } else { - MatchOptimizedHlo(hlo_text, - R"( -; CHECK: s32[12,8]{1,0} dot(s32[12,4]{1,0} [[A:%[^ ]+]], s32[4,8]{1,0} [[B:%[^ ]+]]), lhs_contracting_dims={1}, rhs_contracting_dims={0} - - )", - /*print_operand_shape=*/true); - } -} - -TEST_P(ParameterizedGemmRewriteTest, Int8GemmNoBetaRewrite) { - const char* hlo_text = R"( -HloModule int8gemm - -ENTRY int8gemm { - %parameter.1 = s8[12,4]{1,0} parameter(0) - %parameter.2 = s8[4,8]{1,0} parameter(1) - bias = s32[12,8] parameter(2) - %dot.8 = s32[12,8] dot(s8[12,4] %parameter.1, s8[4,8] %parameter.2), lhs_contracting_dims={1}, rhs_contracting_dims={0} - ROOT out = s32[12,8] add(%dot.8, bias) -} - )"; - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); - - if (GetCudaComputeCapability().IsAtLeast(se::CudaComputeCapability::VOLTA)) { - MatchOptimizedHlo(hlo_text, - R"( -; CHECK: s32[12,8]{1,0} custom-call(s8[12,4]{1,0} [[A:%[^ ]+]], s8[4,8]{0,1} [[B:%[^ ]+]]), -; CHECK: custom_call_target="__cublas$gemm", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":0 - )", - /*print_operand_shape=*/true); - } else { - MatchOptimizedHlo(hlo_text, - R"( -; CHECK: s32[12,8]{1,0} dot(s32[12,4]{1,0} [[A:%[^ ]+]], s32[4,8]{1,0} [[B:%[^ ]+]]), lhs_contracting_dims={1}, rhs_contracting_dims={0} - - )", - /*print_operand_shape=*/true); - } -} - -TEST_P(ParameterizedGemmRewriteTest, Int8GemmNotMultipleOfFour) { - const char* hlo_text = R"( -HloModule int8gemm - -ENTRY int8gemm { - %parameter.1 = s8[13,4]{1,0} parameter(0) - %parameter.2 = s8[4,9]{1,0} parameter(1) - ROOT %dot.9 = s32[13,9] dot(s8[13,4] %parameter.1, s8[4,9] %parameter.2), lhs_contracting_dims={1}, rhs_contracting_dims={0} -} - )"; - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); - - if (GetCudaComputeCapability().IsAtLeast(se::CudaComputeCapability::VOLTA)) { - MatchOptimizedHlo(hlo_text, - R"( -; CHECK: s32[16,12]{1,0} custom-call(s8[16,4]{1,0} [[A:%[^ ]+]], s8[4,12]{0,1} [[B:%[^ ]+]]), custom_call_target="__cublas$gemm" - )", - /*print_operand_shape=*/true); - } else { - MatchOptimizedHlo(hlo_text, - R"( -; CHECK: s32[13,9]{1,0} dot(s32[13,4]{1,0} [[A:%[^ ]+]], s32[4,9]{1,0} [[B:%[^ ]+]]), lhs_contracting_dims={1}, rhs_contracting_dims={0} - - )", - /*print_operand_shape=*/true); - } -} - -TEST_P(ParameterizedGemmRewriteTest, GemmTypeCombinationCheck) { - std::vector> - type_combinations = { - {"s8", "s32", true}, {"s8", "s8", true}, {"s32", "s32", true}, - {"bf16", "bf16", true}, {"f16", "f16", true}, {"f32", "f32", true}, - {"f64", "f64", true}, {"c64", "c64", true}, {"c128", "c128", true}, - }; - - if (GetCudaComputeCapability().IsAtLeast(se::CudaComputeCapability::VOLTA)) { - // For compute capabilities before volta, we always do upcasting, so it - // would be impossible for this test to fail. That is why we only add these - // cases when the compute capabilit is at least Volta. - std::vector> - more_type_combinations = { - {"s8", "bf16", false}, {"s8", "f16", false}, - {"s8", "f32", false}, {"s8", "f64", false}, - {"s8", "c64", false}, {"s8", "c128", false}, - - {"s32", "f32", false}, {"s32", "f64", false}, - {"s32", "c64", false}, {"s32", "c128", false}, - - {"f16", "bf16", false}, {"f16", "f32", false}, - {"f16", "f64", false}, {"f16", "c64", false}, - {"f16", "c128", false}, - - {"bf16", "f16", false}, {"bf16", "f64", false}, - {"bf16", "c64", false}, {"bf16", "c128", false}, - - {"f32", "f64", false}, {"f32", "c64", false}, - {"f32", "c128", false}, - - {"f64", "c64", false}, {"f64", "c128", false}, - }; - type_combinations.insert(type_combinations.end(), - more_type_combinations.begin(), - more_type_combinations.end()); - } - - for (const auto& type_combination : type_combinations) { - absl::flat_hash_map replacements; - replacements["<>"] = std::get<0>(type_combination); - replacements["<>"] = std::get<1>(type_combination); - const char* hlo_template = R"( - HloModule type_combo - - ENTRY type_combo { - %parameter.1 = <>[4,4]{1,0} parameter(0) - %parameter.2 = <>[4,4]{1,0} parameter(1) - ROOT %dot = <>[4,4] dot(%parameter.1, %parameter.2), lhs_contracting_dims={1}, rhs_contracting_dims={0} - } - )"; - const auto hlo_text = absl::StrReplaceAll(hlo_template, replacements); - if (std::get<2>(type_combination)) { - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3})); - } else { - EXPECT_FALSE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3})); - } - } -} - -TEST_P(ParameterizedGemmRewriteTest, UpcastingBf16ToF64) { - const char* hlo_text = R"( -HloModule test - -ENTRY test { - Arg_0.1 = bf16[4,3]{1,0} parameter(0) - Arg_1.2 = bf16[3,6]{1,0} parameter(1) - ROOT dot.3 = f64[4,6]{1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0} -} -)"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(hlo_text)); - GemmRewriter pass(GetCudaComputeCapability()); - TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); - EXPECT_TRUE(changed); - - // This is a type combination which is not supported by cublasLt, expect - // GemmRewriter to choose legacy cublas. - EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch(m::CustomCall({"__cublas$gemm"}))); -} - -TEST_P(ParameterizedGemmRewriteTest, UpcastingC64ToC128) { - const char* hlo_text = R"( -HloModule test - -ENTRY test { - Arg_0.1 = c64[4,3]{1,0} parameter(0) - Arg_1.2 = c64[3,6]{1,0} parameter(1) - ROOT dot.3 = c128[4,6]{1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0} -} -)"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(hlo_text)); - GemmRewriter pass(GetCudaComputeCapability()); - TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); - EXPECT_TRUE(changed); - - // This is a type combination which is not supported by cublasLt, expect - // GemmRewriter to choose legacy cublas. - EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch(m::CustomCall({"__cublas$gemm"}))); -} - -TEST_P(ParameterizedGemmRewriteTest, UpcastingF16ToF32) { - const char* hlo_text = R"( -HloModule test - -ENTRY test { - Arg_0.1 = f16[4,3]{1,0} parameter(0) - Arg_1.2 = f16[3,6]{1,0} parameter(1) - ROOT dot.3 = f32[4,6]{1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0} -} -)"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(hlo_text)); - GemmRewriter pass(GetCudaComputeCapability()); - TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); - EXPECT_TRUE(changed); - - EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch(m::CustomCall({CustomCallTarget()}))); -} - -TEST_P(ParameterizedGemmRewriteTest, UpcastingF16ToF64) { - const char* hlo_text = R"( -HloModule test - -ENTRY test { - Arg_0.1 = f16[4,3]{1,0} parameter(0) - Arg_1.2 = f16[3,6]{1,0} parameter(1) - ROOT dot.3 = f64[4,6]{1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0} -} -)"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(hlo_text)); - GemmRewriter pass(GetCudaComputeCapability()); - TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); - EXPECT_TRUE(changed); - - // This is a type combination which is not supported by cublasLt, expect - // GemmRewriter to choose legacy cublas. - EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch(m::CustomCall({"__cublas$gemm"}))); -} - -TEST_P(ParameterizedGemmRewriteTest, UpcastingF32ToF64) { - const char* hlo_text = R"( -HloModule test - -ENTRY test { - Arg_0.1 = f32[4,3]{1,0} parameter(0) - Arg_1.2 = f32[3,6]{1,0} parameter(1) - ROOT dot.3 = f64[4,6]{1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0} -} -)"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(hlo_text)); - GemmRewriter pass(GetCudaComputeCapability()); - TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); - EXPECT_TRUE(changed); - - // This is a type combination which is not supported by cublasLt, expect - // GemmRewriter to choose legacy cublas. - EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch(m::CustomCall({"__cublas$gemm"}))); -} - -INSTANTIATE_TEST_SUITE_P(CublasTestsBothLegacyAndLt, - ParameterizedGemmRewriteTest, ::testing::Bool()); - -// A test fixture class for tests which are specific to legacy cublas -class LegacyCublasGemmRewriteTest : public GemmRewriteTest { - public: - DebugOptions GetDebugOptionsForTest() override { - DebugOptions debug_options = GemmRewriteTest::GetDebugOptionsForTest(); - debug_options.set_xla_gpu_enable_cublaslt(false); - return debug_options; - } -}; - -// Test that the alpha and beta fields of the GemmBackendConfig are updated. -// A bias must be present for the beta value to be set. -// In order to have a bias add fused, the bias term must be overwritable. -// We assume that we may not overwrite parameters of a computation. Hence, we -// use the third parameter to create a new value which can be overwritten and -// will be used as the bias. This negate(param_2) has no semantic use, it simply -// exists so that bias may be overwritten. -TEST_F(LegacyCublasGemmRewriteTest, AlphaBetaRewrite) { - const char* hlo_text = R"( -HloModule NonZeroAlphaBeta - -ENTRY AddDotsFunc { - x = f32[2,2] parameter(0) - y = f32[2,2] parameter(1) - param_2 = f32[2,2] parameter(2) - bias = f32[2,2] negate(param_2) - k = f32[] constant(3.0) - k_broadcast = f32[2, 2] broadcast(k), dimensions={} - dot_a = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - dot_a_multiplied = f32[2, 2] multiply(dot_a, k_broadcast) - ROOT out = f32[2,2] add(dot_a_multiplied, bias) -} - -)"; - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); - MatchOptimizedHlo(hlo_text, - R"( -; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,2], y: f32[2,2], param_2: f32[2,2]) -> f32[2,2] { -; CHECK-DAG: [[X:%[^ ]+]] = f32[2,2]{1,0} parameter(0) -; CHECK-DAG: [[Y:%[^ ]+]] = f32[2,2]{1,0} parameter(1) -; CHECK: ROOT [[OUT:%[^ ]+]] = f32[2,2]{1,0} custom-call([[X]], [[Y]], {{[^,)]+}}), -; CHECK: custom_call_target="__cublas$gemm", -; CHECK: output_to_operand_aliasing={{{{}: \(2, {}\)}}}, -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":3 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":1 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"DEFAULT\" -; CHECK: }" -)"); -} - -TEST_F(LegacyCublasGemmRewriteTest, BiasMultipleUsersNoOverwrite) { - const char* hlo_text = R"( -HloModule BiasMultipleUsersNoOverwrite - -ENTRY AddDotsFunc { - x = f32[2,2] parameter(0) - y = f32[2,2] parameter(1) - bias = f32[2,2] parameter(2) - k = f32[] constant(3.0) - k_broadcast = f32[2, 2] broadcast(k), dimensions={} - dot_a = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - dot_a_multiplied = f32[2, 2] multiply(dot_a, k_broadcast) - biased_out = f32[2,2] add(dot_a_multiplied, bias) - ROOT out = f32[2,2] add(biased_out, bias) -} -)"; - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); - MatchOptimizedHlo(hlo_text, - R"( -; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,2], y: f32[2,2], bias: f32[2,2]) -> f32[2,2] { -; CHECK-DAG: [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter(0) -; CHECK-DAG: [[P1:%[^ ]+]] = f32[2,2]{1,0} parameter(1) -; CHECK-NEXT: [[GEMM:%[^ ]+]] = f32[2,2]{1,0} custom-call([[P0]], [[P1]]), -; CHECK: custom_call_target="__cublas$gemm", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":3 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":0 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"DEFAULT\" -; CHECK: }" -)"); -} - -TEST_F(LegacyCublasGemmRewriteTest, BiasParameterNoOverwrite) { - const char* hlo_text = R"( -HloModule BiasParameterNoOverwrite - -ENTRY AddDotsFunc { - x = f32[2,2] parameter(0) - y = f32[2,2] parameter(1) - bias = f32[2,2] parameter(2) - dot_a = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - ROOT out = f32[2,2] add(dot_a, bias) -} -)"; - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); - MatchOptimizedHlo(hlo_text, - R"( -; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,2], y: f32[2,2], bias: f32[2,2]) -> f32[2,2] { -; CHECK-DAG: [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter(0) -; CHECK-DAG: [[P1:%[^ ]+]] = f32[2,2]{1,0} parameter(1) -; CHECK-NEXT: [[GEMM:%[^ ]+]] = f32[2,2]{1,0} custom-call([[P0]], [[P1]]), -; CHECK: custom_call_target="__cublas$gemm", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":0 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"DEFAULT\" -; CHECK: }" -)"); -} - -TEST_F(LegacyCublasGemmRewriteTest, BiasTupleParameterOverwrite) { - const char* hlo_text = R"( -HloModule BiasTupleParameterOverwrite - -ENTRY AddDotsFunc { - x = f32[2,2] parameter(0) - y = f32[2,2] parameter(1) - param_2 = (f32[2,2], f32[3,3]) parameter(2) - bias = f32[2,2] get-tuple-element(param_2), index=0 - dot_a = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - ROOT out = f32[2,2] add(dot_a, bias) -} -)"; - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); - MatchOptimizedHlo(hlo_text, - R"( -; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,2], y: f32[2,2], param_2: (f32[2,2], f32[3,3])) -> f32[2,2] { -; CHECK-DAG: [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter(0) -; CHECK-DAG: [[P1:%[^ ]+]] = f32[2,2]{1,0} parameter(1) -; CHECK-DAG: [[P2:%[^ ]+]] = (f32[2,2]{1,0}, f32[3,3]{1,0}) parameter(2) -; CHECK-DAG: [[BIAS:%[^ ]+]] = f32[2,2]{1,0} get-tuple-element([[P2]]), index=0 -; CHECK-DAG: [[BIAS_COPY:%[^ ]+]] = f32[2,2]{1,0} copy([[BIAS]]) -; CHECK-NEXT: [[GEMM:%[^ ]+]] = f32[2,2]{1,0} custom-call([[P0]], [[P1]], [[BIAS_COPY]]), -; CHECK: custom_call_target="__cublas$gemm", -; CHECK: output_to_operand_aliasing={{{{}: \(2, {}\)}}}, -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":1 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"DEFAULT\" -; CHECK: }" -)"); -} - -TEST_F(LegacyCublasGemmRewriteTest, AliasedBiasOverwrite) { - const char* hlo_text = R"( -HloModule AliasedBiasOverwrite, input_output_alias={ {}: (2, {}, must-alias) } - -ENTRY AddDotsFunc { - x = f32[2,2] parameter(0) - y = f32[2,2] parameter(1) - bias = f32[2,2] parameter(2) - k = f32[] constant(3.0) - k_broadcast = f32[2, 2] broadcast(k), dimensions={} - dot_a = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - dot_a_multiplied = f32[2, 2] multiply(dot_a, k_broadcast) - ROOT out = f32[2,2] add(dot_a_multiplied, bias) -} - -)"; - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); - MatchOptimizedHlo(hlo_text, - R"( -; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,2], y: f32[2,2], bias: f32[2,2]) -> f32[2,2] { -; CHECK-DAG: [[X:%[^ ]+]] = f32[2,2]{1,0} parameter(0) -; CHECK-DAG: [[Y:%[^ ]+]] = f32[2,2]{1,0} parameter(1) -; CHECK-DAG: [[BIAS:%[^ ]+]] = f32[2,2]{1,0} parameter(2) -; CHECK: ROOT [[OUT:%[^ ]+]] = f32[2,2]{1,0} custom-call([[X]], [[Y]], [[BIAS]]), -; CHECK: custom_call_target="__cublas$gemm", -; CHECK: output_to_operand_aliasing={{{{}: \(2, {}\)}}}, -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":3 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":1 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"DEFAULT\" -; CHECK: }" -)"); -} - -TEST_F(LegacyCublasGemmRewriteTest, LargerBiasMultipleUsersNoRewrite) { - const char* hlo_text = R"( -HloModule LargerBiasMultipleUsersNoRewrite - -ENTRY AddDotsFunc { - x = f32[1024,1024] parameter(0) - y = f32[1024,1024] parameter(1) - bias = f32[1024,1024] parameter(2) - dot_a = f32[1024,1024] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - biased_out = f32[1024,1024] add(dot_a, bias) - ROOT out = f32[1024,1024] add(biased_out, bias) -} - -)"; - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3})); - MatchOptimizedHlo(hlo_text, - R"( -; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[1024,1024], y: f32[1024,1024], bias: f32[1024,1024]) -> f32[1024,1024] { -; CHECK-DAG: [[P0:%[^ ]+]] = f32[1024,1024]{1,0} parameter(0) -; CHECK-DAG: [[P1:%[^ ]+]] = f32[1024,1024]{1,0} parameter(1) -; CHECK-NEXT: [[GEMM:%[^ ]+]] = f32[1024,1024]{1,0} custom-call([[P0]], [[P1]]), -; CHECK: custom_call_target="__cublas$gemm", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":0 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"DEFAULT\" -; CHECK: }" -)"); -} - -// In order to have a bias add fused, the bias term must be overwritable. -// We assume that we may not overwrite parameters of a computation. Hence, we -// use the third parameter to create a new value which can be overwritten and -// will be used as the bias. This negate(param_2) has no semantic use, it simply -// exists so that bias may be overwritten. -TEST_F(LegacyCublasGemmRewriteTest, BF16GemmWithBias) { - const char* hlo_text = R"( -HloModule BF16GemmWithBias - -ENTRY BF16GemmWithBias { - x = bf16[8,8]{1,0} parameter(0) - y = bf16[8,8]{1,0} parameter(1) - dot.5 = bf16[8,8]{1,0} dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - param_2 = bf16[8,8]{1,0} parameter(2) - bias = bf16[8,8]{1,0} negate(param_2) - ROOT add.6 = bf16[8,8]{1,0} add(dot.5, bias) -} - )"; - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{2e-3, 2e-3})); - MatchOptimizedHlo(hlo_text, - R"( -; CHECK-LABEL: ENTRY %BF16GemmWithBias (x: bf16[8,8], y: bf16[8,8], param_2: bf16[8,8]) -> bf16[8,8] { -; CHECK-DAG: [[X:%[^ ]+]] = bf16[8,8]{1,0} parameter(0) -; CHECK-DAG: [[Y:%[^ ]+]] = bf16[8,8]{1,0} parameter(1) -; CHECK: ROOT [[GEMM:%[^ ]+]] = bf16[8,8]{1,0} custom-call([[X]], [[Y]], {{[^,)]+}}), -; CHECK: custom_call_target="__cublas$gemm", -; CHECK: output_to_operand_aliasing={{{{}: \(2, {}\)}}}, -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":1 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"DEFAULT\" -; CHECK: }" -)"); -} - -// In order to have a bias add fused, the bias term must be overwritable. -// We assume that we may not overwrite parameters of a computation. Hence, we -// use the third parameter to create a new value which can be overwritten and -// will be used as the bias. This negate(param_2) has no semantic use, it simply -// exists so that bias may be overwritten. -TEST_F(LegacyCublasGemmRewriteTest, MatrixBias) { - const char* hlo_text = R"( -HloModule test - -ENTRY test { - x = f32[2,3] parameter(0) - y = f32[3,4] parameter(1) - param_2 = f32[2,4] parameter(2) - bias = f32[2,4] negate(param_2) - dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - ROOT out = f32[2,4] add(dot_a, bias) -} - -)"; - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); - MatchOptimizedHlo(hlo_text, - R"( -; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4], param_2: f32[2,4]) -> f32[2,4] { -; CHECK-DAG: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) -; CHECK-DAG: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) -; CHECK: ROOT [[GEMM:%[^ ]+]] = f32[2,4]{1,0} custom-call([[P0]], [[P1]], {{[^,)]+}}), -; CHECK: custom_call_target="__cublas$gemm", -; CHECK: output_to_operand_aliasing={{{{}: \(2, {}\)}}}, -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":1 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"DEFAULT\" -; CHECK: }" -)"); -} - -TEST_F(LegacyCublasGemmRewriteTest, MatrixBiasWhereBiasIsNotAParameter) { - const char* hlo_text = R"( -HloModule test - -ENTRY test { - w = f32[2,3] parameter(0) - x = f32[3,4] parameter(1) - first_dot = f32[2,4] dot(w, x), lhs_contracting_dims={1}, rhs_contracting_dims={0} - y = f32[2,3] parameter(2) - z = f32[3,4] parameter(3) - second_dot = f32[2,4] dot(y, z), lhs_contracting_dims={1}, rhs_contracting_dims={0} - ROOT out = f32[2,4] add(second_dot, first_dot) -} - -)"; - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); - MatchOptimizedHlo(hlo_text, - R"( -; CHECK-LABEL: ENTRY %test (w: f32[2,3], x: f32[3,4], y: f32[2,3], z: f32[3,4]) -> f32[2,4] { -; CHECK-DAG: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) -; CHECK-DAG: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) -; CHECK-DAG: [[P2:%[^ ]+]] = f32[2,3]{1,0} parameter(2) -; CHECK-DAG: [[P3:%[^ ]+]] = f32[3,4]{1,0} parameter(3) -; CHECK-NEXT: [[FIRST_GEMM:%[^ ]+]] = f32[2,4]{1,0} custom-call([[P0]], [[P1]]), -; CHECK: custom_call_target="__cublas$gemm", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":0 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"DEFAULT\" -; CHECK: }" -; CHECK-NEXT: ROOT [[SECOND_GEMM:%[^ ]+]] = f32[2,4]{1,0} custom-call([[P2]], [[P3]], [[FIRST_GEMM]]), -; CHECK: custom_call_target="__cublas$gemm", -; CHECK: output_to_operand_aliasing={{{{}: \(2, {}\)}}}, -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":1 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"DEFAULT\" -; CHECK: }" -)"); -} - -TEST_F(LegacyCublasGemmRewriteTest, MergeBitcastAndAdd) { - const char* hlo_text = R"( -HloModule test -ENTRY test { - x = f32[2,2] parameter(0) - y = f32[2,2] parameter(1) - bias = f32[4] parameter(2) - dot = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - ROOT out = f32[4] add(f32[4] bitcast(dot), bias) -} -)"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(hlo_text)); - GemmRewriter pass(GetCudaComputeCapability()); - TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); - EXPECT_TRUE(changed); - - EXPECT_THAT( - module->entry_computation()->root_instruction(), - GmockMatch( - m::Bitcast( - m::CustomCall({"__cublas$gemm"}, m::Parameter(0), m::Parameter(1), - m::Bitcast(m::Parameter(2)).WithShape(F32, {2, 2}))) - .WithShape(F32, {4}))); -} - -// In order to have a bias add fused, the bias term must be overwritable. -// We assume that we may not overwrite parameters of a computation. Hence, we -// use the third parameter to create a new value which can be overwritten and -// will be used as the bias. This negate(param_2) has no semantic use, it simply -// exists so that bias may be overwritten. -TEST_F(LegacyCublasGemmRewriteTest, FoldConstantBias) { - const char* hlo_text = R"( -HloModule test -ENTRY test { - x = f32[2,2] parameter(0) - y = f32[2,2] parameter(1) - bias = f32[2,2] broadcast(f32[2] constant({0, 0})), dimensions={0} - - dot1 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - param_2 = f32[2,2] parameter(2) - bias1 = f32[2,2] negate(param_2) - sum1 = add(dot1, bias1) - - dot2 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - sum2 = add(dot2, f32[2,2] reshape(bias)) - - dot3 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - bias3 = f32[2,2] transpose(bias), dimensions={1,0} - sum3 = add(dot3, bias3) - - dot4 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - sum4 = add(dot4, f32[2,2] bitcast(bias)) - - ROOT root = tuple(sum1, sum2, sum3, sum4) -} -)"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(hlo_text)); - GemmRewriter pass(GetCudaComputeCapability()); - TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); - SCOPED_TRACE(module->ToString()); - EXPECT_TRUE(changed); - - EXPECT_THAT( - module->entry_computation()->root_instruction(), - GmockMatch(m::Tuple( - m::CustomCall(m::Parameter(0), m::Parameter(1), - m::Negate(m::Parameter(2))), - m::CustomCall(m::Parameter(0), m::Parameter(1), m::Constant()), - m::CustomCall(m::Parameter(0), m::Parameter(1), m::Constant()), - m::CustomCall(m::Parameter(0), m::Parameter(1), m::Constant())))); -} - -// A test fixture class for tests which are specific to cublasLt -class CublasLtGemmRewriteTest : public GemmRewriteTest { - public: - DebugOptions GetDebugOptionsForTest() override { - DebugOptions debug_options = GemmRewriteTest::GetDebugOptionsForTest(); - debug_options.set_xla_gpu_enable_cublaslt(true); - return debug_options; - } -}; - -TEST_F(CublasLtGemmRewriteTest, AlphaBetaRewrite) { - const char* hlo_text = R"( -HloModule NonZeroAlphaBeta - -ENTRY AddDotsFunc { - x = f32[2,2] parameter(0) - y = f32[2,2] parameter(1) - bias = f32[2,2] parameter(2) - k = f32[] constant(3.0) - k_broadcast = f32[2, 2] broadcast(k), dimensions={} - dot_a = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - dot_a_multiplied = f32[2, 2] multiply(dot_a, k_broadcast) - ROOT out = f32[2,2] add(dot_a_multiplied, bias) -} - -)"; - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); - MatchOptimizedHlo(hlo_text, - R"( -; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,2], y: f32[2,2], bias: f32[2,2]) -> f32[2,2] { -; CHECK-DAG: [[X:%[^ ]+]] = f32[2,2]{1,0} parameter(0) -; CHECK-DAG: [[Y:%[^ ]+]] = f32[2,2]{1,0} parameter(1) -; CHECK-DAG: [[BIAS:%[^ ]+]] = f32[2,2]{1,0} parameter(2) -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,2]{1,0} custom-call([[X]], [[Y]], [[BIAS]]), -; CHECK: custom_call_target="__cublas$lt$matmul", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":3 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":1 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"DEFAULT\" -; CHECK: }" -)"); -} - -TEST_F(CublasLtGemmRewriteTest, BiasMultipleUsersNoOverwrite) { - const char* hlo_text = R"( -HloModule BiasMultipleUsersNoOverwrite - -ENTRY AddDotsFunc { - x = f32[2,2] parameter(0) - y = f32[2,2] parameter(1) - bias = f32[2,2] parameter(2) - k = f32[] constant(3.0) - k_broadcast = f32[2, 2] broadcast(k), dimensions={} - dot_a = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - dot_a_multiplied = f32[2, 2] multiply(dot_a, k_broadcast) - biased_out = f32[2,2] add(dot_a_multiplied, bias) - ROOT out = f32[2,2] add(biased_out, bias) -} -)"; - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); - MatchOptimizedHlo(hlo_text, - R"( -; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,2], y: f32[2,2], bias: f32[2,2]) -> f32[2,2] { -; CHECK-DAG: [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter(0) -; CHECK-DAG: [[P1:%[^ ]+]] = f32[2,2]{1,0} parameter(1) -; CHECK-DAG: [[BIAS:%[^ ]+]] = f32[2,2]{1,0} parameter(2) -; CHECK-NEXT: [[GEMM:%[^ ]+]] = f32[2,2]{1,0} custom-call([[P0]], [[P1]], [[BIAS]]), -; CHECK: custom_call_target="__cublas$lt$matmul", -; CHECK-NOT: output_to_operand_aliasing -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":3 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":1 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"DEFAULT\" -; CHECK: }" -)"); -} - -TEST_F(CublasLtGemmRewriteTest, LargerBiasMultipleUsersNoRewrite) { - const char* hlo_text = R"( -HloModule LargerBiasMultipleUsersNoRewrite - -ENTRY AddDotsFunc { - x = f32[1024,1024] parameter(0) - y = f32[1024,1024] parameter(1) - bias = f32[1024,1024] parameter(2) - dot_a = f32[1024,1024] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - biased_out = f32[1024,1024] add(dot_a, bias) - ROOT out = f32[1024,1024] add(biased_out, bias) -} - -)"; - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3})); - MatchOptimizedHlo(hlo_text, - R"( -; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[1024,1024], y: f32[1024,1024], bias: f32[1024,1024]) -> f32[1024,1024] { -; CHECK-DAG: [[P0:%[^ ]+]] = f32[1024,1024]{1,0} parameter(0) -; CHECK-DAG: [[P1:%[^ ]+]] = f32[1024,1024]{1,0} parameter(1) -; CHECK-DAG: [[BIAS:%[^ ]+]] = f32[1024,1024]{1,0} parameter(2) -; CHECK-NEXT: [[GEMM:%[^ ]+]] = f32[1024,1024]{1,0} custom-call([[P0]], [[P1]], [[BIAS]]), -; CHECK: custom_call_target="__cublas$lt$matmul", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":1 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"DEFAULT\" -; CHECK: }" -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[1024,1024]{1,0} add([[GEMM]], [[BIAS]]) -)"); -} - -TEST_F(CublasLtGemmRewriteTest, BF16GemmWithBias) { - const char* hlo_text = R"( -HloModule test - -ENTRY BF16GemmWithBias { - x = bf16[8,8]{1,0} parameter(0) - y = bf16[8,8]{1,0} parameter(1) - dot.5 = bf16[8,8]{1,0} dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - bias = bf16[8,8]{1,0} parameter(2) - ROOT add.6 = bf16[8,8]{1,0} add(dot.5, bias) -} - )"; - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3})); - MatchOptimizedHlo(hlo_text, - R"( -; CHECK-LABEL: ENTRY %BF16GemmWithBias (x: bf16[8,8], y: bf16[8,8], bias: bf16[8,8]) -> bf16[8,8] { -; CHECK-DAG: [[X:%[^ ]+]] = bf16[8,8]{1,0} parameter(0) -; CHECK-DAG: [[Y:%[^ ]+]] = bf16[8,8]{1,0} parameter(1) -; CHECK-DAG: [[BIAS:%[^ ]+]] = bf16[8,8]{1,0} parameter(2) -; CHECK-NEXT: ROOT [[GEMM:%[^ ]+]] = bf16[8,8]{1,0} custom-call([[X]], [[Y]], [[BIAS]]), -; CHECK: custom_call_target="__cublas$lt$matmul", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":1 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"DEFAULT\" -; CHECK: }" -)"); -} - -TEST_F(CublasLtGemmRewriteTest, MatrixBias) { - const char* hlo_text = R"( -HloModule test - -ENTRY test { - x = f32[2,3] parameter(0) - y = f32[3,4] parameter(1) - z = f32[2,4] parameter(2) - dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - ROOT out = f32[2,4] add(dot_a, z) -} - -)"; - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); - MatchOptimizedHlo(hlo_text, - R"( -; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4], z: f32[2,4]) -> f32[2,4] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) -; CHECK-NEXT: [[P2:%[^ ]+]] = f32[2,4]{1,0} parameter(2) -; CHECK-NEXT: ROOT [[GEMM:%[^ ]+]] = f32[2,4]{1,0} custom-call([[P0]], [[P1]], [[P2]]), -; CHECK: custom_call_target="__cublas$lt$matmul", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":1 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"DEFAULT\" -; CHECK: }" -)"); -} - -TEST_F(CublasLtGemmRewriteTest, MatrixBiasWhereBiasIsNotAParameter) { - const char* hlo_text = R"( -HloModule test - -ENTRY test { - w = f32[2,3] parameter(0) - x = f32[3,4] parameter(1) - first_dot = f32[2,4] dot(w, x), lhs_contracting_dims={1}, rhs_contracting_dims={0} - y = f32[2,3] parameter(2) - z = f32[3,4] parameter(3) - second_dot = f32[2,4] dot(y, z), lhs_contracting_dims={1}, rhs_contracting_dims={0} - ROOT out = f32[2,4] add(second_dot, first_dot) -} - -)"; - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); - MatchOptimizedHlo(hlo_text, - R"( -; CHECK-LABEL: ENTRY %test (w: f32[2,3], x: f32[3,4], y: f32[2,3], z: f32[3,4]) -> f32[2,4] { -; CHECK-DAG: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) -; CHECK-DAG: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) -; CHECK-DAG: [[P2:%[^ ]+]] = f32[2,3]{1,0} parameter(2) -; CHECK-DAG: [[P3:%[^ ]+]] = f32[3,4]{1,0} parameter(3) -; CHECK-NEXT: [[FIRST_GEMM:%[^ ]+]] = f32[2,4]{1,0} custom-call([[P0]], [[P1]]), -; CHECK: custom_call_target="__cublas$lt$matmul", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":0 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"DEFAULT\" -; CHECK: }" -; CHECK-NEXT: ROOT [[SECOND_GEMM:%[^ ]+]] = f32[2,4]{1,0} custom-call([[P2]], [[P3]], [[FIRST_GEMM]]), -; CHECK: custom_call_target="__cublas$lt$matmul", -; CHECK: output_to_operand_aliasing={{{{}: \(2, {}\)}}}, -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":1 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"DEFAULT\" -; CHECK: }" -)"); -} - -TEST_F(CublasLtGemmRewriteTest, VectorBias) { - const char* hlo_text = R"( -HloModule test - -ENTRY test { - x = f32[2,3] parameter(0) - y = f32[3,4] parameter(1) - z = f32[4] parameter(2) - dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - z_bcast = f32[2,4] broadcast(z), dimensions={1} - ROOT out = f32[2,4] add(dot_a, z_bcast) -} - -)"; - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); - MatchOptimizedHlo(hlo_text, - R"( -; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4], z: f32[4]) -> f32[2,4] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) -; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4]{0} parameter(2) -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,4]{1,0} custom-call([[P0]], [[P1]], [[P2]]), -; CHECK: custom_call_target="__cublas$lt$matmul", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":0 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"BIAS\" -; CHECK: }" -)"); -} - -// Epilogue Fusion disabled when GEMM has multiple users. -TEST_F(CublasLtGemmRewriteTest, VectorBiasMultipleUsers) { - const char* hlo_text = R"( -HloModule test - -ENTRY test { - x = f32[4,4] parameter(0) - y = f32[4,4] parameter(1) - z = f32[4] parameter(2) - c = f32[] constant(5) - dot_a = f32[4,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - z_bcast = f32[4,4] broadcast(z), dimensions={1} - add_a = f32[4,4] add(dot_a, z_bcast) - c_bcast = f32[4,4] broadcast(c), dimensions={} - dot_b = f32[4,4] dot(dot_a, c_bcast), lhs_contracting_dims={1}, rhs_contracting_dims={0} - ROOT out = f32[4,4] dot(add_a, dot_b), lhs_contracting_dims={1}, rhs_contracting_dims={0} -} - -)"; - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); - MatchOptimizedHlo(hlo_text, - R"( - -; CHECK: [[FUSED_COMPUTATION:%[^ ]+]] ([[DUMMY0:[^ ]+]]: f32[4,4], [[DUMMY1:[^ ]+]]: f32[4]) -> f32[4,4] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f32[4,4]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4]{0} parameter(1) -; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4,4]{1,0} broadcast([[P1]]), dimensions={1} -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[4,4]{1,0} add([[P0]], [[P2]]) -} - -; CHECK-LABEL: ENTRY %test (x: f32[4,4], y: f32[4,4], z: f32[4]) -> f32[4,4] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f32[4,4]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4,4]{1,0} parameter(1) -; CHECK-NEXT: [[MATMUL0:%[^ ]+]] = f32[4,4]{1,0} custom-call([[P0]], [[P1]]), -; CHECK: custom_call_target="__cublas$lt$matmul", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":0 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"DEFAULT\" -; CHECK: }" -; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4]{0} parameter(2) -; CHECK-NEXT: [[FUSION:%[^ ]+]] = f32[4,4]{1,0} fusion([[MATMUL0]], [[P2]]), kind=kLoop, calls=[[FUSED_COMPUTATION]] -; CHECK-NEXT: [[C0:%[^ ]+]] = f32[] constant(5) -; CHECK-NEXT: [[C0_BCAST:%[^ ]+]] = f32[4,4]{1,0} broadcast([[C0]]), dimensions={} -; CHECK-NEXT: [[MATMUL1:%[^ ]+]] = f32[4,4]{1,0} custom-call([[MATMUL0]], [[C0_BCAST]]), -; CHECK: custom_call_target="__cublas$lt$matmul", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":0 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"DEFAULT\" -; CHECK: }" -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[4,4]{1,0} custom-call([[FUSION]], [[MATMUL1]]), -; CHECK: custom_call_target="__cublas$lt$matmul", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":0 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"DEFAULT\" -; CHECK: }" - )"); -} - -TEST_F(CublasLtGemmRewriteTest, BatchedVectorBias) { - const char* hlo_text = R"( -HloModule test - -ENTRY test { - x = f32[2,3,4] parameter(0) - y = f32[4,5,6] parameter(1) - z = f32[3,5,6] parameter(2) - dot_a = f32[2,3,5,6] dot(x, y), lhs_contracting_dims={2}, rhs_contracting_dims={0} - z_bcast = f32[2,3,5,6] broadcast(z), dimensions={1,2,3} - ROOT out = f32[2,3,5,6] add(dot_a, z_bcast) -} - -)"; - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); - MatchOptimizedHlo(hlo_text, - R"( - -; CHECK: [[FUSED_COMPUTATION:%[^ ]+]] ([[DUMMY0:[^ ]+]]: f32[3,5,6]) -> f32[6,30] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f32[3,5,6]{2,1,0} parameter(0) -; CHECK-NEXT: [[P0_BCAST:%[^ ]+]] = f32[2,3,5,6]{3,2,1,0} broadcast([[P0]]), dimensions={1,2,3} -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[6,30]{1,0} bitcast([[P0_BCAST]]) -} - -; CHECK-LABEL: ENTRY %test (x: f32[2,3,4], y: f32[4,5,6], z: f32[3,5,6]) -> f32[2,3,5,6] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3,4]{2,1,0} parameter(0) -; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[6,4]{1,0} bitcast([[P0]]) -; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4,5,6]{2,1,0} parameter(1) -; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[4,30]{1,0} -; CHECK-NEXT: [[P2:%[^ ]+]] = f32[3,5,6]{2,1,0} parameter(2) -; CHECK-NEXT: [[FUSION:%[^ ]+]] = f32[6,30]{1,0} fusion([[P2]]), kind=kLoop, calls=[[FUSED_COMPUTATION]] -; CHECK-NEXT: [[MATMUL:%[^ ]+]] = f32[6,30]{1,0} custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[FUSION]]), -; CHECK: custom_call_target="__cublas$lt$matmul", -; CHECK: output_to_operand_aliasing={{[{][{]}}}: (2, {})}, -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":1 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"DEFAULT\" -; CHECK: }" -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,3,5,6]{3,2,1,0} bitcast([[MATMUL]]) - )"); -} - -TEST_F(CublasLtGemmRewriteTest, BatchedSharedVectorBias) { - const char* hlo_text = R"( -HloModule test - -ENTRY test { - x = f32[2,3,4] parameter(0) - y = f32[4,5,6] parameter(1) - z = f32[6] parameter(2) - dot_a = f32[2,3,5,6] dot(x, y), lhs_contracting_dims={2}, rhs_contracting_dims={0} - z_bcast = f32[2,3,5,6] broadcast(z), dimensions={3} - ROOT out = f32[2,3,5,6] add(dot_a, z_bcast) -} - -)"; - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); - MatchOptimizedHlo(hlo_text, - R"( - -; CHECK: [[FUSED_COMPUTATION:%[^ ]+]] ([[DUMMY0:[^ ]+]]: f32[6]) -> f32[6,30] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f32[6]{0} parameter(0) -; CHECK-NEXT: [[P0_BCAST:%[^ ]+]] = f32[2,3,5,6]{3,2,1,0} broadcast([[P0]]), dimensions={3} -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[6,30]{1,0} bitcast([[P0_BCAST]]) -} - -; CHECK-LABEL: ENTRY %test (x: f32[2,3,4], y: f32[4,5,6], z: f32[6]) -> f32[2,3,5,6] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3,4]{2,1,0} parameter(0) -; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[6,4]{1,0} bitcast([[P0]]) -; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4,5,6]{2,1,0} parameter(1) -; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[4,30]{1,0} -; CHECK-NEXT: [[P2:%[^ ]+]] = f32[6]{0} parameter(2) -; CHECK-NEXT: [[FUSION:%[^ ]+]] = f32[6,30]{1,0} fusion([[P2]]), kind=kLoop, calls=[[FUSED_COMPUTATION]] -; CHECK-NEXT: [[MATMUL:%[^ ]+]] = f32[6,30]{1,0} custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[FUSION]]), -; CHECK: custom_call_target="__cublas$lt$matmul", -; CHECK: output_to_operand_aliasing={{[{][{]}}}: (2, {})}, -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":1 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"DEFAULT\" -; CHECK: }" -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,3,5,6]{3,2,1,0} bitcast([[MATMUL]]) - )"); -} - -TEST_F(CublasLtGemmRewriteTest, VectorBiasIncorrectAxisFusedAsMatrix) { - const char* hlo_text = R"( -HloModule test - -ENTRY test { - x = f32[2,3] parameter(0) - y = f32[3,4] parameter(1) - z = f32[2] parameter(2) - dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - z_bcast = f32[2,4] broadcast(z), dimensions={0} - add = f32[2,4] add(dot_a, z_bcast) - ROOT out = f32[4,2] transpose(add), dimensions={1,0} -} - -)"; - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); - MatchOptimizedHlo(hlo_text, - R"( -; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4], z: f32[2]) -> f32[4,2] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) -; CHECK-NEXT: [[P2:%[^ ]+]] = f32[2]{0} parameter(2) -; CHECK-NEXT: [[MATMUL:%[^ ]+]] = f32[2,4]{0,1} custom-call([[P0]], [[P1]], [[P2]]), -; CHECK: custom_call_target="__cublas$lt$matmul", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":0 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"BIAS\" -; CHECK: }" -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[4,2]{1,0} bitcast([[MATMUL]]) -)"); -} - -TEST_F(CublasLtGemmRewriteTest, VectorBiasSliced) { - const char* hlo_text = R"( -HloModule test - -ENTRY test { - x = f32[4,3] parameter(0) - y = f32[3,4] parameter(1) - z = f32[3] parameter(2) - dot_a = f32[4,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - slice_a = f32[2,3] slice(dot_a), slice={[0:2], [0:3]} - z_bcast = f32[2,3] broadcast(z), dimensions={1} - ROOT out = f32[2,3] add(slice_a, z_bcast) -} - -)"; - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); - MatchOptimizedHlo(hlo_text, - R"( - -; CHECK-LABEL: ENTRY %test (x: f32[4,3], y: f32[3,4], z: f32[3]) -> f32[2,3] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f32[4,3]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) -; CHECK-NEXT: [[P2:%[^ ]+]] = f32[3]{0} parameter(2) -; CHECK-NEXT: [[MATMUL:%[^ ]+]] = f32[4,4]{1,0} custom-call([[P0]], [[P1]], [[P2]]), -; CHECK: custom_call_target="__cublas$lt$matmul", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":0 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"BIAS\" -; CHECK: }" -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,3]{1,0} slice([[MATMUL]]), slice={[0:2], [0:3]} - )"); -} - -// Epilogue Fusion disabled when slice has multiple users. -TEST_F(CublasLtGemmRewriteTest, VectorBiasSlicedMultipleUsers) { - const char* hlo_text = R"( -HloModule test - -ENTRY test { - x = f32[2,3] parameter(0) - y = f32[3,4] parameter(1) - z = f32[2] parameter(2) - c = f32[] constant(5) - dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - slice_a = f32[2,2] slice(dot_a), slice={[0:2], [0:2]} - z_bcast = f32[2,2] broadcast(z), dimensions={1} - add_a = f32[2,2] add(slice_a, z_bcast) - c_bcast = f32[2,2] broadcast(c), dimensions={} - dot_b = f32[2,2] dot(slice_a, c_bcast), lhs_contracting_dims={1}, rhs_contracting_dims={0} - ROOT out = f32[2,2] dot(add_a, dot_b), lhs_contracting_dims={1}, rhs_contracting_dims={0} -} - -)"; - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); - MatchOptimizedHlo(hlo_text, - R"( - -; CHECK: [[FUSED_COMPUTATION:%[^ ]+]] ([[DUMMY0:[^ ]+]]: f32[2], [[DUMMY1:[^ ]+]]: f32[2,4]) -> f32[2,2] { -; CHECK-DAG: [[P0:%[^ ]+]] = f32[2]{0} parameter(0) -; CHECK-DAG: [[P1:%[^ ]+]] = f32[2,4]{1,0} parameter(1) -; CHECK-DAG: [[SLICE:%[^ ]+]] = f32[2,2]{1,0} slice([[P1]]), slice={[0:2], [0:2]} -; CHECK-NEXT: [[P0_BCAST:%[^ ]+]] = f32[2,2]{1,0} broadcast([[P0]]), dimensions={1} -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,2]{1,0} add([[SLICE]], [[P0_BCAST]]) -} - -; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4], z: f32[2]) -> f32[2,2] { -; CHECK-DAG: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) -; CHECK-DAG: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) -; CHECK-DAG: [[P2:%[^ ]+]] = f32[2]{0} parameter(2) -; CHECK-NEXT: [[MATMUL0:%[^ ]+]] = f32[2,4]{1,0} custom-call([[P0]], [[P1]]), -; CHECK: custom_call_target="__cublas$lt$matmul", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":0 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"DEFAULT\" -; CHECK: }" -; CHECK-NEXT: [[FUSION:%[^ ]+]] = f32[2,2]{1,0} fusion([[P2]], [[MATMUL0]]), kind=kLoop, calls=[[FUSED_COMPUTATION]] -; CHECK-NEXT: [[SLICE:%[^ ]+]] = f32[2,2]{1,0} slice([[MATMUL0]]), slice={[0:2], [0:2]} -; CHECK-NEXT: [[C0:%[^ ]+]] = f32[] constant(5) -; CHECK-NEXT: [[C0_BCAST:%[^ ]+]] = f32[2,2]{1,0} broadcast([[C0]]), dimensions={} -; CHECK-NEXT: [[MATMUL1:%[^ ]+]] = f32[2,2]{1,0} custom-call([[SLICE]], [[C0_BCAST]]), -; CHECK: custom_call_target="__cublas$lt$matmul", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":0 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"DEFAULT\" -; CHECK: }" -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,2]{1,0} custom-call([[FUSION]], [[MATMUL1]]), -; CHECK: custom_call_target="__cublas$lt$matmul", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":0 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"DEFAULT\" -; CHECK: }" - )"); -} - -TEST_F(CublasLtGemmRewriteTest, VectorBiasTransposed) { - const char* hlo_text = R"( -HloModule test - -ENTRY test { - x = f32[2,3] parameter(0) - y = f32[3,4] parameter(1) - z = f32[2] parameter(2) - dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - z_bcast = f32[2,4] parameter(3) - ROOT out = f32[2,4] add(dot_a, z_bcast) -} - -)"; - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); - MatchOptimizedHlo(hlo_text, - R"( -; CHECK: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) -; CHECK-NEXT: [[P2_BCAST:%[^ ]+]] = f32[2,4]{1,0} parameter(3) -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,4]{1,0} custom-call([[P0]], [[P1]], [[P2_BCAST]]), -; CHECK: custom_call_target="__cublas$lt$matmul", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":1 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"DEFAULT\" -; CHECK: }" -)"); -} - -TEST_F(CublasLtGemmRewriteTest, VectorBiasThenMatrixBias) { - const char* hlo_text = R"( -HloModule test - -ENTRY test { - x = f32[2,3] parameter(0) - y = f32[3,4] parameter(1) - z = f32[4] parameter(2) - z2 = f32[2,4] parameter(3) - dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - z_bcast = f32[2,4] broadcast(z), dimensions={1} - add0 = f32[2,4] add(dot_a, z_bcast) - ROOT add1 = f32[2,4] add(add0, z2) -} - -)"; - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); - MatchOptimizedHlo(hlo_text, - R"( -; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4], z: f32[4], z2: f32[2,4]) -> f32[2,4] { -; CHECK-DAG: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) -; CHECK-DAG: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) -; CHECK-DAG: [[VECTOR_BIAS:%[^ ]+]] = f32[4]{0} parameter(2) -; CHECK-DAG: [[MATRIX_BIAS:%[^ ]+]] = f32[2,4]{1,0} parameter(3) -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,4]{1,0} custom-call([[P0]], [[P1]], [[MATRIX_BIAS]], [[VECTOR_BIAS]]), -; CHECK: custom_call_target="__cublas$lt$matmul", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":1 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"BIAS\" -; CHECK: }" -)"); -} - -TEST_F(CublasLtGemmRewriteTest, BF16VectorBias) { - const char* hlo_text = R"( -HloModule test - -ENTRY test { - x = bf16[16,24] parameter(0) - y = bf16[24,32] parameter(1) - z = bf16[32] parameter(2) - dot_a = bf16[16,32] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - z_bcast = bf16[16,32] broadcast(z), dimensions={1} - ROOT out = bf16[16,32] add(dot_a, z_bcast) -} - -)"; - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{3e-3, 1e-3})); - MatchOptimizedHlo(hlo_text, - R"( - -; CHECK-LABEL: ENTRY %test (x: bf16[16,24], y: bf16[24,32], z: bf16[32]) -> bf16[16,32] { -; CHECK-NEXT: [[P0:%[^ ]+]] = bf16[16,24]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = bf16[24,32]{1,0} parameter(1) -; CHECK-NEXT: [[P2:%[^ ]+]] = bf16[32]{0} parameter(2) -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = bf16[16,32]{1,0} custom-call([[P0]], [[P1]], [[P2]]), -; CHECK: custom_call_target="__cublas$lt$matmul", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":0 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"BIAS\" - )"); -} - -TEST_F(CublasLtGemmRewriteTest, BF16VectorBiasPadded) { - if (!GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << "Padding of GEMM bf16 operands only implemented on " - "architectures with bf16 Tensor Cores."; - } - const char* hlo_text = R"( -HloModule test - -ENTRY test { - x = bf16[2,3] parameter(0) - y = bf16[3,4] parameter(1) - z = bf16[4] parameter(2) - dot_a = bf16[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - z_bcast = bf16[2,4] broadcast(z), dimensions={1} - ROOT out = bf16[2,4] add(dot_a, z_bcast) -} - -)"; - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3})); - MatchOptimizedHlo(hlo_text, - R"( - -; CHECK-LABEL: ENTRY %test (x: bf16[2,3], y: bf16[3,4], z: bf16[4]) -> bf16[2,4] { -; CHECK-NEXT: [[P0:%[^ ]+]] = bf16[2,3]{1,0} parameter(0) -; CHECK-NEXT: [[C0:%[^ ]+]] = bf16[] constant(0) -; CHECK-NEXT: [[P0_PADDED:%[^ ]+]] = bf16[8,8]{1,0} pad([[P0]], [[C0]]), padding=0_6x0_5 -; CHECK-NEXT: [[P1:%[^ ]+]] = bf16[3,4]{1,0} parameter(1) -; CHECK-NEXT: [[P1_PADDED:%[^ ]+]] = bf16[8,8]{1,0} pad([[P1]], [[C0]]), padding=0_5x0_4 -; CHECK-NEXT: [[P2:%[^ ]+]] = bf16[4]{0} parameter(2) -; CHECK-NEXT: [[MATMUL:%[^ ]+]] = bf16[8,8]{1,0} custom-call([[P0_PADDED]], [[P1_PADDED]], [[P2]]), -; CHECK: custom_call_target="__cublas$lt$matmul", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":0 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"BIAS\" -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = bf16[2,4]{1,0} slice([[MATMUL]]), slice={[0:2], [0:4]} - )"); -} - -TEST_F(CublasLtGemmRewriteTest, ReluActivation) { - const char* hlo_text = R"( -HloModule test - -ENTRY test { - x = f32[2,3] parameter(0) - y = f32[3,4] parameter(1) - dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - c = f32[] constant(0) - c_bcast = f32[2,4] broadcast(c), dimensions={} - ROOT out = f32[2,4] maximum(dot_a, c_bcast) -} - -)"; - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); - MatchOptimizedHlo(hlo_text, - R"( - -; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4]) -> f32[2,4] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,4]{1,0} custom-call([[P0]], [[P1]]), -; CHECK: custom_call_target="__cublas$lt$matmul", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":0 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"RELU\" -; CHECK: }" - )"); -} - -TEST_F(CublasLtGemmRewriteTest, BatchedReluActivation) { - const char* hlo_text = R"( -HloModule test - -ENTRY test { - x = f32[2,3,4] parameter(0) - y = f32[4,5,6] parameter(1) - dot_a = f32[2,3,5,6] dot(x, y), lhs_contracting_dims={2}, rhs_contracting_dims={0} - c = f32[] constant(0) - c_bcast = f32[2,3,5,6] broadcast(c), dimensions={} - ROOT out = f32[2,3,5,6] maximum(dot_a, c_bcast) -} - -)"; - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); - MatchOptimizedHlo(hlo_text, - R"( - -; CHECK-LABEL: ENTRY %test (x: f32[2,3,4], y: f32[4,5,6]) -> f32[2,3,5,6] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3,4]{2,1,0} parameter(0) -; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[6,4]{1,0} bitcast([[P0]]) -; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4,5,6]{2,1,0} parameter(1) -; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[4,30]{1,0} -; CHECK-NEXT: [[MATMUL:%[^ ]+]] = f32[6,30]{1,0} custom-call([[P0_BITCAST]], [[P1_BITCAST]]), -; CHECK: custom_call_target="__cublas$lt$matmul", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":0 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"RELU\" -; CHECK: }" -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,3,5,6]{3,2,1,0} bitcast([[MATMUL]]) - )"); -} - -TEST_F(CublasLtGemmRewriteTest, ReluActivationSliced) { - const char* hlo_text = R"( -HloModule test - -ENTRY test { - x = f32[2,3] parameter(0) - y = f32[3,4] parameter(1) - dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - c = f32[] constant(0) - c_bcast = f32[2,2] broadcast(c), dimensions={} - slice_a = f32[2,2] slice(dot_a), slice={[0:2], [0:2]} - ROOT out = f32[2,2] maximum(slice_a, c_bcast) -} - -)"; - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); - MatchOptimizedHlo(hlo_text, - R"( - -; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4]) -> f32[2,2] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) -; CHECK-NEXT: [[MATMUL:%[^ ]+]] = f32[2,4]{1,0} custom-call([[P0]], [[P1]]), -; CHECK: custom_call_target="__cublas$lt$matmul", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":0 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"RELU\" -; CHECK: }" -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,2]{1,0} slice([[MATMUL]]), slice={[0:2], [0:2]} - )"); -} - -TEST_F(CublasLtGemmRewriteTest, MatrixBiasReluActivation) { - const char* hlo_text = R"( -HloModule test - -ENTRY test { - x = f32[2,3] parameter(0) - y = f32[3,4] parameter(1) - z = f32[2,4] parameter(2) - dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - add = f32[2,4] add(dot_a, z) - c = f32[] constant(0) - c_bcast = f32[2,4] broadcast(c), dimensions={} - ROOT out = f32[2,4] maximum(add, c_bcast) -} - -)"; - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); - MatchOptimizedHlo(hlo_text, - R"( - -; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4], z: f32[2,4]) -> f32[2,4] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) -; CHECK-NEXT: [[P2:%[^ ]+]] = f32[2,4]{1,0} parameter(2) -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,4]{1,0} custom-call([[P0]], [[P1]], [[P2]]), -; CHECK: custom_call_target="__cublas$lt$matmul", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":1 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"RELU\" -; CHECK: }" - )"); -} - -TEST_F(CublasLtGemmRewriteTest, SquareMatrixBiasReluActivation) { - const char* hlo_text = R"( -HloModule test - -ENTRY test { - x = f32[4,4] parameter(0) - y = f32[4,4] parameter(1) - z = f32[4,4] parameter(2) - dot_a = f32[4,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - add = f32[4,4] add(dot_a, z) - c = f32[] constant(0) - c_bcast = f32[4,4] broadcast(c), dimensions={} - ROOT out = f32[4,4] maximum(add, c_bcast) -} - -)"; - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); - MatchOptimizedHlo(hlo_text, - R"( - -; CHECK-LABEL: ENTRY %test (x: f32[4,4], y: f32[4,4], z: f32[4,4]) -> f32[4,4] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f32[4,4]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4,4]{1,0} parameter(1) -; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4,4]{1,0} parameter(2) -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[4,4]{1,0} custom-call([[P0]], [[P1]], [[P2]]), -; CHECK: custom_call_target="__cublas$lt$matmul", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":1 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"RELU\" -; CHECK: }" - )"); -} - -TEST_F(CublasLtGemmRewriteTest, VectorBiasReluActivation) { - const char* hlo_text = R"( -HloModule test - -ENTRY test { - x = f32[2,3] parameter(0) - y = f32[3,4] parameter(1) - z = f32[4] parameter(2) - dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - z_bcast = f32[2,4] broadcast(z), dimensions={1} - add = f32[2,4] add(dot_a, z_bcast) - c = f32[] constant(0) - c_bcast = f32[2,4] broadcast(c), dimensions={} - ROOT out = f32[2,4] maximum(add, c_bcast) -} - -)"; - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); - MatchOptimizedHlo(hlo_text, - R"( - -; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4], z: f32[4]) -> f32[2,4] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) -; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4]{0} parameter(2) -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,4]{1,0} custom-call([[P0]], [[P1]], [[P2]]), -; CHECK: custom_call_target="__cublas$lt$matmul", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":0 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"BIAS_RELU\" -; CHECK: }" - )"); -} - -TEST_F(CublasLtGemmRewriteTest, BatchedVectorBiasReluActivation) { - const char* hlo_text = R"( -HloModule test - -ENTRY test { - x = f32[2,3,4] parameter(0) - y = f32[4,5,6] parameter(1) - z = f32[3,5,6] parameter(2) - dot_a = f32[2,3,5,6] dot(x, y), lhs_contracting_dims={2}, rhs_contracting_dims={0} - z_bcast = f32[2,3,5,6] broadcast(z), dimensions={1,2,3} - add = f32[2,3,5,6] add(dot_a, z_bcast) - c = f32[] constant(0) - c_bcast = f32[2,3,5,6] broadcast(c), dimensions={} - ROOT out = f32[2,3,5,6] maximum(add, c_bcast) -} - -)"; - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); - MatchOptimizedHlo(hlo_text, - R"( - -; CHECK: [[FUSED_COMPUTATION:%[^ ]+]] ([[DUMMY0:[^ ]+]]: f32[3,5,6]) -> f32[6,30] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f32[3,5,6]{2,1,0} parameter(0) -; CHECK-NEXT: [[P0_BCAST:%[^ ]+]] = f32[2,3,5,6]{3,2,1,0} broadcast([[P0]]), dimensions={1,2,3} -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[6,30]{1,0} bitcast([[P0_BCAST]]) -} - -; CHECK-LABEL: ENTRY %test (x: f32[2,3,4], y: f32[4,5,6], z: f32[3,5,6]) -> f32[2,3,5,6] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3,4]{2,1,0} parameter(0) -; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[6,4]{1,0} bitcast([[P0]]) -; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4,5,6]{2,1,0} parameter(1) -; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[4,30]{1,0} -; CHECK-NEXT: [[P2:%[^ ]+]] = f32[3,5,6]{2,1,0} parameter(2) -; CHECK-NEXT: [[FUSION:%[^ ]+]] = f32[6,30]{1,0} fusion([[P2]]), kind=kLoop, calls=[[FUSED_COMPUTATION]] -; CHECK-NEXT: [[MATMUL:%[^ ]+]] = f32[6,30]{1,0} custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[FUSION]]), -; CHECK: custom_call_target="__cublas$lt$matmul", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":1 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"RELU\" -; CHECK: }" -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,3,5,6]{3,2,1,0} bitcast([[MATMUL]]) - )"); -} - -TEST_F(CublasLtGemmRewriteTest, VectorBiasTransposedReluActivation) { - const char* hlo_text = R"( -HloModule test - -ENTRY test { - x = f32[2,3] parameter(0) - y = f32[3,4] parameter(1) - z = f32[2] parameter(2) - dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - z_bcast = f32[2,4] broadcast(z), dimensions={0} - add = f32[2,4] add(dot_a, z_bcast) - c = f32[] constant(0) - c_bcast = f32[2,4] broadcast(c), dimensions={} - maximum = f32[2,4] maximum(add, c_bcast) - ROOT out = f32[4,2] transpose(maximum), dimensions={1,0} -} - -)"; - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); - MatchOptimizedHlo(hlo_text, - R"( - -; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4], z: f32[2]) -> f32[4,2] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) -; CHECK-NEXT: [[P2:%[^ ]+]] = f32[2]{0} parameter(2) -; CHECK-NEXT: [[MATMUL:%[^ ]+]] = f32[2,4]{0,1} custom-call([[P0]], [[P1]], [[P2]]), -; CHECK: custom_call_target="__cublas$lt$matmul", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":0 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"BIAS_RELU\" -; CHECK: }" -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[4,2]{1,0} bitcast([[MATMUL]]) - )"); -} - -TEST_F(CublasLtGemmRewriteTest, VectorBiasThenMatrixBiasReluActivation) { - const char* hlo_text = R"( -HloModule test - -ENTRY test { - x = f32[2,3] parameter(0) - y = f32[3,4] parameter(1) - z_vec = f32[4] parameter(2) - z_matrix = f32[2,4] parameter(3) - dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - z_bcast = f32[2,4] broadcast(z_vec), dimensions={1} - add0 = f32[2,4] add(dot_a, z_bcast) - add1 = f32[2,4] add(add0, z_matrix) - c = f32[] constant(0) - c_bcast = f32[2,4] broadcast(c), dimensions={} - ROOT out = f32[2,4] maximum(add1, c_bcast) -} - -)"; - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); - MatchOptimizedHlo(hlo_text, - R"( - -; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4], z_vec: f32[4], z_matrix: f32[2,4]) -> f32[2,4] { -; CHECK-DAG: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) -; CHECK-DAG: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) -; CHECK-DAG: [[P2:%[^ ]+]] = f32[4]{0} parameter(2) -; CHECK-DAG: [[P3:%[^ ]+]] = f32[2,4]{1,0} parameter(3) -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,4]{1,0} custom-call([[P0]], [[P1]], [[P3]], [[P2]]), -; CHECK: custom_call_target="__cublas$lt$matmul", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":1 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"BIAS_RELU\" -; CHECK: }" - )"); -} - -TEST_F(CublasLtGemmRewriteTest, ApproxGeluActivation) { - const char* hlo_text = R"( -HloModule test - -ENTRY test { - x = f32[2,3] parameter(0) - y = f32[3,4] parameter(1) - dot = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - mul.0 = f32[2,4] multiply(dot, dot) - mul.1 = f32[2,4] multiply(dot, mul.0) - const.0 = f32[] constant(0.044715) - bcast.0 = f32[2,4] broadcast(const.0), dimensions={} - mul.2 = f32[2,4] multiply(mul.1, bcast.0) - add.0 = f32[2,4] add(dot, mul.2) - const.1 = f32[] constant(0.797884583) - bcast.1 = f32[2,4] broadcast(const.1), dimensions={} - mul.3 = f32[2,4] multiply(add.0, bcast.1) - tanh = f32[2,4] tanh(mul.3) - const.2 = f32[] constant(1) - bcast.2 = f32[2,4] broadcast(const.2), dimensions={} - add.2 = f32[2,4] add(tanh, bcast.2) - const.3 = f32[] constant(0.5) - bcast.3 = f32[2,4] broadcast(const.3), dimensions={} - mul.4 = f32[2,4] multiply(add.2, bcast.3) - ROOT out = f32[2,4] multiply(dot, mul.4) -} - -)"; - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); - MatchOptimizedHlo(hlo_text, - R"( - -; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4]) -> f32[2,4] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,4]{1,0} custom-call([[P0]], [[P1]]), -; CHECK: custom_call_target="__cublas$lt$matmul", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":0 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"GELU\" -; CHECK: }" - )"); -} - -TEST_F(CublasLtGemmRewriteTest, ApproxGeluActivationWrongConstant) { - // Modify one constant slightly, so it should no longer pattern match. - const char* hlo_text = R"( -HloModule test - -ENTRY test { - x = f32[2,3] parameter(0) - y = f32[3,4] parameter(1) - dot = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - mul.0 = f32[2,4] multiply(dot, dot) - mul.1 = f32[2,4] multiply(dot, mul.0) - const.0 = f32[] constant(0.05) - bcast.0 = f32[2,4] broadcast(const.0), dimensions={} - mul.2 = f32[2,4] multiply(mul.1, bcast.0) - add.0 = f32[2,4] add(dot, mul.2) - const.1 = f32[] constant(0.797884583) - bcast.1 = f32[2,4] broadcast(const.1), dimensions={} - mul.3 = f32[2,4] multiply(add.0, bcast.1) - tanh = f32[2,4] tanh(mul.3) - const.2 = f32[] constant(1) - bcast.2 = f32[2,4] broadcast(const.2), dimensions={} - add.2 = f32[2,4] add(tanh, bcast.2) - const.3 = f32[] constant(0.5) - bcast.3 = f32[2,4] broadcast(const.3), dimensions={} - mul.4 = f32[2,4] multiply(add.2, bcast.3) - ROOT out = f32[2,4] multiply(dot, mul.4) -} - -)"; - - MatchOptimizedHlo(hlo_text, - R"( - -; CHECK-NOT: GELU - )"); -} - -TEST_F(CublasLtGemmRewriteTest, VectorBiasThenApproxGeluActivation) { - const char* hlo_text = R"( -HloModule test - -ENTRY test { - x = f32[2,3] parameter(0) - y = f32[3,4] parameter(1) - z = f32[4] parameter(2) - dot = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - z_bcast = f32[2,4] broadcast(z), dimensions={1} - add = f32[2,4] add(dot, z_bcast) - mul.0 = f32[2,4] multiply(add, add) - mul.1 = f32[2,4] multiply(add, mul.0) - const.0 = f32[] constant(0.044715) - bcast.0 = f32[2,4] broadcast(const.0), dimensions={} - mul.2 = f32[2,4] multiply(mul.1, bcast.0) - add.0 = f32[2,4] add(add, mul.2) - const.1 = f32[] constant(0.797884583) - bcast.1 = f32[2,4] broadcast(const.1), dimensions={} - mul.3 = f32[2,4] multiply(add.0, bcast.1) - tanh = f32[2,4] tanh(mul.3) - const.2 = f32[] constant(1) - bcast.2 = f32[2,4] broadcast(const.2), dimensions={} - add.2 = f32[2,4] add(tanh, bcast.2) - const.3 = f32[] constant(0.5) - bcast.3 = f32[2,4] broadcast(const.3), dimensions={} - mul.4 = f32[2,4] multiply(add.2, bcast.3) - ROOT out = f32[2,4] multiply(add, mul.4) -} - -)"; - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); - MatchOptimizedHlo(hlo_text, - R"( - -; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4], z: f32[4]) -> f32[2,4] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) -; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4]{0} parameter(2) -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,4]{1,0} custom-call([[P0]], [[P1]], [[P2]]), -; CHECK: custom_call_target="__cublas$lt$matmul", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":0 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"BIAS_GELU\" -; CHECK: }" - )"); -} - -TEST_F(CublasLtGemmRewriteTest, ApproxGeluActivationWithAux) { - const char* hlo_text = R"( -HloModule test - -ENTRY test { - x = f32[2,3] parameter(0) - y = f32[3,4] parameter(1) - dot = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - mul.0 = f32[2,4] multiply(dot, dot) - mul.1 = f32[2,4] multiply(dot, mul.0) - const.0 = f32[] constant(0.044715) - bcast.0 = f32[2,4] broadcast(const.0), dimensions={} - mul.2 = f32[2,4] multiply(mul.1, bcast.0) - add.0 = f32[2,4] add(dot, mul.2) - const.1 = f32[] constant(0.797884583) - bcast.1 = f32[2,4] broadcast(const.1), dimensions={} - mul.3 = f32[2,4] multiply(add.0, bcast.1) - tanh = f32[2,4] tanh(mul.3) - const.2 = f32[] constant(1) - bcast.2 = f32[2,4] broadcast(const.2), dimensions={} - add.2 = f32[2,4] add(tanh, bcast.2) - const.3 = f32[] constant(0.5) - bcast.3 = f32[2,4] broadcast(const.3), dimensions={} - mul.4 = f32[2,4] multiply(add.2, bcast.3) - mul.5 = f32[2,4] multiply(dot, mul.4) - ROOT out = (f32[2,4], f32[2,4]) tuple(mul.5, dot) -} - -)"; - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); - MatchOptimizedHlo(hlo_text, - R"( - -; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4]) -> (f32[2,4], f32[2,4]) { -; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = (f32[2,4]{1,0}, f32[2,4]{1,0}) custom-call([[P0]], [[P1]]), -; CHECK: custom_call_target="__cublas$lt$matmul", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":0 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"GELU_AUX\" -; CHECK: }" - )"); -} - -TEST_F(CublasLtGemmRewriteTest, VectorBiasThenApproxGeluActivationWithAux) { - const char* hlo_text = R"( -HloModule test - -ENTRY test { - x = f32[2,3] parameter(0) - y = f32[3,4] parameter(1) - z = f32[4] parameter(2) - dot = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - z_bcast = f32[2,4] broadcast(z), dimensions={1} - add = f32[2,4] add(dot, z_bcast) - mul.0 = f32[2,4] multiply(add, add) - mul.1 = f32[2,4] multiply(add, mul.0) - const.0 = f32[] constant(0.044715) - bcast.0 = f32[2,4] broadcast(const.0), dimensions={} - mul.2 = f32[2,4] multiply(mul.1, bcast.0) - add.0 = f32[2,4] add(add, mul.2) - const.1 = f32[] constant(0.797884583) - bcast.1 = f32[2,4] broadcast(const.1), dimensions={} - mul.3 = f32[2,4] multiply(add.0, bcast.1) - tanh = f32[2,4] tanh(mul.3) - const.2 = f32[] constant(1) - bcast.2 = f32[2,4] broadcast(const.2), dimensions={} - add.2 = f32[2,4] add(tanh, bcast.2) - const.3 = f32[] constant(0.5) - bcast.3 = f32[2,4] broadcast(const.3), dimensions={} - mul.4 = f32[2,4] multiply(add.2, bcast.3) - mul.5 = f32[2,4] multiply(add, mul.4) - ROOT out = (f32[2,4], f32[2,4]) tuple(mul.5, add) -} - -)"; - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); - MatchOptimizedHlo(hlo_text, - R"( - -; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4], z: f32[4]) -> (f32[2,4], f32[2,4]) { -; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) -; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4]{0} parameter(2) -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = (f32[2,4]{1,0}, f32[2,4]{1,0}) custom-call([[P0]], [[P1]], [[P2]]), -; CHECK: custom_call_target="__cublas$lt$matmul", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":0 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"BIAS_GELU_AUX\" -; CHECK: }" - )"); -} - -// For F16, the sizes of all dimensions of the operands are required to be -// multiples of 8 to allow matrix bias fusion. -TEST_F(CublasLtGemmRewriteTest, MatrixBiasF16) { - const char* hlo_text = R"( -HloModule test - -ENTRY test { - x = f16[8,16] parameter(0) - y = f16[16,8] parameter(1) - z = f16[8,8] parameter(2) - dot_a = f16[8,8] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - ROOT out = f16[8,8] add(dot_a, z) -} - -)"; - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3})); - MatchOptimizedHlo(hlo_text, - R"( - -; CHECK-LABEL: ENTRY %test (x: f16[8,16], y: f16[16,8], z: f16[8,8]) -> f16[8,8] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f16[8,16]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f16[16,8]{1,0} parameter(1) -; CHECK-NEXT: [[P2:%[^ ]+]] = f16[8,8]{1,0} parameter(2) -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f16[8,8]{1,0} custom-call([[P0]], [[P1]], [[P2]]), -; CHECK: custom_call_target="__cublas$lt$matmul", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":1 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"DEFAULT\" -; CHECK: }" - )"); -} - -// For F16, the operands are padded on GPUs with Tensor Cores (i.e. Volta and -// newer architectures) so that the sizes of all dimensions are multiples of 8. -TEST_F(CublasLtGemmRewriteTest, VectorBiasF16Unpadded) { - const char* hlo_text = R"( -HloModule test - -ENTRY test { - x = f16[8,16] parameter(0) - y = f16[16,8] parameter(1) - z = f16[8] parameter(2) - dot_a = f16[8,8] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - z_bcast = f16[8,8] broadcast(z), dimensions={1} - ROOT add = f16[8,8] add(dot_a, z_bcast) -} - -)"; - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{8e-3, 2e-3})); - MatchOptimizedHlo(hlo_text, - R"( - -; CHECK-LABEL: ENTRY %test (x: f16[8,16], y: f16[16,8], z: f16[8]) -> f16[8,8] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f16[8,16]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f16[16,8]{1,0} parameter(1) -; CHECK-NEXT: [[P2:%[^ ]+]] = f16[8]{0} parameter(2) -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f16[8,8]{1,0} custom-call([[P0]], [[P1]], [[P2]]), -; CHECK: custom_call_target="__cublas$lt$matmul", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":0 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"BIAS\" -; CHECK: }" - )"); -} - -TEST_F(CublasLtGemmRewriteTest, VectorBiasF16Padded) { - if (!GetCudaComputeCapability().IsAtLeast(se::CudaComputeCapability::VOLTA)) { - GTEST_SKIP() << "Padding of GEMM operands only implemented on " - "architectures with Tensor Cores."; - } - const char* hlo_text = R"( -HloModule test - -ENTRY test { - x = f16[6,12] parameter(0) - y = f16[12,6] parameter(1) - z = f16[6] parameter(2) - dot_a = f16[6,6] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - z_bcast = f16[6,6] broadcast(z), dimensions={1} - ROOT add = f16[6,6] add(dot_a, z_bcast) -} - -)"; - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3})); - MatchOptimizedHlo(hlo_text, - R"( - -; CHECK-LABEL: ENTRY %test (x: f16[6,12], y: f16[12,6], z: f16[6]) -> f16[6,6] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f16[6,12]{1,0} parameter(0) -; CHECK-NEXT: [[C0:%[^ ]+]] = f16[] constant(0) -; CHECK-NEXT: [[P0_PADDED:%[^ ]+]] = f16[8,16]{1,0} pad([[P0]], [[C0]]), padding=0_2x0_4 -; CHECK-NEXT: [[P1:%[^ ]+]] = f16[12,6]{1,0} parameter(1) -; CHECK-NEXT: [[P1_PADDED:%[^ ]+]] = f16[16,8]{1,0} pad([[P1]], [[C0]]), padding=0_4x0_2 -; CHECK-NEXT: [[P2:%[^ ]+]] = f16[6]{0} parameter(2) -; CHECK-NEXT: [[MATMUL:%[^ ]+]] = f16[8,8]{1,0} custom-call([[P0_PADDED]], [[P1_PADDED]], [[P2]]), -; CHECK: custom_call_target="__cublas$lt$matmul", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":0 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"BIAS\" -; CHECK: }" -; CHECK-NEXT: [[OUT:%[^ ]+]] = f16[6,6]{1,0} slice([[MATMUL]]), slice={[0:6], [0:6]} - )"); -} - -// For F16, the operands are padded on GPUs with Tensor Cores (i.e. Volta and -// newer architectures) so that the sizes of all dimensions are multiples of 8. -TEST_F(CublasLtGemmRewriteTest, ReluActivationF16Unpadded) { - const char* hlo_text = R"( -HloModule test - -ENTRY test { - x = f16[8,16] parameter(0) - y = f16[16,8] parameter(1) - dot_a = f16[8,8] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - c = f16[] constant(0) - c_bcast = f16[8,8] broadcast(c), dimensions={} - ROOT out = f16[8,8] maximum(dot_a, c_bcast) -} - -)"; - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3})); - MatchOptimizedHlo(hlo_text, - R"( - -; CHECK-LABEL: ENTRY %test (x: f16[8,16], y: f16[16,8]) -> f16[8,8] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f16[8,16]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f16[16,8]{1,0} parameter(1) -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f16[8,8]{1,0} custom-call([[P0]], [[P1]]), -; CHECK: custom_call_target="__cublas$lt$matmul", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":0 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"RELU\" -; CHECK: }" - )"); -} - -TEST_F(CublasLtGemmRewriteTest, ReluActivationF16Padded) { - if (!GetCudaComputeCapability().IsAtLeast(se::CudaComputeCapability::VOLTA)) { - GTEST_SKIP() << "Padding of GEMM operands only implemented on " - "architectures with Tensor Cores."; - } - const char* hlo_text = R"( -HloModule test - -ENTRY test { - x = f16[6,12] parameter(0) - y = f16[12,6] parameter(1) - dot_a = f16[6,6] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - c = f16[] constant(0) - c_bcast = f16[6,6] broadcast(c), dimensions={} - ROOT out = f16[6,6] maximum(dot_a, c_bcast) -} - -)"; - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); - MatchOptimizedHlo(hlo_text, - R"( - -; CHECK-LABEL: ENTRY %test (x: f16[6,12], y: f16[12,6]) -> f16[6,6] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f16[6,12]{1,0} parameter(0) -; CHECK-NEXT: [[C0:%[^ ]+]] = f16[] constant(0) -; CHECK-NEXT: [[P0_PADDED:%[^ ]+]] = f16[8,16]{1,0} pad([[P0]], [[C0]]), padding=0_2x0_4 -; CHECK-NEXT: [[P1:%[^ ]+]] = f16[12,6]{1,0} parameter(1) -; CHECK-NEXT: [[P1_PADDED:%[^ ]+]] = f16[16,8]{1,0} pad([[P1]], [[C0]]), padding=0_4x0_2 -; CHECK-NEXT: [[MATMUL:%[^ ]+]] = f16[8,8]{1,0} custom-call([[P0_PADDED]], [[P1_PADDED]]), -; CHECK: custom_call_target="__cublas$lt$matmul", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":0 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"RELU\" -; CHECK: }" -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f16[6,6]{1,0} slice([[MATMUL]]), slice={[0:6], [0:6]} - )"); -} - -TEST_F(CublasLtGemmRewriteTest, MatrixBiasReluActivationF16) { - const char* hlo_text = R"( -HloModule test - -ENTRY test { - x = f16[8,16] parameter(0) - y = f16[16,8] parameter(1) - z = f16[8,8] parameter(2) - dot_a = f16[8,8] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - add = f16[8,8] add(dot_a, z) - c = f16[] constant(0) - c_bcast = f16[8,8] broadcast(c), dimensions={} - ROOT out = f16[8,8] maximum(add, c_bcast) -} - -)"; - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3})); - MatchOptimizedHlo(hlo_text, - R"( - -; CHECK-LABEL: ENTRY %test (x: f16[8,16], y: f16[16,8], z: f16[8,8]) -> f16[8,8] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f16[8,16]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f16[16,8]{1,0} parameter(1) -; CHECK-NEXT: [[P2:%[^ ]+]] = f16[8,8]{1,0} parameter(2) -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f16[8,8]{1,0} custom-call([[P0]], [[P1]], [[P2]]), -; CHECK: custom_call_target="__cublas$lt$matmul", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":1 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"RELU\" -; CHECK: }" - )"); -} - -// For F16, the operands are padded on GPUs with Tensor Cores (i.e. Volta and -// newer architectures) so that the sizes of all dimensions are multiples of 8. -TEST_F(CublasLtGemmRewriteTest, VectorBiasReluActivationF16Unpadded) { - const char* hlo_text = R"( -HloModule test - -ENTRY test { - x = f16[8,16] parameter(0) - y = f16[16,8] parameter(1) - z = f16[8] parameter(2) - dot_a = f16[8,8] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - z_bcast = f16[8,8] broadcast(z), dimensions={1} - add = f16[8,8] add(dot_a, z_bcast) - c = f16[] constant(0) - c_bcast = f16[8,8] broadcast(c), dimensions={} - ROOT out = f16[8,8] maximum(add, c_bcast) -} - -)"; - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3})); - MatchOptimizedHlo(hlo_text, - R"( - -; CHECK-LABEL: ENTRY %test (x: f16[8,16], y: f16[16,8], z: f16[8]) -> f16[8,8] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f16[8,16]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f16[16,8]{1,0} parameter(1) -; CHECK-NEXT: [[P2:%[^ ]+]] = f16[8]{0} parameter(2) -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f16[8,8]{1,0} custom-call([[P0]], [[P1]], [[P2]]), -; CHECK: custom_call_target="__cublas$lt$matmul", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":0 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"BIAS_RELU\" -; CHECK: }" - )"); -} - -TEST_F(CublasLtGemmRewriteTest, VectorBiasReluActivationF16Padded) { - if (!GetCudaComputeCapability().IsAtLeast(se::CudaComputeCapability::VOLTA)) { - GTEST_SKIP() << "Padding of GEMM operands only implemented on " - "architectures with Tensor Cores."; - } - const char* hlo_text = R"( -HloModule test - -ENTRY test { - x = f16[6,12] parameter(0) - y = f16[12,6] parameter(1) - z = f16[6] parameter(2) - dot_a = f16[6,6] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - z_bcast = f16[6,6] broadcast(z), dimensions={1} - add = f16[6,6] add(dot_a, z_bcast) - c = f16[] constant(0) - c_bcast = f16[6,6] broadcast(c), dimensions={} - ROOT out = f16[6,6] maximum(add, c_bcast) -} - -)"; - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3})); - MatchOptimizedHlo(hlo_text, - R"( - -; CHECK-LABEL: ENTRY %test (x: f16[6,12], y: f16[12,6], z: f16[6]) -> f16[6,6] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f16[6,12]{1,0} parameter(0) -; CHECK-NEXT: [[C0:%[^ ]+]] = f16[] constant(0) -; CHECK-NEXT: [[P0_PADDED:%[^ ]+]] = f16[8,16]{1,0} pad([[P0]], [[C0]]), padding=0_2x0_4 -; CHECK-NEXT: [[P1:%[^ ]+]] = f16[12,6]{1,0} parameter(1) -; CHECK-NEXT: [[P1_PADDED:%[^ ]+]] = f16[16,8]{1,0} pad([[P1]], [[C0]]), padding=0_4x0_2 -; CHECK-NEXT: [[P2:%[^ ]+]] = f16[6]{0} parameter(2) -; CHECK-NEXT: [[MATMUL:%[^ ]+]] = f16[8,8]{1,0} custom-call([[P0_PADDED]], [[P1_PADDED]], [[P2]]), -; CHECK: custom_call_target="__cublas$lt$matmul", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: "beta\":0 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"BIAS_RELU\" -; CHECK: }" - )"); -} - -// For bfloat16, the sizes of all dimensions of the operands are required to be -// multiples of 8 to allow matrix bias fusion. -TEST_F(CublasLtGemmRewriteTest, MatrixBiasBF16) { - const char* hlo_text = R"( -HloModule test - -ENTRY test { - x = bf16[8,16] parameter(0) - y = bf16[16,8] parameter(1) - z = bf16[8,8] parameter(2) - dot_a = bf16[8,8] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - ROOT out = bf16[8,8] add(dot_a, z) -} - -)"; - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3})); - MatchOptimizedHlo(hlo_text, - R"( - -; CHECK-LABEL: ENTRY %test (x: bf16[8,16], y: bf16[16,8], z: bf16[8,8]) -> bf16[8,8] { -; CHECK-DAG: [[P0:%[^ ]+]] = bf16[8,16]{1,0} parameter(0) -; CHECK-DAG: [[P1:%[^ ]+]] = bf16[16,8]{1,0} parameter(1) -; CHECK-DAG: [[P2:%[^ ]+]] = bf16[8,8]{1,0} parameter(2) -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = bf16[8,8]{1,0} custom-call([[P0]], [[P1]], [[P2]]), -; CHECK: custom_call_target="__cublas$lt$matmul", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":1 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"DEFAULT\" -; CHECK: }" - )"); -} - -TEST_F(CublasLtGemmRewriteTest, MatrixBiasBitcastBF16) { - const char* hlo_text = R"( -HloModule test - -ENTRY test { - x = bf16[8,16] parameter(0) - y = bf16[16,8] parameter(1) - bias = bf16[2,4,8] parameter(2) - dot = bf16[8,8] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - bitcast = bf16[2,4,8] bitcast(dot) - ROOT out = bf16[2,4,8] add(bitcast, bias) -} - -)"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(hlo_text)); - GemmRewriter pass(GetCudaComputeCapability()); - TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); - EXPECT_TRUE(changed); - - EXPECT_THAT( - module->entry_computation()->root_instruction(), - GmockMatch( - m::Bitcast(m::CustomCall( - {"__cublas$lt$matmul"}, - m::Parameter(0).WithShape(BF16, {8, 16}), - m::Parameter(1).WithShape(BF16, {16, 8}), - m::Bitcast(m::Parameter(2)).WithShape(BF16, {8, 8}))) - .WithShape(BF16, {2, 4, 8}))); -} - -// For bfloat16, the operands are padded if necessary on Ampere and newer -// architectures so that the sizes of all dimensions are multiples of 8. -TEST_F(CublasLtGemmRewriteTest, VectorBiasBF16Unpadded) { - const char* hlo_text = R"( -HloModule test - -ENTRY test { - x = bf16[8,16] parameter(0) - y = bf16[16,8] parameter(1) - z = bf16[8] parameter(2) - dot_a = bf16[8,8] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - z_bcast = bf16[8,8] broadcast(z), dimensions={1} - ROOT add = bf16[8,8] add(dot_a, z_bcast) -} - -)"; - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{8e-3, 2e-3})); - MatchOptimizedHlo(hlo_text, - R"( - -; CHECK-LABEL: ENTRY %test (x: bf16[8,16], y: bf16[16,8], z: bf16[8]) -> bf16[8,8] { -; CHECK-DAG: [[P0:%[^ ]+]] = bf16[8,16]{1,0} parameter(0) -; CHECK-DAG: [[P1:%[^ ]+]] = bf16[16,8]{1,0} parameter(1) -; CHECK-DAG: [[P2:%[^ ]+]] = bf16[8]{0} parameter(2) -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = bf16[8,8]{1,0} custom-call([[P0]], [[P1]], [[P2]]), -; CHECK: custom_call_target="__cublas$lt$matmul", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":0 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"BIAS\" -; CHECK: }" - )"); -} - -TEST_F(CublasLtGemmRewriteTest, VectorBiasBF16Padded) { - if (!GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << "Padding of GEMM operands in bfloat16 only implemented on " - "Ampere and newer architectures."; - } - const char* hlo_text = R"( -HloModule test - -ENTRY test { - x = bf16[6,12] parameter(0) - y = bf16[12,6] parameter(1) - z = bf16[6] parameter(2) - dot_a = bf16[6,6] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - z_bcast = bf16[6,6] broadcast(z), dimensions={1} - ROOT add = bf16[6,6] add(dot_a, z_bcast) -} - -)"; - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3})); - MatchOptimizedHlo(hlo_text, - R"( - -; CHECK-LABEL: ENTRY %test (x: bf16[6,12], y: bf16[12,6], z: bf16[6]) -> bf16[6,6] { -; CHECK-DAG: [[P0:%[^ ]+]] = bf16[6,12]{1,0} parameter(0) -; CHECK-DAG: [[C0:%[^ ]+]] = bf16[] constant(0) -; CHECK-DAG: [[P0_PADDED:%[^ ]+]] = bf16[8,16]{1,0} pad([[P0]], [[C0]]), padding=0_2x0_4 -; CHECK-DAG: [[P1:%[^ ]+]] = bf16[12,6]{1,0} parameter(1) -; CHECK-DAG: [[P1_PADDED:%[^ ]+]] = bf16[16,8]{1,0} pad([[P1]], [[C0]]), padding=0_4x0_2 -; CHECK-DAG: [[P2:%[^ ]+]] = bf16[6]{0} parameter(2) -; CHECK-NEXT: [[MATMUL:%[^ ]+]] = bf16[8,8]{1,0} custom-call([[P0_PADDED]], [[P1_PADDED]], [[P2]]), -; CHECK: custom_call_target="__cublas$lt$matmul", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":0 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"BIAS\" -; CHECK: }" -; CHECK-NEXT: [[OUT:%[^ ]+]] = bf16[6,6]{1,0} slice([[MATMUL]]), slice={[0:6], [0:6]} - )"); -} - -// For bfloat16, the operands are padded if necessary on Ampere and newer -// architectures so that the sizes of all dimensions are multiples of 8. -TEST_F(CublasLtGemmRewriteTest, ReluActivationBF16Unpadded) { - const char* hlo_text = R"( -HloModule test - -ENTRY test { - x = bf16[8,16] parameter(0) - y = bf16[16,8] parameter(1) - dot_a = bf16[8,8] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - c = bf16[] constant(0) - c_bcast = bf16[8,8] broadcast(c), dimensions={} - ROOT out = bf16[8,8] maximum(dot_a, c_bcast) -} - -)"; - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3})); - MatchOptimizedHlo(hlo_text, - R"( - -; CHECK-LABEL: ENTRY %test (x: bf16[8,16], y: bf16[16,8]) -> bf16[8,8] { -; CHECK-DAG: [[P0:%[^ ]+]] = bf16[8,16]{1,0} parameter(0) -; CHECK-DAG: [[P1:%[^ ]+]] = bf16[16,8]{1,0} parameter(1) -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = bf16[8,8]{1,0} custom-call([[P0]], [[P1]]), -; CHECK: custom_call_target="__cublas$lt$matmul", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":0 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"RELU\" -; CHECK: }" - )"); -} - -TEST_F(CublasLtGemmRewriteTest, ReluActivationBF16Padded) { - if (!GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << "Padding of GEMM operands in bfloat16 only implemented on " - "Ampere and newer architectures."; - } - const char* hlo_text = R"( -HloModule test - -ENTRY test { - x = bf16[6,12] parameter(0) - y = bf16[12,6] parameter(1) - dot_a = bf16[6,6] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - c = bf16[] constant(0) - c_bcast = bf16[6,6] broadcast(c), dimensions={} - ROOT out = bf16[6,6] maximum(dot_a, c_bcast) -} - -)"; - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); - MatchOptimizedHlo(hlo_text, - R"( - -; CHECK-LABEL: ENTRY %test (x: bf16[6,12], y: bf16[12,6]) -> bf16[6,6] { -; CHECK-DAG: [[P0:%[^ ]+]] = bf16[6,12]{1,0} parameter(0) -; CHECK-DAG: [[C0:%[^ ]+]] = bf16[] constant(0) -; CHECK-DAG: [[P0_PADDED:%[^ ]+]] = bf16[8,16]{1,0} pad([[P0]], [[C0]]), padding=0_2x0_4 -; CHECK-DAG: [[P1:%[^ ]+]] = bf16[12,6]{1,0} parameter(1) -; CHECK-DAG: [[P1_PADDED:%[^ ]+]] = bf16[16,8]{1,0} pad([[P1]], [[C0]]), padding=0_4x0_2 -; CHECK-NEXT: [[MATMUL:%[^ ]+]] = bf16[8,8]{1,0} custom-call([[P0_PADDED]], [[P1_PADDED]]), -; CHECK: custom_call_target="__cublas$lt$matmul", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":0 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"RELU\" -; CHECK: }" -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = bf16[6,6]{1,0} slice([[MATMUL]]), slice={[0:6], [0:6]} - )"); -} - -// For bfloat16, the operands are padded if necessary on Ampere and newer -// architectures so that the sizes of all dimensions are multiples of 8. -TEST_F(CublasLtGemmRewriteTest, VectorBiasReluActivationBF16Unpadded) { - const char* hlo_text = R"( -HloModule test - -ENTRY test { - x = bf16[8,16] parameter(0) - y = bf16[16,8] parameter(1) - z = bf16[8] parameter(2) - dot_a = bf16[8,8] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - z_bcast = bf16[8,8] broadcast(z), dimensions={1} - add = bf16[8,8] add(dot_a, z_bcast) - c = bf16[] constant(0) - c_bcast = bf16[8,8] broadcast(c), dimensions={} - ROOT out = bf16[8,8] maximum(add, c_bcast) -} - -)"; - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{8e-3, 2e-3})); - MatchOptimizedHlo(hlo_text, - R"( - -; CHECK-LABEL: ENTRY %test (x: bf16[8,16], y: bf16[16,8], z: bf16[8]) -> bf16[8,8] { -; CHECK-DAG: [[P0:%[^ ]+]] = bf16[8,16]{1,0} parameter(0) -; CHECK-DAG: [[P1:%[^ ]+]] = bf16[16,8]{1,0} parameter(1) -; CHECK-DAG: [[P2:%[^ ]+]] = bf16[8]{0} parameter(2) -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = bf16[8,8]{1,0} custom-call([[P0]], [[P1]], [[P2]]), -; CHECK: custom_call_target="__cublas$lt$matmul", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":0 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"BIAS_RELU\" -; CHECK: }" - )"); -} - -TEST_F(CublasLtGemmRewriteTest, VectorBiasReluActivationBF16Padded) { - if (!GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { - GTEST_SKIP() << "Padding of GEMM operands in bfloat16 only implemented on " - "Ampere and newer architectures."; +// A test fixture class for tests which should have similar results with legacy +// cublas and cublasLt +class ParameterizedGemmRewriteTest + : public GemmRewriteTest, + public ::testing::WithParamInterface { + public: + ParameterizedGemmRewriteTest() { + const bool kUsingCublasLt = GetParam(); + replacements_[kCustomCallTargetPlaceholder] = + kUsingCublasLt ? "__cublas$lt$matmul" : "__cublas$gemm"; + } + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = GemmRewriteTest::GetDebugOptionsForTest(); + debug_options.set_xla_gpu_enable_cublaslt(GetParam()); + return debug_options; + } + void MatchOptimizedHlo(absl::string_view hlo, const absl::string_view pattern, + bool print_operand_shape = false) { + GemmRewriteTest::MatchOptimizedHlo( + hlo, absl::StrReplaceAll(pattern, replacements_), print_operand_shape); + } + absl::string_view CustomCallTarget() { + return replacements_[kCustomCallTargetPlaceholder]; } - const char* hlo_text = R"( -HloModule test - -ENTRY test { - x = bf16[6,12] parameter(0) - y = bf16[12,6] parameter(1) - z = bf16[6] parameter(2) - dot_a = bf16[6,6] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - z_bcast = bf16[6,6] broadcast(z), dimensions={1} - add = bf16[6,6] add(dot_a, z_bcast) - c = bf16[] constant(0) - c_bcast = bf16[6,6] broadcast(c), dimensions={} - ROOT out = bf16[6,6] maximum(add, c_bcast) -} - -)"; - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3})); - MatchOptimizedHlo(hlo_text, - R"( - -; CHECK-LABEL: ENTRY %test (x: bf16[6,12], y: bf16[12,6], z: bf16[6]) -> bf16[6,6] { -; CHECK-DAG: [[P0:%[^ ]+]] = bf16[6,12]{1,0} parameter(0) -; CHECK-DAG: [[C0:%[^ ]+]] = bf16[] constant(0) -; CHECK-DAG: [[P0_PADDED:%[^ ]+]] = bf16[8,16]{1,0} pad([[P0]], [[C0]]), padding=0_2x0_4 -; CHECK-DAG: [[P1:%[^ ]+]] = bf16[12,6]{1,0} parameter(1) -; CHECK-DAG: [[P1_PADDED:%[^ ]+]] = bf16[16,8]{1,0} pad([[P1]], [[C0]]), padding=0_4x0_2 -; CHECK-DAG: [[P2:%[^ ]+]] = bf16[6]{0} parameter(2) -; CHECK-NEXT: [[MATMUL:%[^ ]+]] = bf16[8,8]{1,0} custom-call([[P0_PADDED]], [[P1_PADDED]], [[P2]]), -; CHECK: custom_call_target="__cublas$lt$matmul", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":0 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"BIAS_RELU\" -; CHECK: }" -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = bf16[6,6]{1,0} slice([[MATMUL]]), slice={[0:6], [0:6]} - )"); -} - -TEST_F(CublasLtGemmRewriteTest, VectorBiasReluActivationF64) { - const char* hlo_text = R"( -HloModule test - -ENTRY test { - x = f64[2,3] parameter(0) - y = f64[3,4] parameter(1) - z = f64[4] parameter(2) - dot_a = f64[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - z_bcast = f64[2,4] broadcast(z), dimensions={1} - add = f64[2,4] add(dot_a, z_bcast) - c = f64[] constant(0) - c_bcast = f64[2,4] broadcast(c), dimensions={} - ROOT out = f64[2,4] maximum(add, c_bcast) -} - -)"; - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-10, 1e-10})); - MatchOptimizedHlo(hlo_text, - R"( - -; CHECK-LABEL: ENTRY %test (x: f64[2,3], y: f64[3,4], z: f64[4]) -> f64[2,4] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f64[2,3]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f64[3,4]{1,0} parameter(1) -; CHECK-NEXT: [[P2:%[^ ]+]] = f64[4]{0} parameter(2) -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f64[2,4]{1,0} custom-call([[P0]], [[P1]], [[P2]]), -; CHECK: custom_call_target="__cublas$lt$matmul", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":0 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"BIAS_RELU\" -; CHECK: }" - )"); -} - -TEST_F(CublasLtGemmRewriteTest, AlphaSimpleRewriteBiasAddActivation) { - const char* hlo_text = R"( -HloModule test - -ENTRY test { - x = f32[2,3] parameter(0) - y = f32[3,4] parameter(1) - z = f32[4] parameter(2) - k = f32[] constant(3.0) - k_bcast = f32[2,4] broadcast(k), dimensions={} - dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - dot_a_multiplied = f32[2, 4] multiply(dot_a, k_bcast) - z_bcast = f32[2,4] broadcast(z), dimensions={1} - add = f32[2,4] add(dot_a_multiplied, z_bcast) - c = f32[] constant(0) - c_bcast = f32[2,4] broadcast(c), dimensions={} - ROOT out = f32[2,4] maximum(add, c_bcast) -} - -)"; - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); - MatchOptimizedHlo(hlo_text, - R"( - -; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4], z: f32[4]) -> f32[2,4] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) -; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4]{0} parameter(2) -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,4]{1,0} custom-call([[P0]], [[P1]], [[P2]]), -; CHECK: custom_call_target="__cublas$lt$matmul", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":3 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":0 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"BIAS_RELU\" -; CHECK: }" - )"); -} - -TEST_F(CublasLtGemmRewriteTest, FoldConstantBias) { - const char* hlo_text = R"( -HloModule test -ENTRY test { - x = f32[2,2] parameter(0) - y = f32[2,2] parameter(1) - bias = f32[2,2] broadcast(f32[2] constant({0, 0})), dimensions={0} - - dot1 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - bias1 = f32[2,2] parameter(2) - sum1 = add(dot1, bias1) - - dot2 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - sum2 = add(dot2, f32[2,2] reshape(bias)) - - dot3 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - bias3 = f32[2,2] transpose(bias), dimensions={1,0} - sum3 = add(dot3, bias3) - - dot4 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - sum4 = add(dot4, f32[2,2] bitcast(bias)) - - ROOT root = tuple(sum1, sum2, sum3, sum4) -} -)"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(hlo_text)); - GemmRewriter pass(GetCudaComputeCapability()); - TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); - SCOPED_TRACE(module->ToString()); - EXPECT_TRUE(changed); - - EXPECT_THAT( - module->entry_computation()->root_instruction(), - GmockMatch(m::Tuple( - m::CustomCall(m::Parameter(0), m::Parameter(1), m::Parameter()), - m::CustomCall(m::Parameter(0), m::Parameter(1), m::Constant()), - m::CustomCall(m::Parameter(0), m::Parameter(1), m::Constant()), - m::CustomCall(m::Parameter(0), m::Parameter(1), m::Constant())))); -} - -TEST_F(CublasLtGemmRewriteTest, MultipleMaximumUsers) { - const char* hlo_text = R"( -HloModule multiple_maximum_users - -relu { - Arg_0 = f32[3,896,54]{2,1,0} parameter(0) - constant = f32[] constant(0) - broadcast = f32[3,896,54]{2,1,0} broadcast(constant), dimensions={} - ROOT maximum = f32[3,896,54]{2,1,0} maximum(Arg_0, broadcast) -} -ENTRY main { - constant = f32[] constant(1) - broadcast_1 = f32[3,896,1024]{2,1,0} broadcast(constant), dimensions={} - Arg_2 = f32[1024,54]{1,0} parameter(2) - dot = f32[3,896,54]{2,1,0} dot(broadcast_1, Arg_2), lhs_contracting_dims={2}, rhs_contracting_dims={0} - Arg_1 = f32[54]{0} parameter(1) - broadcast_2 = f32[3,896,54]{2,1,0} broadcast(Arg_1), dimensions={2} - add = f32[3,896,54]{2,1,0} add(dot, broadcast_2) - call = f32[3,896,54]{2,1,0} call(add), to_apply=relu - Arg_0 = f32[1]{0} parameter(0) - reshape_1 = f32[1,1,1]{2,1,0} reshape(Arg_0) - broadcast_3 = f32[1,1,1]{2,1,0} broadcast(reshape_1), dimensions={0,1,2} - reshape_2 = f32[] reshape(broadcast_3) - broadcast_4 = f32[3,896,54]{2,1,0} broadcast(reshape_2), dimensions={} - multiply = f32[3,896,54]{2,1,0} multiply(call, broadcast_4) - ROOT tuple = (f32[3,896,54]{2,1,0}, f32[3,896,54]{2,1,0}) tuple(multiply, call) -} -)"; + protected: + absl::flat_hash_map replacements_; - // TODO(cjfj): Why do we need to relax the error constraint here?! - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-4})); - MatchOptimizedHlo(hlo_text, - R"( -; CHECK: custom_call_target="__cublas$lt$matmul", - )"); -} + private: + static constexpr const char* kCustomCallTargetPlaceholder{ + "<>"}; +}; class ParameterizedFp8GemmRewriteTest : public ParameterizedGemmRewriteTest { public: @@ -5336,6 +1252,152 @@ TEST_P(ParameterizedFp8GemmRewriteTest, Rank3ScaledABUnscaledDVectorBiasF8) { )"); } +TEST_P(ParameterizedFp8GemmRewriteTest, Rank3ScaledABUnscaledDMatrixBiasF8) { +#if CUDA_VERSION < 12000 + GTEST_SKIP() << "A matrix bias on a matmul is only supported in CUDA 12"; +#endif + const char* hlo_text = R"( + HloModule test + ENTRY test { + x = f8e4m3fn[4,16,16] parameter(0) + y = f8e4m3fn[16,32] parameter(1) + b = f32[4,16,32] parameter(2) + x_f32 = f32[4,16,16] convert(x) + y_f32 = f32[16,32] convert(y) + x_scale = f32[] parameter(3) + y_scale = f32[] parameter(4) + x_scale_bcast = f32[4,16,16] broadcast(x_scale), dimensions={} + y_scale_bcast = f32[16,32] broadcast(y_scale), dimensions={} + x_unscaled = f32[4,16,16] multiply(x_f32, x_scale_bcast) + x_unscaled_bitcast = f32[64,16] bitcast(x_unscaled) + y_unscaled = f32[16,32] multiply(y_f32, y_scale_bcast) + dot_a = f32[64,32] dot(x_unscaled_bitcast, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0} + dot_a_bitcast = f32[4,16,32]{2,1,0} bitcast(dot_a) + ROOT out = f32[4,16,32] add(dot_a_bitcast, b) + } +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + GemmRewriter pass( + se::CudaComputeCapability{se::CudaComputeCapability::HOPPER, 0}); + TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); + EXPECT_TRUE(changed); + + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::Bitcast(m::CustomCall({"__cublas$lt$matmul$f8"}) + .WithShape(F32, {64, 32})) + .WithShape(F32, {4, 16, 32}))); + + RunAndFilecheckHloRewrite(hlo_text, + GemmRewriter(se::CudaComputeCapability{ + se::CudaComputeCapability::HOPPER, 0}), + R"( +; CHECK-LABEL: ENTRY %test (x: f8e4m3fn[4,16,16], y: f8e4m3fn[16,32], b: f32[4,16,32], x_scale: f32[], y_scale: f32[]) -> f32[4,16,32] { +; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fn[4,16,16]{2,1,0} parameter(0) +; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f8e4m3fn[64,16]{1,0} bitcast([[P0]]) +; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[32,16]{1,0} transpose([[P1]]), dimensions={1,0} +; CHECK-NEXT: [[B:%[^ ]+]] = f32[4,16,32]{2,1,0} parameter(2) +; CHECK-NEXT: [[B_BITCAST:%[^ ]+]] = f32[64,32]{1,0} bitcast([[B]]) +; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(3) +; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(4) +; CHECK-NEXT: [[C:%[^ ]+]] = f32[] constant(1) +; CHECK-NEXT: [[GEMM:%[^ ]+]] = f32[64,32]{1,0} custom-call([[P0_BITCAST]], [[P1_TRANSPOSE]], [[B_BITCAST]], [[P2]], [[P3]], /*index=5*/[[C]], [[C]]), +; CHECK: custom_call_target="__cublas$lt$matmul$f8", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":1 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"DEFAULT\" +; CHECK: }" +; CHECK: ROOT [[OUT:%[^ ]+]] = f32[4,16,32]{2,1,0} bitcast([[GEMM]]) + )"); +} + +TEST_P(ParameterizedFp8GemmRewriteTest, + Rank3ScaledABUnscaledDMatrixBiasPaddedF8) { +#if CUDA_VERSION < 12000 + GTEST_SKIP() << "A matrix bias on a matmul is only supported in CUDA 12"; +#endif + const char* hlo_text = R"( + HloModule test + ENTRY test { + x = f8e4m3fn[4,15,15] parameter(0) + y = f8e4m3fn[15,31] parameter(1) + b = f32[4,15,31] parameter(2) + x_f32 = f32[4,15,15] convert(x) + y_f32 = f32[15,31] convert(y) + x_scale = f32[] parameter(3) + y_scale = f32[] parameter(4) + x_scale_bcast = f32[4,15,15] broadcast(x_scale), dimensions={} + y_scale_bcast = f32[15,31] broadcast(y_scale), dimensions={} + x_unscaled = f32[4,15,15] multiply(x_f32, x_scale_bcast) + x_unscaled_bitcast = f32[60,15] bitcast(x_unscaled) + y_unscaled = f32[15,31] multiply(y_f32, y_scale_bcast) + dot_a = f32[60,31] dot(x_unscaled_bitcast, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0} + dot_a_bitcast = f32[4,15,31]{2,1,0} bitcast(dot_a) + ROOT out = f32[4,15,31] add(dot_a_bitcast, b) + } +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + GemmRewriter pass( + se::CudaComputeCapability{se::CudaComputeCapability::HOPPER, 0}); + TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); + EXPECT_TRUE(changed); + + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::Bitcast(m::CustomCall({"__cublas$lt$matmul$f8"}) + .WithShape(F32, {64, 32})) + .WithShape(F32, {4, 16, 32}))); + + RunAndFilecheckHloRewrite(hlo_text, + GemmRewriter(se::CudaComputeCapability{ + se::CudaComputeCapability::HOPPER, 0}), + R"( +; CHECK-LABEL: ENTRY %test (x: f8e4m3fn[4,15,15], y: f8e4m3fn[15,31], b: f32[4,15,31], x_scale: f32[], y_scale: f32[]) -> f32[4,15,31] { +; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fn[4,15,15]{2,1,0} parameter(0) +; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f8e4m3fn[60,15]{1,0} bitcast([[P0]]) +; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[32,16]{1,0} transpose([[P1]]), dimensions={1,0} +; CHECK-NEXT: [[B:%[^ ]+]] = f32[4,16,32]{2,1,0} parameter(2) +; CHECK-NEXT: [[B_BITCAST:%[^ ]+]] = f32[64,32]{1,0} bitcast([[B]]) +; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(3) +; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(4) +; CHECK-NEXT: [[C:%[^ ]+]] = f32[] constant(1) +; CHECK-NEXT: [[GEMM:%[^ ]+]] = f32[64,32]{1,0} custom-call([[P0_BITCAST]], [[P1_TRANSPOSE]], [[B_BITCAST]], [[P2]], [[P3]], /*index=5*/[[C]], [[C]]), +; CHECK: custom_call_target="__cublas$lt$matmul$f8", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":1 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"DEFAULT\" +; CHECK: }" +; CHECK: ROOT [[OUT:%[^ ]+]] = f32[4,16,32]{2,1,0} bitcast([[GEMM]]) +; CHECK: %slice = f32[60,31]{1,0} slice(%cublas-gemm.2), slice={[0:60], [0:31]} + )"); +} + TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDVectorBiasThenReluActivationF8) { #if CUDA_VERSION < 12000 @@ -5882,66 +1944,6 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF8TF32E5M2) { INSTANTIATE_TEST_SUITE_P(Fp8CublasTestsBothLegacyAndLt, ParameterizedFp8GemmRewriteTest, ::testing::Bool()); -TEST_F(GemmRewriteTest, NoFuseBiasBroadcast) { - const char* hlo = R"( - -HloModule module - -ENTRY main.10 { - Arg_0.1 = f16[384,128]{1,0} parameter(0) - Arg_1.2 = f16[128,256]{1,0} parameter(1) - dot.4 = f16[384,256]{1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0} - Arg_2.3 = f16[256]{0} parameter(2) - reshape.5 = f16[1,256]{1,0} reshape(Arg_2.3) - broadcast.6 = f16[1,256]{1,0} broadcast(reshape.5), dimensions={0,1} - reshape.7 = f16[256]{0} reshape(broadcast.6) - broadcast.8 = f16[384,256]{1,0} broadcast(reshape.7), dimensions={1} - ROOT add.9 = f16[384,256]{1,0} add(dot.4, broadcast.8) -})"; - - MatchOptimizedHlo(hlo, R"( -// CHECK: \"beta\":0 - )"); -} - -class GemmRewriteAllocationTest : public GpuCodegenTest { - public: - void CheckNumberOfAllocations(const std::string& hlo, - int expected_number_of_allocations) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr optimized_module, - GetOptimizedModule(hlo)); - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr executable, - backend().compiler()->RunBackend( - std::move(optimized_module), backend().default_stream_executor(), - backend().default_stream_executor()->GetAllocator())); - GpuExecutable* gpu_executable = - static_cast(executable.get()); - absl::Span allocations = - gpu_executable->GetAllocations(); - CHECK_EQ(allocations.size(), expected_number_of_allocations); - } -}; - -TEST_F(GemmRewriteAllocationTest, SharedBufferAssignment) { - const char* hlo_text = R"( -HloModule SharedBufferAssignment - -ENTRY AddDotsFunc { - x = f32[2,2] parameter(0) - y = f32[2,2] parameter(1) - bias = f32[2,2] add(x, y) - dot = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - ROOT out = f32[2,2] add(dot, bias) -} - -)"; - - // Bias should be fused into the multiplication. - CheckNumberOfAllocations(hlo_text, 3); - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); -} - } // namespace } // namespace gpu } // namespace xla From 42131773ff0aa19c5a5618b4a08c58cceb5fdc1a Mon Sep 17 00:00:00 2001 From: shuw Date: Sat, 10 Jun 2023 17:25:52 -0700 Subject: [PATCH 010/410] Rank3 all pass --- .../compiler/xla/service/gpu/gemm_rewriter.cc | 356 ++++++++++-------- .../service/gpu/tests/gemm_rewrite_test.cc | 150 ++++++-- 2 files changed, 314 insertions(+), 192 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc b/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc index d1f3fb3cf14215..52f8f808687900 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc @@ -328,6 +328,12 @@ auto OptionalConvert(HloInstruction **optional_convert, Pattern pattern) { std::move(pattern)); } +template +auto OptionalBitcast(HloInstruction **optional_bitcast, Pattern pattern) { + return m::AnyOf(m::Bitcast(optional_bitcast, pattern), + std::move(pattern)); +} + // The rewriting proceeds in a bottom-up way: // // (kDot A B) is rewritten into a (kCustomCall:gemm A B) @@ -494,28 +500,34 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { } Status HandleAdd(HloInstruction *instr) override { - HloInstruction *bias, *existing_gemm; + HloInstruction *bias, *existing_gemm(nullptr); HloInstruction *optional_slice = nullptr; HloInstruction *optional_convert = nullptr; + HloInstruction *optional_bitcast = nullptr; + VLOG(1)<<"Yes, in HandleAdd!\n"; + VLOG(1)<<"shuw:" << instr->GetModule()->ToString(); // Attempt to elide broadcast and fuse addition of a vector bias into GEMM, // including when slicing is applied to the result. if (Match(instr, m::AddAnyOrder( - OptionalSlice( + OptionalBitcast(&optional_bitcast,OptionalSlice( &optional_slice, CublasLtMatmulMaybeF8(&existing_gemm).WithOneUser()) - .WithOneUser(), + .WithOneUser()).WithOneUser(), m::Broadcast(&bias, OptionalConvert(&optional_convert, m::Op()))))) { + VLOG(1) << "Yes, in HandleAdd VECTOR!\n"; TF_ASSIGN_OR_RETURN(bool was_fused, FuseVectorBiasAdd(instr, bias, existing_gemm, - optional_slice, optional_convert)); + optional_slice, optional_convert, + optional_bitcast)); if (was_fused) { return OkStatus(); } } - + // return OkStatus(); + VLOG(1)<<"HandleAdd:Yes, 11111111111111!\n"; // Attempt to elide broadcast and fuse addition of a vector bias into // *batched* GEMM as a matrix bias addition using FuseMatrixBiasAdd. // add(bitcast(gemm(a, b)), broadcast(bias)) -> @@ -538,7 +550,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { // Continue below. instr = new_add; } - + VLOG(1)<<"Yes, 22222222222222!\n"; // Do not fuse broadcast unless we can fuse its input, as it will cause // broadcast materialization. auto is_not_broadcast = [](const HloInstruction *instr) { @@ -573,15 +585,40 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { // Continue below transforming new_add. instr = new_add; } - + VLOG(1)<<"Yes, 333333333333! if gemm null\n" << (existing_gemm == nullptr); + VLOG(1)<<"Yes, 333333333333!\n" << instr->ToShortString(); if (Match(instr, m::AddAnyOrder( GemmOrCublasLtMatmulMaybeF8(&existing_gemm).WithOneUser(), m::Op(&bias).WithPredicate(is_not_broadcast)))) { - std::cout << "uuuuuuuuuuuuuuuuuuuuuuuuuuuu\n"; + VLOG(1) << "Yes, heads to FuseMatrixBiasAdd!\n"; return FuseMatrixBiasAdd(instr, bias, existing_gemm); } - + VLOG(1) << "Herer cccccccccccccc\n"; + // if (bias) { + // VLOG(1) << "bias not emp"; + // VLOG(1) << bias->ToString(); + // } + HloInstruction *optional_bitcast_matrix = nullptr; + HloInstruction* optional_slice_matrix = nullptr; + if (Match(instr, + m::AddAnyOrder( + OptionalBitcast( + &optional_bitcast_matrix, + OptionalSlice( + &optional_slice_matrix, + CublasLtMatmulMaybeF8(&existing_gemm).WithOneUser())) + .WithOneUser(), + m::Op(&bias).WithPredicate(is_not_broadcast)))) { + VLOG(1) << "Herer uuuuuuuuuuuuuuuuuuuuuuuuuuuu\n"; + if (bias) { + VLOG(1) << "bias not emp"; + VLOG(1) << bias->ToString(); + VLOG(1) << bias->users()[0]->ToString(); + } + return FuseMatrixBiasAdd(instr, bias, existing_gemm, + optional_bitcast_matrix, optional_slice_matrix); + } return OkStatus(); } @@ -725,48 +762,21 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { return false; } - // Get the padded shape. - auto pad_shape = [&batch_dims](const Shape old_shape) { - Shape padded_shape = old_shape; - for (int i = 0; i < old_shape.rank(); ++i) { - if (!absl::c_linear_search(batch_dims, i)) { - int64_t padded_dimension = - RoundUpTo(old_shape.dimensions(i), 16); - padded_shape.set_dimensions(i, padded_dimension); - } - } - return padded_shape; - }; - - // If slice is needed, won't pattern matched to FuseMatrixBias - bool slice_needed = pad_shape(instr->shape()).dimensions() != instr->shape().dimensions(); // Fuse the possible addition of a matrix bias here to enable the subsequent // fusion of the scaling and conversion of D into the Custom Call. Fusing // a matrix bias is only supported with CUDA 12 and above. HloInstruction *c = nullptr, *add = nullptr; - bool has_matrix_bias = false; - bool is_high_rank_input = a->shape().rank() > 2 || b->shape().rank() > 2; - if (slice_needed && ((instr->user_count() == 1 && - instr->users()[0]->opcode() == HloOpcode::kAdd) || - (instr->user_count() == 1 && instr->users()[0]->user_count() == 1 && - instr->users()[0]->opcode()== HloOpcode::kBitcast && - instr->users()[0]->users()[0]->opcode() == HloOpcode::kAdd) ) ) { - std::cout << "yyyyyyyyyyyyyyyyyyyyyyyyyy\n"; - std::cout << a->shape().rank() <<" ;"<< b->shape().rank()<users()[0]->mutable_operand(!instr->users()[0]->operand_index(instr)); - } else { - bias = instr->users()[0]->users()[0]->mutable_operand( - !instr->users()[0]->users()[0]->operand_index(instr->users()[0])); - } + + if (instr->user_count() == 1 && + instr->users()[0]->opcode() == HloOpcode::kAdd) { + HloInstruction *bias = instr->users()[0]->mutable_operand( + !instr->users()[0]->operand_index(instr)); if (bias->opcode() != HloOpcode::kBroadcast) { - // c = bias; - // gemm_backend_config.set_beta(1.0); - // add = instr->users()[0]; - has_matrix_bias = true; + VLOG(1) <<"WWWWWWWWWWWWWWWWWWWWWWWW"; + c = bias; + gemm_backend_config.set_beta(1.0); + add = instr->users()[0]; } - std::cout << "hhhhhhhhhhhhhhhhhhhhhhhhhhhhhhh\n"; } // Each operand must have exactly one contracting and one non-contracting @@ -863,6 +873,19 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { b = TransposeMatrix(b, b_contracting_dims[0], batch_dims); } + // Get the padded shape. + auto pad_shape = [&batch_dims](const Shape old_shape) { + Shape padded_shape = old_shape; + for (int i = 0; i < old_shape.rank(); ++i) { + if (!absl::c_linear_search(batch_dims, i)) { + int64_t padded_dimension = + RoundUpTo(old_shape.dimensions(i), 16); + padded_shape.set_dimensions(i, padded_dimension); + } + } + return padded_shape; + }; + // Pad the non-batch dimensions of the operands to multiples of 16 as // required by cuBLASLt. auto pad_operand = [&instr, &pad_shape](HloInstruction *&x) -> void { @@ -885,18 +908,12 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { return; }; - // Possible padding of ouput shape. - Shape new_output_shape = instr->shape(); - if (!has_matrix_bias) { - pad_operand(a); - pad_operand(b); - new_output_shape = pad_shape(instr->shape()); - } + pad_operand(a); + pad_operand(b); if (c != nullptr) { - std::cout <<" pppppppaaaaaaa\n"; pad_operand(c); } - + Shape new_output_shape = pad_shape(instr->shape()); std::vector operands_list = { a, b, scales_f32[0], scales_f32[1], one, one}; @@ -917,20 +934,15 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { // Slice the result of the GEMM if the operands were padded. HloInstruction *slice = nullptr; - if (new_output_shape.dimensions() != instr->shape().dimensions() && !has_matrix_bias) { + if (new_output_shape.dimensions() != instr->shape().dimensions()) { std::vector start_indices(instr->shape().rank(), 0); std::vector strides(instr->shape().rank(), 1); slice = instr->AddInstruction(HloInstruction::CreateSlice( instr->shape(), new_custom_call, start_indices, instr->shape().dimensions(), strides)); - std::cout << "CCCCCCCCCCCCCCCCCCCCCCCCCC\n"; } - -// TF_RETURN_IF_ERROR( -// ReplaceInstruction(add ? add : instr, slice ? slice : new_custom_call)); - std::cout <<"exiting CreateF8CustomCall!\n"; -TF_RETURN_IF_ERROR( - ReplaceInstruction(instr, slice?slice: new_custom_call)); + TF_RETURN_IF_ERROR( + ReplaceInstruction(add ? add : instr, slice ? slice : new_custom_call)); return true; } @@ -1075,16 +1087,18 @@ TF_RETURN_IF_ERROR( } Status FuseMatrixBiasAdd(HloInstruction *instr, HloInstruction *bias, - HloInstruction *gemm, - HloInstruction *bitcast = nullptr) { + const HloInstruction *gemm, + HloInstruction *bitcast = nullptr, + HloInstruction *slice = nullptr) { + // return OkStatus(); TF_RET_CHECK(bias->shape() == (bitcast ? bitcast->shape() : gemm->shape())); - + VLOG(1) << "FuseMatrixBiasAdd:1111111111111111111\n"; // Do not fuse bias into S32 GEMM, as for this datatype cuBLAS only // supports fixed values for alpha/beta. if (gemm->shape().element_type() == S32) { return OkStatus(); } - std::cout << "11111111111111111111111111\n"; + VLOG(1) << "FuseMatrixBiasAdd:22222222222222\n"; // Cublas gemm overwrites the bias matrix, so fusion is only possible if the // gemm is the only user. CublasLt gemm can operate out-of-place. bool can_overwrite_bias = [bias]() { @@ -1114,36 +1128,33 @@ TF_RETURN_IF_ERROR( return in_out_alias_config.ParameterHasAlias(bias->parameter_number(), /*param_index=*/{}); }(); - bool want_to_fuse_bias = IsCublasLtMatmulF8(*gemm) || IsCublasLtMatmul(*gemm) || can_overwrite_bias; + bool want_to_fuse_bias = IsCublasLtMatmulF8(*gemm) || IsCublasLtMatmul(*gemm) || can_overwrite_bias; auto config = gemm->backend_config().value(); - + VLOG(1) << "FuseMatrixBiasAdd:3333333333333333\n"; // It is possible to fuse into a cublasLt matmul that already has a vector // bias, but no other epilogue will commute with the matrix bias add. bool supported_epilogue = ((config.epilogue() == GemmBackendConfig::DEFAULT) || (config.epilogue() == GemmBackendConfig::BIAS)); - std::cout << "22222222222222222222\n"; - std::cout << config.beta() << "; " << want_to_fuse_bias << "; " + VLOG(1) << config.beta() << "; " << want_to_fuse_bias << "; " << gemm->user_count() << "; " << supported_epilogue << "; \n"; if ((config.beta() != 0) || !want_to_fuse_bias || (gemm->user_count() != 1) || !supported_epilogue) { return OkStatus(); } - std::cout << "333333333333333333\n"; + config.set_beta(1.0); std::vector operands(gemm->operands().begin(), gemm->operands().end()); - HloInstruction *broadcast_bias = MaybeConstantFoldBias(bias); - // if (gemm->custom_call_target() == kCublasLtMatmulF8CallTarget) { - // operands.at(2) = broadcast_bias; - // } else - - std::cout << "44444444444444444444\n"; - - if (gemm->custom_call_target() == kCublasLtMatmulF8CallTarget) { - absl::Span batch_dims = + HloInstruction* broadcast_bias = MaybeConstantFoldBias(bias); + if (slice) { + broadcast_bias = instr->AddInstruction(HloInstruction::CreateBitcast(slice->shape(),broadcast_bias)); + } + VLOG(1) << broadcast_bias->ToString(); + if(bitcast) VLOG(1) << bitcast->ToString(); + absl::Span batch_dims = config.dot_dimension_numbers().rhs_batch_dimensions(); // Get the padded shape. auto pad_shape = [&batch_dims](const Shape old_shape) { @@ -1179,63 +1190,15 @@ TF_RETURN_IF_ERROR( } return; }; - HloInstruction *a = operands[0]; - HloInstruction *b = operands[1]; - pad_operand(a); - pad_operand(b); - operands[0] = a; - operands[1] = b; pad_operand(broadcast_bias); + VLOG(1) << "FuseMatrixBiasAdd:44444444444444444444\n"; + // if (gemm->custom_call_target() == kCublasLtMatmulF8CallTarget) { + // operands.at(2) = broadcast_bias; + // } else { operands.insert(operands.begin() + 2, broadcast_bias); + // } - Shape new_output_shape = pad_shape(instr->shape()); - bool need_padding = - new_output_shape.dimensions() != instr->shape().dimensions(); - // HloInstruction *new_custom_call = - // gemm->parent()->AddInstruction(HloInstruction::CreateCustomCall( - // ShapeUtil::MakeShapeWithDenseLayout( - // instr->shape().element_type(), new_output_shape.dimensions(), - // instr->shape().layout().minor_to_major()), - // operands, kCublasLtMatmulF8CallTarget)); - std::unique_ptr new_custom_call = - // gemm->CloneWithNewShape(new_output_shape, operands); - gemm->CloneWithNewOperands(new_output_shape,operands); - TF_RETURN_IF_ERROR(new_custom_call->set_backend_config(config)); - - std::cout << "555555555555555\n"; - - std::cout << "6666666666666666\n"; - std::cout << "need_padding"<<": "<< need_padding <<"\n"; - TF_RETURN_IF_ERROR(SetName(instr->GetModule(), new_custom_call.get())); - // Slice the result of the GEMM if the operands were padded. - HloInstruction *slice = nullptr; - if (need_padding) { - std::vector start_indices(instr->shape().rank(), 0); - std::vector strides(instr->shape().rank(), 1); - slice = instr->AddInstruction(HloInstruction::CreateSlice( - instr->shape(), new_custom_call.get(), start_indices, - instr->shape().dimensions(), strides)); - - } - if (slice) { - new_custom_call = slice->CloneWithNewOperands( - slice->shape(), {slice->parent()->AddInstruction(std::move(new_custom_call))}); - } - std::cout << new_custom_call.get()->ToString() << std::endl; - std::cout << instr->ToString() << std::endl; - std::cout << instr->parent()->ToString() << std::endl; - std::cout << slice->ToString() << std::endl; - std::cout << slice->parent()->ToString() << std::endl; - std::cout << "7777777777777777777\n"; - // return ReplaceInstruction(instr, slice); - return ReplaceWithNewInstruction(instr,std::move(new_custom_call)); - - - - } else { - operands.insert(operands.begin() + 2, broadcast_bias); - - std::unique_ptr fused_op = + std::unique_ptr fused_op = gemm->CloneWithNewOperands(gemm->shape(), operands); TF_RETURN_IF_ERROR(fused_op->set_backend_config(config)); @@ -1261,38 +1224,45 @@ TF_RETURN_IF_ERROR( ->set_output_to_operand_aliasing({{{}, {2, {}}}}); } TF_RETURN_IF_ERROR(SetName(instr->GetModule(), fused_op.get())); - + if (slice != nullptr) { + fused_op = slice->CloneWithNewOperands( + slice->shape(), + {slice->parent()->AddInstruction(std::move(fused_op))}); + } + if (bitcast != nullptr) { fused_op = bitcast->CloneWithNewOperands( bitcast->shape(), {bitcast->parent()->AddInstruction(std::move(fused_op))}); } - std::cout <<"at exiting fusematrixbiasadd:\n"; - std::cout << instr->ToString() << std::endl; - std::cout <<"----------------------------\n"; - std::cout << fused_op.get()->ToString() << std::endl; - std::cout <<"at exiting fusematrixbiasadd print done\n"; - return ReplaceWithNewInstruction(instr, std::move(fused_op)); - } - + return ReplaceWithNewInstruction(instr, std::move(fused_op)); } StatusOr FuseVectorBiasAdd(HloInstruction *instr, HloInstruction *broadcast, HloInstruction *gemm, HloInstruction *slice = nullptr, - HloInstruction *convert = nullptr) { + HloInstruction *convert = nullptr, + HloInstruction *bitcast = nullptr) { + VLOG(1) << "FuseVectorBiasAdd: 1"; + VLOG(1) << bitcast->ToString(); + VLOG(1) << instr->ToString(); + VLOG(1) << broadcast->ToString(); + VLOG(1) << gemm->ToString(); + VLOG(1) << "FuseVectorBiasAdd: 1 ends"; + if (!bitcast) { TF_RET_CHECK(ShapeUtil::Compatible( broadcast->shape(), (slice ? slice->shape() : gemm->shape()))); - + } // Verify that the data type is supported by Epilogue Fusion. if (!SupportsEpilogueFusion(gemm->shape().element_type())) { return false; } HloInstruction *bias = broadcast->mutable_operand(0); - + VLOG(1) << bias->ToString(); + VLOG(1) << "FuseVectorBiasAdd: 2 ends"; TF_ASSIGN_OR_RETURN(auto config, gemm->backend_config()); // # output column dims == # non-contracting rhs operand dims. @@ -1306,26 +1276,30 @@ TF_RETURN_IF_ERROR( (bias->shape().rank() != num_col_dims)) { return false; } + VLOG(1) << "FuseVectorBiasAdd: 3 ends"; // We require the bias vector to have been broadcast in the most major // dimensions; i.e. its most minor physical dimensions align with most minor // physical dimensions of the gemm output. absl::Span broadcast_dims = broadcast->dimensions(); - for (size_t i = 0; i < num_col_dims; ++i) { - int64_t dim = gemm->shape().layout().minor_to_major(i); - - // Find the corresponding dimension from the bias vector. - auto it = absl::c_find(broadcast_dims, dim); - - if (it == broadcast_dims.end()) { - return false; - } - - int64_t vector_dim = it - broadcast_dims.begin(); - if (bias->shape().layout().minor_to_major(i) != vector_dim) { - return false; - } + if (bitcast) { + broadcast_dims = gemm->shape().dimensions(); } - + // for (size_t i = 0; i < num_col_dims; ++i) { + // int64_t dim = gemm->shape().layout().minor_to_major(i); + + // // Find the corresponding dimension from the bias vector. + // auto it = absl::c_find(broadcast_dims, dim); + + // if (it == broadcast_dims.end()) { + // return false; + // } + + // int64_t vector_dim = it - broadcast_dims.begin(); + // if (bias->shape().layout().minor_to_major(i) != vector_dim) { + // return false; + // } + // } + VLOG(1) << "FuseVectorBiasAdd: 4 ends"; std::vector operands(gemm->operands().begin(), gemm->operands().end()); // When (non-trivial) matrix and vector bias co-exist for FP8 matmul, just @@ -1334,13 +1308,12 @@ TF_RETURN_IF_ERROR( config.beta() != 0.0) { return true; } - + VLOG(1) << "FuseVectorBiasAdd: 5 ends"; if (gemm->custom_call_target() == kCublasLtMatmulF8CallTarget && bias->shape().element_type() == F32) { if (convert == nullptr) { return false; } - HloInstruction *bias_f16_or_bf16 = convert->mutable_operand(0); auto compatible_bias_type = [](const PrimitiveType bias_type, const PrimitiveType output_type) { @@ -1369,7 +1342,50 @@ TF_RETURN_IF_ERROR( return false; } } + VLOG(1) << "FuseVectorBiasAdd: 6 ends"; + + absl::Span batch_dims = + config.dot_dimension_numbers().rhs_batch_dimensions(); + // Get the padded shape. + auto pad_shape = [&batch_dims](const Shape old_shape) { + Shape padded_shape = old_shape; + for (int i = 0; i < old_shape.rank(); ++i) { + if (!absl::c_linear_search(batch_dims, i)) { + int64_t padded_dimension = + RoundUpTo(old_shape.dimensions(i), 16); + padded_shape.set_dimensions(i, padded_dimension); + } + } + return padded_shape; + }; + // Pad the non-batch dimensions of the operands to multiples of 16 as + // required by cuBLASLt. + auto pad_operand = [&instr, &pad_shape](HloInstruction *&x) -> void { + PaddingConfig padding_config; + Shape padded_shape = pad_shape(x->shape()); + for (int i = 0; i < x->shape().rank(); ++i) { + auto dimension = padding_config.add_dimensions(); + dimension->set_edge_padding_low(0); + dimension->set_edge_padding_high(padded_shape.dimensions(i) - + x->shape().dimensions(i)); + dimension->set_interior_padding(0); + } + if (!ShapeUtil::Equal(padded_shape, x->shape())) { + HloInstruction *zero = + instr->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(x->shape().element_type()))); + x = instr->AddInstruction( + HloInstruction::CreatePad(padded_shape, x, zero, padding_config)); + } + return; + }; + if (bitcast) { + // bias = instr->AddInstruction(HloInstruction::CreateBitcast(slice->shape(),broadcast_bias)); + pad_operand(bias); + VLOG(1) << "FuseVectorBiasAdd: 7 ends"; + VLOG(1) << bias->ToString(); + } // Replace add(gemm, broadcast) with fused new_gemm. operands.push_back(bias); config.set_epilogue(GemmBackendConfig::BIAS); @@ -1381,8 +1397,22 @@ TF_RETURN_IF_ERROR( result = slice->CloneWithNewOperands( slice->shape(), {slice->parent()->AddInstruction(std::move(result))}); } + VLOG(1) << "FuseVectorBiasAdd: 7.5 ends"; + VLOG(1) << instr->ToString(); + VLOG(1) << result.get()->ToString(); + + if (bitcast != nullptr) { + result = bitcast->CloneWithNewOperands( + bitcast->shape(), + {bitcast->parent()->AddInstruction(std::move(result))}); + } + VLOG(1) << "FuseVectorBiasAdd: 7.75 ends"; + VLOG(1) << instr->ToString(); + VLOG(1) << result.get()->ToString(); TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(instr, std::move(result))); + VLOG(1) << "FuseVectorBiasAdd: 8 ends"; + VLOG(1) << instr->ToString(); return true; } diff --git a/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc index 78c8fbd4b42bbb..9aa93664ff7a91 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc @@ -1222,21 +1222,104 @@ TEST_P(ParameterizedFp8GemmRewriteTest, Rank3ScaledABUnscaledDVectorBiasF8) { ; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f8e4m3fn[64,16]{1,0} bitcast([[P0]]) ; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(1) ; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[32,16]{1,0} transpose([[P1]]), dimensions={1,0} +; CHECK-NEXT: [[P2:%[^ ]+]] = f16[] parameter(3) +; CHECK-NEXT: [[P2_CV:%[^ ]+]] = f32[] convert([[P2]]) +; CHECK-NEXT: [[P3:%[^ ]+]] = f16[] parameter(4) +; CHECK-NEXT: [[P3_CV:%[^ ]+]] = f32[] convert([[P3]]) +; CHECK-NEXT: [[C:%[^ ]+]] = f32[] constant(1) ; CHECK-NEXT: [[B:%[^ ]+]] = f32[32]{0} parameter(2) ; CHECK-NEXT: [[B_F16:%[^ ]+]] = f16[32]{0} convert([[B]]) -; CHECK-NEXT: [[B_BCAST:%[^ ]+]] = f16[4,16,32]{2,1,0} broadcast([[B_F16]]), dimensions={2} -; CHECK-NEXT: [[B_BITCAST:%[^ ]+]] = f16[64,32]{1,0} bitcast([[B_BCAST]]) +; CHECK-NEXT: [[GEMM:%[^ ]+]] = f16[64,32]{1,0} custom-call([[P0_BITCAST]], [[P1_TRANSPOSE]], [[P2_CV]], [[P3_CV]], [[C]], /*index=5*/[[C]], [[B_F16]]), +; CHECK: custom_call_target="__cublas$lt$matmul$f8", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":0 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"BIAS\" +; CHECK: }" +; CHECK: ROOT [[OUT:%[^ ]+]] = f16[4,16,32]{2,1,0} bitcast([[GEMM]]) + )"); +} + +TEST_P(ParameterizedFp8GemmRewriteTest, + Rank3ScaledABUnscaledDVectorBiasPaddedF8) { +#if CUDA_VERSION < 12000 + GTEST_SKIP() << "A matrix bias on a matmul is only supported in CUDA 12"; +#endif + const char* hlo_text = R"( + HloModule test + ENTRY test { + x = f8e4m3fn[4,15,15] parameter(0) + y = f8e4m3fn[15,31] parameter(1) + b = f32[31] parameter(2) + b_f16 = f16[31] convert(b) + b_bcast = f16[4,15,31] broadcast(b_f16), dimensions={2} + x_f16 = f16[4,15,15] convert(x) + y_f16 = f16[15,31] convert(y) + x_scale = f16[] parameter(3) + y_scale = f16[] parameter(4) + x_scale_bcast = f16[4,15,15] broadcast(x_scale), dimensions={} + y_scale_bcast = f16[15,31] broadcast(y_scale), dimensions={} + x_unscaled = f16[4,15,15] multiply(x_f16, x_scale_bcast) + x_unscaled_bitcast = f16[60,15] bitcast(x_unscaled) + y_unscaled = f16[15,31] multiply(y_f16, y_scale_bcast) + dot_a = f16[60,31] dot(x_unscaled_bitcast, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0} + dot_a_bitcast = f16[4,15,31]{2,1,0} bitcast(dot_a) + ROOT out = f16[4,15,31] add(dot_a_bitcast, b_bcast) + } +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + GemmRewriter pass( + se::CudaComputeCapability{se::CudaComputeCapability::HOPPER, 0}); + TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); + EXPECT_TRUE(changed); + + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Bitcast(m::Slice(m::CustomCall({"__cublas$lt$matmul$f8"}) + .WithShape(F16, {64, 32})) + .WithShape(F16, {60, 31})) + .WithShape(F16, {4, 15, 31}))); + + RunAndFilecheckHloRewrite(hlo_text, + GemmRewriter(se::CudaComputeCapability{ + se::CudaComputeCapability::HOPPER, 0}), + R"( +; CHECK-LABEL: ENTRY %test (x: f8e4m3fn[4,15,15], y: f8e4m3fn[15,31], b: f32[31], x_scale: f16[], y_scale: f16[]) -> f16[4,15,31] { +; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fn[4,15,15]{2,1,0} parameter(0) +; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f8e4m3fn[60,15]{1,0} bitcast([[P0]]) +; CHECK-NEXT: [[C1:%[^ ]+]] = f8e4m3fn[] constant(0) +; CHECK-NEXT: [[P0_PAD:%[^ ]+]] = f8e4m3fn[64,16]{1,0} pad([[P0_BITCAST]], [[C1]]), padding=0_4x0_1 +; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[15,31]{1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[31,15]{1,0} transpose([[P1]]), dimensions={1,0} +; CHECK-NEXT: [[C2:%[^ ]+]] = f8e4m3fn[] constant(0) +; CHECK-NEXT: [[P1_PAD:%[^ ]+]] = f8e4m3fn[32,16]{1,0} pad([[P1_TRANSPOSE]], [[C2]]), padding=0_1x0_1 ; CHECK-NEXT: [[P2:%[^ ]+]] = f16[] parameter(3) ; CHECK-NEXT: [[P2_CV:%[^ ]+]] = f32[] convert([[P2]]) ; CHECK-NEXT: [[P3:%[^ ]+]] = f16[] parameter(4) ; CHECK-NEXT: [[P3_CV:%[^ ]+]] = f32[] convert([[P3]]) ; CHECK-NEXT: [[C:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[GEMM:%[^ ]+]] = f16[64,32]{1,0} custom-call([[P0_BITCAST]], [[P1_TRANSPOSE]], [[B_BITCAST]], [[P2_CV]], [[P3_CV]], /*index=5*/[[C]], [[C]]), +; CHECK-NEXT: [[B:%[^ ]+]] = f32[31]{0} parameter(2) +; CHECK-NEXT: [[B_F16:%[^ ]+]] = f16[31]{0} convert([[B]]) +; CHECK-NEXT: [[C3:%[^ ]+]] = f16[] constant(0) +; CHECK-NEXT: [[P2_PAD:%[^ ]+]] = f16[32]{0} pad([[B_F16]], [[C3]]), padding=0_1 +; CHECK-NEXT: [[GEMM:%[^ ]+]] = f16[64,32]{1,0} custom-call([[P0_PAD]], [[P1_PAD]], [[P2_CV]], [[P3_CV]], [[C]], /*index=5*/[[C]], [[P2_PAD]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config="{ ; CHECK-DAG: \"alpha_real\":1 ; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":1 +; CHECK-DAG: \"beta\":0 ; CHECK-DAG: \"dot_dimension_numbers\":{ ; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] ; CHECK-DAG: \"rhs_contracting_dimensions\":[\"1\"] @@ -1246,9 +1329,10 @@ TEST_P(ParameterizedFp8GemmRewriteTest, Rank3ScaledABUnscaledDVectorBiasF8) { ; CHECK-DAG: \"precision_config\":{ ; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] ; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"DEFAULT\" +; CHECK-DAG: \"epilogue\":\"BIAS\" ; CHECK: }" -; CHECK: ROOT [[OUT:%[^ ]+]] = f16[4,16,32]{2,1,0} bitcast([[GEMM]]) +; CHECK-NEXT: [[SLICE:%[^ ]+]] = f16[60,31]{1,0} slice([[GEMM]]), slice={[0:60], [0:31]} +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f16[4,15,31]{2,1,0} bitcast([[SLICE]]) )"); } @@ -1332,21 +1416,21 @@ TEST_P(ParameterizedFp8GemmRewriteTest, const char* hlo_text = R"( HloModule test ENTRY test { - x = f8e4m3fn[4,15,15] parameter(0) + x = f8e4m3fn[3,15,15] parameter(0) y = f8e4m3fn[15,31] parameter(1) - b = f32[4,15,31] parameter(2) - x_f32 = f32[4,15,15] convert(x) + b = f32[3,15,31] parameter(2) + x_f32 = f32[3,15,15] convert(x) y_f32 = f32[15,31] convert(y) x_scale = f32[] parameter(3) y_scale = f32[] parameter(4) - x_scale_bcast = f32[4,15,15] broadcast(x_scale), dimensions={} + x_scale_bcast = f32[3,15,15] broadcast(x_scale), dimensions={} y_scale_bcast = f32[15,31] broadcast(y_scale), dimensions={} - x_unscaled = f32[4,15,15] multiply(x_f32, x_scale_bcast) - x_unscaled_bitcast = f32[60,15] bitcast(x_unscaled) + x_unscaled = f32[3,15,15] multiply(x_f32, x_scale_bcast) + x_unscaled_bitcast = f32[45,15] bitcast(x_unscaled) y_unscaled = f32[15,31] multiply(y_f32, y_scale_bcast) - dot_a = f32[60,31] dot(x_unscaled_bitcast, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0} - dot_a_bitcast = f32[4,15,31]{2,1,0} bitcast(dot_a) - ROOT out = f32[4,15,31] add(dot_a_bitcast, b) + dot_a = f32[45,31] dot(x_unscaled_bitcast, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0} + dot_a_bitcast = f32[3,15,31]{2,1,0} bitcast(dot_a) + ROOT out = f32[3,15,31] add(dot_a_bitcast, b) } )"; @@ -1357,26 +1441,34 @@ TEST_P(ParameterizedFp8GemmRewriteTest, TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); EXPECT_TRUE(changed); - EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch(m::Bitcast(m::CustomCall({"__cublas$lt$matmul$f8"}) - .WithShape(F32, {64, 32})) - .WithShape(F32, {4, 16, 32}))); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Bitcast(m::Slice(m::CustomCall({"__cublas$lt$matmul$f8"}) + .WithShape(F32, {48, 32})) + .WithShape(F32, {45, 31})) + .WithShape(F32, {3, 15, 31}))); RunAndFilecheckHloRewrite(hlo_text, GemmRewriter(se::CudaComputeCapability{ se::CudaComputeCapability::HOPPER, 0}), R"( -; CHECK-LABEL: ENTRY %test (x: f8e4m3fn[4,15,15], y: f8e4m3fn[15,31], b: f32[4,15,31], x_scale: f32[], y_scale: f32[]) -> f32[4,15,31] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fn[4,15,15]{2,1,0} parameter(0) -; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f8e4m3fn[60,15]{1,0} bitcast([[P0]]) -; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(1) -; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[32,16]{1,0} transpose([[P1]]), dimensions={1,0} -; CHECK-NEXT: [[B:%[^ ]+]] = f32[4,16,32]{2,1,0} parameter(2) -; CHECK-NEXT: [[B_BITCAST:%[^ ]+]] = f32[64,32]{1,0} bitcast([[B]]) +; CHECK-LABEL: ENTRY %test (x: f8e4m3fn[3,15,15], y: f8e4m3fn[15,31], b: f32[3,15,31], x_scale: f32[], y_scale: f32[]) -> f32[3,15,31] { +; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fn[3,15,15]{2,1,0} parameter(0) +; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f8e4m3fn[45,15]{1,0} bitcast([[P0]]) +; CHECK-NEXT: [[C1:%[^ ]+]] = f8e4m3fn[] constant(0) +; CHECK-NEXT: [[P0_PADDED:%[^ ]+]] = f8e4m3fn[48,16]{1,0} pad([[P0_BITCAST]], [[C1]]), padding=0_3x0_1 +; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[15,31]{1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[31,15]{1,0} transpose([[P1]]), dimensions={1,0} +; CHECK-NEXT: [[C2:%[^ ]+]] = f8e4m3fn[] constant(0) +; CHECK-NEXT: [[P1_PADDED:%[^ ]+]] = f8e4m3fn[32,16]{1,0} pad([[P1_TRANSPOSE]], [[C2]]), padding=0_1x0_1 +; CHECK-NEXT: [[B:%[^ ]+]] = f32[3,15,31]{2,1,0} parameter(2) +; CHECK-NEXT: [[B_BITCAST:%[^ ]+]] = f32[45,31]{1,0} bitcast([[B]]) +; CHECK-NEXT: [[C3:%[^ ]+]] = f32[] constant(0) +; CHECK-NEXT: [[P2_PADDED:%[^ ]+]] = f32[48,32]{1,0} pad([[B_BITCAST]], [[C3]]), padding=0_3x0_1 ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(3) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(4) ; CHECK-NEXT: [[C:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[GEMM:%[^ ]+]] = f32[64,32]{1,0} custom-call([[P0_BITCAST]], [[P1_TRANSPOSE]], [[B_BITCAST]], [[P2]], [[P3]], /*index=5*/[[C]], [[C]]), +; CHECK-NEXT: [[GEMM:%[^ ]+]] = f32[48,32]{1,0} custom-call([[P0_PADDED]], [[P1_PADDED]], [[P2_PADDED]], [[P2]], [[P3]], /*index=5*/[[C]], [[C]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config="{ ; CHECK-DAG: \"alpha_real\":1 @@ -1393,8 +1485,8 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ; CHECK-DAG: } ; CHECK-DAG: \"epilogue\":\"DEFAULT\" ; CHECK: }" -; CHECK: ROOT [[OUT:%[^ ]+]] = f32[4,16,32]{2,1,0} bitcast([[GEMM]]) -; CHECK: %slice = f32[60,31]{1,0} slice(%cublas-gemm.2), slice={[0:60], [0:31]} +; CHECK-NEXT: [[SLICE:%[^ ]+]] = f32[45,31]{1,0} slice([[GEMM]]), slice={[0:45], [0:31]} +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[3,15,31]{2,1,0} bitcast([[SLICE]]) )"); } From b0db89270cb6b200241c3267c526b19a39f29214 Mon Sep 17 00:00:00 2001 From: shuw Date: Sat, 10 Jun 2023 21:11:30 -0700 Subject: [PATCH 011/410] F8 all pass --- .../compiler/xla/service/gpu/gemm_rewriter.cc | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc b/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc index 52f8f808687900..1f8569e1467387 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc @@ -505,7 +505,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { HloInstruction *optional_convert = nullptr; HloInstruction *optional_bitcast = nullptr; VLOG(1)<<"Yes, in HandleAdd!\n"; - VLOG(1)<<"shuw:" << instr->GetModule()->ToString(); + // VLOG(1)<<"shuw:" << instr->GetModule()->ToString(); // only on for rank3 debug // Attempt to elide broadcast and fuse addition of a vector bias into GEMM, // including when slicing is applied to the result. if (Match(instr, @@ -1246,12 +1246,13 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { HloInstruction *convert = nullptr, HloInstruction *bitcast = nullptr) { VLOG(1) << "FuseVectorBiasAdd: 1"; - VLOG(1) << bitcast->ToString(); - VLOG(1) << instr->ToString(); - VLOG(1) << broadcast->ToString(); - VLOG(1) << gemm->ToString(); + // Only for rank3 debug + // if(bitcast) VLOG(1) << bitcast->ToString(); + // if(instr) VLOG(1) << instr->ToString(); + // if(broadcast) VLOG(1) << broadcast->ToString(); + // if(gemm) VLOG(1) << gemm->ToString(); VLOG(1) << "FuseVectorBiasAdd: 1 ends"; - if (!bitcast) { + if (bitcast == nullptr) { TF_RET_CHECK(ShapeUtil::Compatible( broadcast->shape(), (slice ? slice->shape() : gemm->shape()))); } From c96473b84d1d403e23fa3a8abad834b393dcc718 Mon Sep 17 00:00:00 2001 From: shuw Date: Sat, 10 Jun 2023 21:55:20 -0700 Subject: [PATCH 012/410] mising last allocate --- .../compiler/xla/service/gpu/gemm_rewriter.cc | 2 + .../service/gpu/tests/gemm_rewrite_test.cc | 4285 ++++++++++++++++- 2 files changed, 4217 insertions(+), 70 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc b/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc index 1f8569e1467387..6c23c047db480d 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc @@ -1190,7 +1190,9 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { } return; }; + if (gemm->custom_call_target() == kCublasLtMatmulF8CallTarget){ pad_operand(broadcast_bias); + } VLOG(1) << "FuseMatrixBiasAdd:44444444444444444444\n"; // if (gemm->custom_call_target() == kCublasLtMatmulF8CallTarget) { // operands.at(2) = broadcast_bias; diff --git a/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc index 9aa93664ff7a91..4a0e938a8d9d31 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc @@ -64,6 +64,136 @@ class GemmRewriteTest : public GpuCodegenTest { bool tf32_state_; }; +TEST_F(GemmRewriteTest, CheckCustomCallTarget) { + const char* hlo_text = R"( +HloModule SimpleGemm + +ENTRY AddDotsFunc { + x = f32[2,3] parameter(0) + y = f32[3,4] parameter(1) + ROOT dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +)"; + + DebugOptions debug_options = GetDebugOptionsForTest(); + if (debug_options.xla_gpu_enable_cublaslt()) { + MatchOptimizedHlo(hlo_text, + R"(; CHECK: custom_call_target="__cublas$lt$matmul")"); + } else { + MatchOptimizedHlo(hlo_text, + R"(; CHECK: custom_call_target="__cublas$gemm")"); + } +} + +TEST_F(GemmRewriteTest, TestBatchedAutotuning) { + if (GetCudaComputeCapability().IsAtLeast(se::CudaComputeCapability::AMPERE)) { + GTEST_SKIP() + << "There is no autotuning starting with the Nvidia Ampere generation"; + } + const char* hlo_text = R"( +HloModule ComplexDotMultipleNonContracting + +ENTRY %test { + %lhs = f32[7,17,10,13]{3,2,1,0} parameter(0) + %rhs = f32[7,9,10,13,6]{4,3,2,1,0} parameter(1) + ROOT %dot = f32[10,7,17,9,6]{4,3,2,1,0} dot(%lhs, %rhs), lhs_batch_dims={2,0}, rhs_batch_dims={2,0}, lhs_contracting_dims={3}, rhs_contracting_dims={3} +} + +)"; + + MatchOptimizedHlo(hlo_text, + R"( +; CHECK: selected_algorithm + )"); +} + +TEST_F(GemmRewriteTest, SimpleRewriteDeterministic) { + const char* hlo_text = R"( +HloModule SimpleGemm + +ENTRY AddDotsFunc { + x = f32[128,128] parameter(0) + y = f32[128,128] parameter(1) + ROOT dot_a = f32[128,128] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + + ErrorSpec error_spec = [&] { + DebugOptions debug_options = GetDebugOptionsForTest(); + if (debug_options.xla_gpu_enable_cublaslt()) { + return ErrorSpec{1e-3, 1e-3}; + } else { + return ErrorSpec{1e-3, 1e-3}; + } + }(); + + auto get_module = [&]() { + HloModuleConfig config; + DebugOptions debug_options = GetDebugOptionsForTest(); + debug_options.set_xla_gpu_deterministic_ops(true); + config.set_debug_options(debug_options); + return ParseAndReturnVerifiedModule(hlo_text, config); + }; + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr optimized_module, + backend().compiler()->RunHloPasses( + *get_module(), backend().default_stream_executor(), + backend().default_stream_executor()->GetAllocator())); + + StatusOr filecheck_result = RunFileCheck(optimized_module->ToString(), + R"( +; CHECK: custom_call_target="__cublas${{(lt\$matmul|gemm)}}" + )"); + TF_ASSERT_OK(filecheck_result.status()); + EXPECT_TRUE(filecheck_result.value()); + EXPECT_TRUE(RunAndCompare(*get_module(), error_spec)); +} + +TEST_F(GemmRewriteTest, BF16GemmCodeGen) { + const char* hlo_text = R"( +HloModule bf16codegendgemm + +ENTRY bf16gemm { + %parameter.1 = bf16[3]{0} parameter(0) + %parameter.2 = bf16[3]{0} parameter(1) + ROOT %dot.3 = bf16[] dot(bf16[3]{0} %parameter.1, bf16[3]{0} %parameter.2), lhs_contracting_dims={0}, rhs_contracting_dims={0}, operand_precision={highest,highest} +} + )"; + + MatchOptimizedHlo(hlo_text, R"( +; CHECK: [[P1:%[^ ]+]] = bf16[3]{0} parameter(1) +; CHECK: [[INSTR_1:%[^ ]+]] = f32[3]{0} convert([[P1]]) +; CHECK: [[P0:%[^ ]+]] = bf16[3]{0} parameter(0) +; CHECK: [[INSTR_3:%[^ ]+]] = f32[3]{0} convert([[P0]]) +; CHECK: [[INSTR_4:%[^ ]+]] = f32[3]{0} multiply([[INSTR_1]], [[INSTR_3]]) +; CHECK: [[INSTR_5:%[^ ]+]] = f32[] constant(0) +; CHECK: [[INSTR_6:%[^ ]+]] = f32[] reduce([[INSTR_4]], [[INSTR_5]]), dimensions={0}, to_apply=[[INSTR_7:%[^ ]+]] +; CHECK: ROOT [[INSTR_8:%[^ ]+]] = bf16[] convert([[INSTR_6]]) + )"); + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); +} + +TEST_F(GemmRewriteTest, BF16Transpose) { + const char* hlo_text = R"( +HloModule broadcast + +ENTRY broadcast { + p = bf16[9] parameter(0) + ROOT out = bf16[1,9] broadcast(p), dimensions={1} +} +)"; + + MatchOptimizedHlo(hlo_text, R"( +; CHECK: bf16[1,9]{1,0} bitcast +; CHECK: bf16[1,9]{1,0} copy +)"); + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); +} + // A test fixture class for tests which should have similar results with legacy // cublas and cublasLt class ParameterizedGemmRewriteTest @@ -89,13 +219,3967 @@ class ParameterizedGemmRewriteTest return replacements_[kCustomCallTargetPlaceholder]; } - protected: - absl::flat_hash_map replacements_; + protected: + absl::flat_hash_map replacements_; + + private: + static constexpr const char* kCustomCallTargetPlaceholder{ + "<>"}; +}; + +TEST_P(ParameterizedGemmRewriteTest, Simple) { + const char* hlo_text = R"( +HloModule test + +ENTRY test { + x = f32[2,3] parameter(0) + y = f32[3,4] parameter(1) + ROOT dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +)"; + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + MatchOptimizedHlo(hlo_text, + R"( +; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4]) -> f32[2,4] { +; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,4]{1,0} custom-call([[P0]], [[P1]]), +; CHECK: custom_call_target="<>", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":0 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"DEFAULT\" +; CHECK: }" +)"); +} + +TEST_P(ParameterizedGemmRewriteTest, SimpleRewrite) { + const char* hlo_text = R"( +HloModule SimpleGemm + +ENTRY AddDotsFunc { + x = f32[2,3] parameter(0) + y = f32[3,4] parameter(1) + ROOT dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +)"; + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + MatchOptimizedHlo(hlo_text, + R"( +; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,3], y: f32[3,4]) -> f32[2,4] { +; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,4]{1,0} custom-call([[P0]], [[P1]]), +; CHECK: custom_call_target="<>", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":0 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"DEFAULT\" +; CHECK: }" +)"); +} + +TEST_P(ParameterizedGemmRewriteTest, MultipleContractingDims) { + const char* hlo_text = R"( +HloModule MultipleContractingCheckGemm + +ENTRY AddDotsFunc { + x = f32[3,4,2] parameter(0) + y = f32[3,4,5] parameter(1) + ROOT dot_a = f32[2,5] dot(x, y), lhs_contracting_dims={0,1}, rhs_contracting_dims={0,1} +} + +)"; + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + MatchOptimizedHlo(hlo_text, + R"( +; CHECK-NOT: copy +; +; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[3,4,2], y: f32[3,4,5]) -> f32[2,5] { +; CHECK-NEXT: [[P0:%[^ ]+]] = f32[3,4,2]{2,1,0} parameter(0) +; CHECK-DAG: [[P1:%[^ ]+]] = f32[3,4,5]{2,1,0} parameter(1) +; CHECK-DAG: [[BITCAST0:%[^ ]+]] = f32[2,12]{0,1} bitcast([[P0]]) +; CHECK-DAG: [[BITCAST1:%[^ ]+]] = f32[12,5]{1,0} bitcast([[P1]]) +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,5]{1,0} custom-call([[BITCAST0]], [[BITCAST1]]), +; CHECK: custom_call_target="<>", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":0 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"DEFAULT\" +; CHECK: }" +)"); +} + +TEST_P(ParameterizedGemmRewriteTest, ArgTransposeFoldCheck) { + const char* hlo_text = R"( +HloModule ArgTransposeFoldGemm + +ENTRY AddDotsFunc { + x = f32[3,2] parameter(0) + y = f32[3,4] parameter(1) + x_transposed = f32[2,3] transpose(x), dimensions={1, 0} + ROOT dot_a = f32[2,4] dot(x_transposed, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +)"; + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + MatchOptimizedHlo(hlo_text, + R"( +; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[3,2], y: f32[3,4]) -> f32[2,4] { +; CHECK-NEXT: [[P0:%[^ ]+]] = f32[3,2]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,4]{1,0} custom-call([[P0]], [[P1]]), +; CHECK: custom_call_target="<>", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":0 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"DEFAULT\" +; CHECK: }" +)"); +} + +TEST_P(ParameterizedGemmRewriteTest, BatchedArgRowColTransposeFoldCheck) { + const char* hlo_text = R"( +HloModule BatchedArgRowColTransposeFoldGemm + +ENTRY AddDotsFunc { + x = f32[5,3,2] parameter(0) + y = f32[5,3,4] parameter(1) + x_transposed = f32[5,2,3] transpose(x), dimensions={0, 2, 1} + ROOT dot_a = f32[5,2,4] dot(x_transposed, y), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0} +} + +)"; + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3})); + MatchOptimizedHlo(hlo_text, + R"( +; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[5,3,2], y: f32[5,3,4]) -> f32[5,2,4] { +; CHECK-NEXT: [[P0:%[^ ]+]] = f32[5,3,2]{2,1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = f32[5,3,4]{2,1,0} parameter(1) +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[5,2,4]{2,1,0} custom-call([[P0]], [[P1]]), +; CHECK: custom_call_target="<>", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":0 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[\"0\"] +; CHECK-DAG: \"rhs_batch_dimensions\":[\"0\"] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"DEFAULT\" +; CHECK: }" +)"); +} + +TEST_P(ParameterizedGemmRewriteTest, BatchRowTransposeFoldCheck) { + const char* hlo_text = R"( +HloModule BatchRowTransposeFoldCheck + +ENTRY AddDotsFunc { + x = f32[2,5,3] parameter(0) + y = f32[5,3,4] parameter(1) + x_transposed = f32[5,2,3] transpose(x), dimensions={1, 0, 2} + ROOT dot_a = f32[5,2,4] dot(x_transposed, y), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0} +} + +)"; + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + MatchOptimizedHlo(hlo_text, + R"( +; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,5,3], y: f32[5,3,4]) -> f32[5,2,4] { +; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,5,3]{2,1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = f32[5,3,4]{2,1,0} parameter(1) +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[5,2,4]{2,1,0} custom-call([[P0]], [[P1]]), +; CHECK: custom_call_target="<>", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":0 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"2\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_batch_dimensions\":[\"0\"] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"DEFAULT\" +; CHECK: }" +)"); +} + +TEST_P(ParameterizedGemmRewriteTest, BatchFromMinorDimTransposeIsNotFolded) { + const char* hlo_text = R"( +HloModule BatchFromMinorDimTransposeDoesntFold + +ENTRY AddDotsFunc { + x = f32[3,2,5] parameter(0) + y = f32[5,3,4] parameter(1) + x_transposed = f32[5,2,3] transpose(x), dimensions={2, 1, 0} + ROOT dot_a = f32[5,2,4] dot(x_transposed, y), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0} +} + +)"; + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + MatchOptimizedHlo(hlo_text, + R"( +; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[3,2,5], y: f32[5,3,4]) -> f32[5,2,4] { +; CHECK-NEXT: [[P0:%[^ ]+]] = f32[3,2,5]{2,1,0} parameter(0) +; CHECK-DAG: [[P1:%[^ ]+]] = f32[5,3,4]{2,1,0} parameter(1) +; CHECK-DAG: [[FUSION:%[^ ]+]] = f32[5,2,3]{2,1,0} transpose([[P0]]) +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[5,2,4]{2,1,0} custom-call([[FUSION]], [[P1]]), +; CHECK: custom_call_target="<>", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":0 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"2\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[\"0\"] +; CHECK-DAG: \"rhs_batch_dimensions\":[\"0\"] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"DEFAULT\" +; CHECK: }" +)"); +} + +TEST_P(ParameterizedGemmRewriteTest, LargeBatch) { + const char* hlo_text = R"( +HloModule BatchedArgRowColTransposeFoldGemm + +ENTRY AddDotsFunc { + x = f32[20000,4,3,2] parameter(0) + y = f32[20000,4,3,4] parameter(1) + ROOT dot_a = f32[20000,4,2,4] dot(x, y), lhs_contracting_dims={2}, rhs_contracting_dims={2}, lhs_batch_dims={0,1}, rhs_batch_dims={0,1} +} + +)"; + + // Batch sizes larger than 2^16-1 are not supported by cublasLt. Ensure that + // the custom_call_target is __cublas$gemm. + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3})); + MatchOptimizedHlo(hlo_text, + R"( +; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[20000,4,3,2], y: f32[20000,4,3,4]) -> f32[20000,4,2,4] { +; CHECK: [[P0:%[^ ]+]] = f32[20000,4,3,2]{3,2,1,0} parameter(0) +; CHECK: [[BC0:%[^ ]+]] = f32[80000,3,2]{2,1,0} bitcast([[P0]]) +; CHECK: [[P1:%[^ ]+]] = f32[20000,4,3,4]{3,2,1,0} parameter(1) +; CHECK: [[BC1:%[^ ]+]] = f32[80000,3,4]{2,1,0} bitcast([[P1]]) +; CHECK: [[OUT:%[^ ]+]] = f32[80000,2,4]{2,1,0} custom-call([[BC0]], [[BC1]]), +; CHECK: custom_call_target="__cublas$gemm", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":0 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[\"0\"] +; CHECK-DAG: \"rhs_batch_dimensions\":[\"0\"] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK: }" +; CHECK: ROOT {{[^ ]+}} = f32[20000,4,2,4]{3,2,1,0} bitcast([[OUT]]) +)"); +} + +TEST_P(ParameterizedGemmRewriteTest, InstrTransposeFoldCheck) { + const char* hlo_text = R"( +HloModule InstrTransposeFoldGemm + +ENTRY AddDotsFunc { + x = f32[2,3] parameter(0) + y = f32[3,4] parameter(1) + dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT out = f32[4,2] transpose(dot_a), dimensions={1, 0} +} + +)"; + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + MatchOptimizedHlo(hlo_text, + R"( +; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,3], y: f32[3,4]) -> f32[4,2] { +; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) +; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[4,2]{1,0} custom-call([[P1]], [[P0]]), +; CHECK: custom_call_target="<>", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":0 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"DEFAULT\" +; CHECK: }" +)"); +} + +TEST_P(ParameterizedGemmRewriteTest, BatchedInstrLayoutTransposed) { + const char* hlo_text = R"( +HloModule BatchedInstrLayoutCheck + +ENTRY AddDotsFunc { + x = f32[5,2,3] parameter(0) + y = f32[5,3,4] parameter(1) + dot_a = f32[5,2,4] dot(x, y), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0} + ROOT out = f32[2,5,4] transpose(dot_a), dimensions={1, 0, 2} +} + +)"; + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + MatchOptimizedHlo(hlo_text, + R"( +; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[5,2,3], y: f32[5,3,4]) -> f32[2,5,4] { +; CHECK-NEXT: [[P0:%[^ ]+]] = f32[5,2,3]{2,1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = f32[5,3,4]{2,1,0} parameter(1) +; CHECK-NEXT: [[GEMM:%[^ ]+]] = f32[5,2,4]{2,0,1} custom-call([[P0]], [[P1]]), +; CHECK: custom_call_target="<>", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":0 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"2\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[\"0\"] +; CHECK-DAG: \"rhs_batch_dimensions\":[\"0\"] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"DEFAULT\" +; CHECK: }" +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,5,4]{2,1,0} bitcast([[GEMM]]) +)"); +} + +TEST_P(ParameterizedGemmRewriteTest, BatchedInstrLayoutBatchNotInMinorDim) { + const char* hlo_text = R"( +HloModule BatchedInstrLayoutBatchNotInMinorDim + +ENTRY AddDotsFunc { + x = f32[5,2,3] parameter(0) + y = f32[5,3,4] parameter(1) + dot_a = f32[5,2,4] dot(x, y), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0} + ROOT out = f32[2,4,5] transpose(dot_a), dimensions={1, 2, 0} +} + +)"; + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + MatchOptimizedHlo(hlo_text, + R"( +; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[5,2,3], y: f32[5,3,4]) -> f32[2,4,5] { +; CHECK-NEXT: [[P0:%[^ ]+]] = f32[5,2,3]{2,1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = f32[5,3,4]{2,1,0} parameter(1) +; CHECK-NEXT: [[GEMM:%[^ ]+]] = f32[5,2,4]{2,1,0} custom-call([[P0]], [[P1]]), +; CHECK: custom_call_target="<>", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":0 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"2\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[\"0\"] +; CHECK-DAG: \"rhs_batch_dimensions\":[\"0\"] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"DEFAULT\" +; CHECK: }" +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,4,5]{2,1,0} [[OP:[^ ]+]]([[GEMM]]) +)"); +} + +TEST_P(ParameterizedGemmRewriteTest, AlphaSimpleRewrite) { + const char* hlo_text = R"( +HloModule AlphaSimpleRewrite + +ENTRY AddDotsFunc { + x = f32[2,2] parameter(0) + y = f32[2,2] parameter(1) + k = f32[] constant(3.0) + k_broadcast = f32[2, 2] broadcast(k), dimensions={} + dot_a = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT dot_a_multiplied = f32[2, 2] multiply(dot_a, k_broadcast) +} + +)"; + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + MatchOptimizedHlo(hlo_text, + R"( +; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,2], y: f32[2,2]) -> f32[2,2] { +; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = f32[2,2]{1,0} parameter(1) +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,2]{1,0} custom-call([[P0]], [[P1]]), +; CHECK: custom_call_target="<>", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":3 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":0 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"DEFAULT\" +; CHECK: }" +)"); +} + +TEST_P(ParameterizedGemmRewriteTest, ComplexAlphaSimpleRewrite) { + const char* hlo_text = R"( +HloModule ComplexAlphaSimpleRewrite + +ENTRY AddDotsFunc { + x = c64[2,2] parameter(0) + y = c64[2,2] parameter(1) + k = c64[] constant((3.0, 3.0)) + k_broadcast = c64[2, 2] broadcast(k), dimensions={} + dot_a = c64[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT dot_a_multiplied = c64[2, 2] multiply(dot_a, k_broadcast) +} + +)"; + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-4, 1e-5})); + MatchOptimizedHlo(hlo_text, + R"( +; CHECK-LABEL: ENTRY %AddDotsFunc (x: c64[2,2], y: c64[2,2]) -> c64[2,2] { +; CHECK-NEXT: [[P0:%[^ ]+]] = c64[2,2]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = c64[2,2]{1,0} parameter(1) +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = c64[2,2]{1,0} custom-call([[P0]], [[P1]]), +; CHECK: custom_call_target="<>", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":3 +; CHECK-DAG: \"alpha_imag\":3 +; CHECK-DAG: \"beta\":0 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"DEFAULT\" +; CHECK: }" +)"); +} + +TEST_P(ParameterizedGemmRewriteTest, AlphaMultipleUsersNoRewrite) { + const char* hlo_text = R"( +HloModule AlphaMultipleUsersNoRewrite + +ENTRY AddDotsFunc { + x = f32[2,2] parameter(0) + y = f32[2,2] parameter(1) + k = f32[] constant(3.0) + k_broadcast = f32[2, 2] broadcast(k), dimensions={} + dot_a = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + dot_a_multiplied = f32[2, 2] multiply(dot_a, k_broadcast) + ROOT out = f32[2,2] add(dot_a_multiplied, dot_a) +} + +)"; + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + MatchOptimizedHlo(hlo_text, + R"( +; CHECK: {{[^ ]+}} = f32[2,2]{1,0} custom-call({{[^,]+}}, {{[^)]+}}), +; CHECK: custom_call_target="<>", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":0 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"DEFAULT\" +; CHECK: }" +)"); +} + +TEST_P(ParameterizedGemmRewriteTest, AlphaVectorNoRewrite) { + const char* hlo_text = R"( +HloModule AlphaVectorNoRewrite + +ENTRY AddDotsFunc { + x = f32[2,2] parameter(0) + y = f32[2,2] parameter(1) + alpha = f32[2] constant({1, 2}) + alpha_broadcast = f32[2,2] broadcast(alpha), dimensions={1} + dot = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT dot_a_multiplied = f32[2, 2] multiply(dot, alpha_broadcast) +} +)"; + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + MatchOptimizedHlo(hlo_text, + R"( +; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,2], y: f32[2,2]) -> f32[2,2] { +; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = f32[2,2]{1,0} parameter(1) +; CHECK-NEXT: [[OUT:%[^ ]+]] = f32[2,2]{1,0} custom-call([[P0]], [[P1]]), +; CHECK: custom_call_target="<>", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":0 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"DEFAULT\" +; CHECK: }" +)"); +} + +TEST_P(ParameterizedGemmRewriteTest, BF16Gemm) { + const char* hlo_text = R"( +HloModule bf16gemm + +ENTRY bf16gemm { + %parameter.1 = bf16[12,4]{1,0} parameter(0) + %parameter.2 = bf16[4,8]{1,0} parameter(1) + ROOT %dot.8 = bf16[12,8] dot(bf16[12,4] %parameter.1, bf16[4,8] %parameter.2), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + )"; + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + + if (GetCudaComputeCapability().IsAtLeast(se::CudaComputeCapability::AMPERE)) { + MatchOptimizedHlo(hlo_text, + R"( +; CHECK: bf16[16,8]{1,0} custom-call(bf16[16,8]{1,0} {{.*}}, bf16[8,8]{1,0} {{.*}}), custom_call_target="<>" + )", + /*print_operand_shape=*/true); + } else { + MatchOptimizedHlo(hlo_text, + R"( +; CHECK: bf16[12,8]{1,0} custom-call(bf16[12,4]{1,0} [[P0:%[^ ]+]], bf16[4,8]{1,0} [[P1:%[^ ]+]]), custom_call_target="<>" + )", + /*print_operand_shape=*/true); + } +} + +TEST_P(ParameterizedGemmRewriteTest, BF16GemmStrided) { + const char* hlo_text = R"( +HloModule bf16gemm + +ENTRY bf16gemm { + %parameter.1 = bf16[3,3,4] parameter(0) + %parameter.2 = bf16[3,3,2] parameter(1) + ROOT %dot.3 = bf16[3,4,2]{2,1,0} dot(bf16[3,3,4]{2,1,0} %parameter.1, bf16[3,3,2]{2,1,0} %parameter.2), lhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1}, operand_precision={highest,highest} +} + + )"; + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + + if (GetCudaComputeCapability().IsAtLeast(se::CudaComputeCapability::AMPERE)) { + MatchOptimizedHlo(hlo_text, + R"( + ; CHECK: bf16[3,8,8]{2,1,0} custom-call(bf16[3,8,8]{2,1,0} {{.*}}, bf16[3,8,8]{2,1,0} {{.*}}), custom_call_target="<>" + )", + /*print_operand_shape=*/true); + } else { + MatchOptimizedHlo(hlo_text, + R"( + ; CHECK: ROOT [[OUT:%[^ ]+]] = bf16[3,4,2]{2,1,0} custom-call(bf16[3,3,4]{2,1,0} [[A:%[^ ]+]], bf16[3,3,2]{2,1,0} [[B:%[^ ]+]]), custom_call_target="<>" + )", + /*print_operand_shape=*/true); + } +} + +TEST_P(ParameterizedGemmRewriteTest, Int8Gemm) { + const char* hlo_text = R"( +HloModule int8gemm + +ENTRY int8gemm { + %parameter.1 = s8[12,4]{1,0} parameter(0) + %parameter.2 = s8[4,8]{1,0} parameter(1) + ROOT %dot.8 = s32[12,8] dot(s8[12,4] %parameter.1, s8[4,8] %parameter.2), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + )"; + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + + if (GetCudaComputeCapability().IsAtLeast(se::CudaComputeCapability::VOLTA)) { + MatchOptimizedHlo(hlo_text, + R"( +; CHECK: s32[12,8]{1,0} custom-call(s8[12,4]{1,0} [[A:%[^ ]+]], s8[4,8]{0,1} [[B:%[^ ]+]]), custom_call_target="__cublas$gemm" + )", + /*print_operand_shape=*/true); + } else { + MatchOptimizedHlo(hlo_text, + R"( +; CHECK: s32[12,8]{1,0} dot(s32[12,4]{1,0} [[A:%[^ ]+]], s32[4,8]{1,0} [[B:%[^ ]+]]), lhs_contracting_dims={1}, rhs_contracting_dims={0} + + )", + /*print_operand_shape=*/true); + } +} + +TEST_P(ParameterizedGemmRewriteTest, Int8GemmNoAlphaRewrite) { + const char* hlo_text = R"( +HloModule int8gemm + +ENTRY int8gemm { + %parameter.1 = s8[12,4]{1,0} parameter(0) + %parameter.2 = s8[4,8]{1,0} parameter(1) + k = s32[] constant(2) + k_broadcast = s32[12,8] broadcast(k), dimensions={} + %dot.8 = s32[12,8] dot(s8[12,4] %parameter.1, s8[4,8] %parameter.2), lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT dot_multiplied = s32[12,8] multiply(%dot.8, k_broadcast) +} + )"; + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + + if (GetCudaComputeCapability().IsAtLeast(se::CudaComputeCapability::VOLTA)) { + MatchOptimizedHlo(hlo_text, + R"( +; CHECK: s32[12,8]{1,0} custom-call(s8[12,4]{1,0} [[A:%[^ ]+]], s8[4,8]{0,1} [[B:%[^ ]+]]), +; CHECK: custom_call_target="__cublas$gemm", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 + )", + /*print_operand_shape=*/true); + } else { + MatchOptimizedHlo(hlo_text, + R"( +; CHECK: s32[12,8]{1,0} dot(s32[12,4]{1,0} [[A:%[^ ]+]], s32[4,8]{1,0} [[B:%[^ ]+]]), lhs_contracting_dims={1}, rhs_contracting_dims={0} + + )", + /*print_operand_shape=*/true); + } +} + +TEST_P(ParameterizedGemmRewriteTest, Int8GemmNoBetaRewrite) { + const char* hlo_text = R"( +HloModule int8gemm + +ENTRY int8gemm { + %parameter.1 = s8[12,4]{1,0} parameter(0) + %parameter.2 = s8[4,8]{1,0} parameter(1) + bias = s32[12,8] parameter(2) + %dot.8 = s32[12,8] dot(s8[12,4] %parameter.1, s8[4,8] %parameter.2), lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT out = s32[12,8] add(%dot.8, bias) +} + )"; + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + + if (GetCudaComputeCapability().IsAtLeast(se::CudaComputeCapability::VOLTA)) { + MatchOptimizedHlo(hlo_text, + R"( +; CHECK: s32[12,8]{1,0} custom-call(s8[12,4]{1,0} [[A:%[^ ]+]], s8[4,8]{0,1} [[B:%[^ ]+]]), +; CHECK: custom_call_target="__cublas$gemm", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":0 + )", + /*print_operand_shape=*/true); + } else { + MatchOptimizedHlo(hlo_text, + R"( +; CHECK: s32[12,8]{1,0} dot(s32[12,4]{1,0} [[A:%[^ ]+]], s32[4,8]{1,0} [[B:%[^ ]+]]), lhs_contracting_dims={1}, rhs_contracting_dims={0} + + )", + /*print_operand_shape=*/true); + } +} + +TEST_P(ParameterizedGemmRewriteTest, Int8GemmNotMultipleOfFour) { + const char* hlo_text = R"( +HloModule int8gemm + +ENTRY int8gemm { + %parameter.1 = s8[13,4]{1,0} parameter(0) + %parameter.2 = s8[4,9]{1,0} parameter(1) + ROOT %dot.9 = s32[13,9] dot(s8[13,4] %parameter.1, s8[4,9] %parameter.2), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + )"; + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + + if (GetCudaComputeCapability().IsAtLeast(se::CudaComputeCapability::VOLTA)) { + MatchOptimizedHlo(hlo_text, + R"( +; CHECK: s32[16,12]{1,0} custom-call(s8[16,4]{1,0} [[A:%[^ ]+]], s8[4,12]{0,1} [[B:%[^ ]+]]), custom_call_target="__cublas$gemm" + )", + /*print_operand_shape=*/true); + } else { + MatchOptimizedHlo(hlo_text, + R"( +; CHECK: s32[13,9]{1,0} dot(s32[13,4]{1,0} [[A:%[^ ]+]], s32[4,9]{1,0} [[B:%[^ ]+]]), lhs_contracting_dims={1}, rhs_contracting_dims={0} + + )", + /*print_operand_shape=*/true); + } +} + +TEST_P(ParameterizedGemmRewriteTest, GemmTypeCombinationCheck) { + std::vector> + type_combinations = { + {"s8", "s32", true}, {"s8", "s8", true}, {"s32", "s32", true}, + {"bf16", "bf16", true}, {"f16", "f16", true}, {"f32", "f32", true}, + {"f64", "f64", true}, {"c64", "c64", true}, {"c128", "c128", true}, + }; + + if (GetCudaComputeCapability().IsAtLeast(se::CudaComputeCapability::VOLTA)) { + // For compute capabilities before volta, we always do upcasting, so it + // would be impossible for this test to fail. That is why we only add these + // cases when the compute capabilit is at least Volta. + std::vector> + more_type_combinations = { + {"s8", "bf16", false}, {"s8", "f16", false}, + {"s8", "f32", false}, {"s8", "f64", false}, + {"s8", "c64", false}, {"s8", "c128", false}, + + {"s32", "f32", false}, {"s32", "f64", false}, + {"s32", "c64", false}, {"s32", "c128", false}, + + {"f16", "bf16", false}, {"f16", "f32", false}, + {"f16", "f64", false}, {"f16", "c64", false}, + {"f16", "c128", false}, + + {"bf16", "f16", false}, {"bf16", "f64", false}, + {"bf16", "c64", false}, {"bf16", "c128", false}, + + {"f32", "f64", false}, {"f32", "c64", false}, + {"f32", "c128", false}, + + {"f64", "c64", false}, {"f64", "c128", false}, + }; + type_combinations.insert(type_combinations.end(), + more_type_combinations.begin(), + more_type_combinations.end()); + } + + for (const auto& type_combination : type_combinations) { + absl::flat_hash_map replacements; + replacements["<>"] = std::get<0>(type_combination); + replacements["<>"] = std::get<1>(type_combination); + const char* hlo_template = R"( + HloModule type_combo + + ENTRY type_combo { + %parameter.1 = <>[4,4]{1,0} parameter(0) + %parameter.2 = <>[4,4]{1,0} parameter(1) + ROOT %dot = <>[4,4] dot(%parameter.1, %parameter.2), lhs_contracting_dims={1}, rhs_contracting_dims={0} + } + )"; + const auto hlo_text = absl::StrReplaceAll(hlo_template, replacements); + if (std::get<2>(type_combination)) { + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3})); + } else { + EXPECT_FALSE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3})); + } + } +} + +TEST_P(ParameterizedGemmRewriteTest, UpcastingBf16ToF64) { + const char* hlo_text = R"( +HloModule test + +ENTRY test { + Arg_0.1 = bf16[4,3]{1,0} parameter(0) + Arg_1.2 = bf16[3,6]{1,0} parameter(1) + ROOT dot.3 = f64[4,6]{1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + GemmRewriter pass(GetCudaComputeCapability()); + TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); + EXPECT_TRUE(changed); + + // This is a type combination which is not supported by cublasLt, expect + // GemmRewriter to choose legacy cublas. + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::CustomCall({"__cublas$gemm"}))); +} + +TEST_P(ParameterizedGemmRewriteTest, UpcastingC64ToC128) { + const char* hlo_text = R"( +HloModule test + +ENTRY test { + Arg_0.1 = c64[4,3]{1,0} parameter(0) + Arg_1.2 = c64[3,6]{1,0} parameter(1) + ROOT dot.3 = c128[4,6]{1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + GemmRewriter pass(GetCudaComputeCapability()); + TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); + EXPECT_TRUE(changed); + + // This is a type combination which is not supported by cublasLt, expect + // GemmRewriter to choose legacy cublas. + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::CustomCall({"__cublas$gemm"}))); +} + +TEST_P(ParameterizedGemmRewriteTest, UpcastingF16ToF32) { + const char* hlo_text = R"( +HloModule test + +ENTRY test { + Arg_0.1 = f16[4,3]{1,0} parameter(0) + Arg_1.2 = f16[3,6]{1,0} parameter(1) + ROOT dot.3 = f32[4,6]{1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + GemmRewriter pass(GetCudaComputeCapability()); + TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); + EXPECT_TRUE(changed); + + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::CustomCall({CustomCallTarget()}))); +} + +TEST_P(ParameterizedGemmRewriteTest, UpcastingF16ToF64) { + const char* hlo_text = R"( +HloModule test + +ENTRY test { + Arg_0.1 = f16[4,3]{1,0} parameter(0) + Arg_1.2 = f16[3,6]{1,0} parameter(1) + ROOT dot.3 = f64[4,6]{1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + GemmRewriter pass(GetCudaComputeCapability()); + TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); + EXPECT_TRUE(changed); + + // This is a type combination which is not supported by cublasLt, expect + // GemmRewriter to choose legacy cublas. + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::CustomCall({"__cublas$gemm"}))); +} + +TEST_P(ParameterizedGemmRewriteTest, UpcastingF32ToF64) { + const char* hlo_text = R"( +HloModule test + +ENTRY test { + Arg_0.1 = f32[4,3]{1,0} parameter(0) + Arg_1.2 = f32[3,6]{1,0} parameter(1) + ROOT dot.3 = f64[4,6]{1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + GemmRewriter pass(GetCudaComputeCapability()); + TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); + EXPECT_TRUE(changed); + + // This is a type combination which is not supported by cublasLt, expect + // GemmRewriter to choose legacy cublas. + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::CustomCall({"__cublas$gemm"}))); +} + +INSTANTIATE_TEST_SUITE_P(CublasTestsBothLegacyAndLt, + ParameterizedGemmRewriteTest, ::testing::Bool()); + +// A test fixture class for tests which are specific to legacy cublas +class LegacyCublasGemmRewriteTest : public GemmRewriteTest { + public: + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = GemmRewriteTest::GetDebugOptionsForTest(); + debug_options.set_xla_gpu_enable_cublaslt(false); + return debug_options; + } +}; + +// Test that the alpha and beta fields of the GemmBackendConfig are updated. +// A bias must be present for the beta value to be set. +// In order to have a bias add fused, the bias term must be overwritable. +// We assume that we may not overwrite parameters of a computation. Hence, we +// use the third parameter to create a new value which can be overwritten and +// will be used as the bias. This negate(param_2) has no semantic use, it simply +// exists so that bias may be overwritten. +TEST_F(LegacyCublasGemmRewriteTest, AlphaBetaRewrite) { + const char* hlo_text = R"( +HloModule NonZeroAlphaBeta + +ENTRY AddDotsFunc { + x = f32[2,2] parameter(0) + y = f32[2,2] parameter(1) + param_2 = f32[2,2] parameter(2) + bias = f32[2,2] negate(param_2) + k = f32[] constant(3.0) + k_broadcast = f32[2, 2] broadcast(k), dimensions={} + dot_a = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + dot_a_multiplied = f32[2, 2] multiply(dot_a, k_broadcast) + ROOT out = f32[2,2] add(dot_a_multiplied, bias) +} + +)"; + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + MatchOptimizedHlo(hlo_text, + R"( +; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,2], y: f32[2,2], param_2: f32[2,2]) -> f32[2,2] { +; CHECK-DAG: [[X:%[^ ]+]] = f32[2,2]{1,0} parameter(0) +; CHECK-DAG: [[Y:%[^ ]+]] = f32[2,2]{1,0} parameter(1) +; CHECK: ROOT [[OUT:%[^ ]+]] = f32[2,2]{1,0} custom-call([[X]], [[Y]], {{[^,)]+}}), +; CHECK: custom_call_target="__cublas$gemm", +; CHECK: output_to_operand_aliasing={{{{}: \(2, {}\)}}}, +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":3 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":1 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"DEFAULT\" +; CHECK: }" +)"); +} + +TEST_F(LegacyCublasGemmRewriteTest, BiasMultipleUsersNoOverwrite) { + const char* hlo_text = R"( +HloModule BiasMultipleUsersNoOverwrite + +ENTRY AddDotsFunc { + x = f32[2,2] parameter(0) + y = f32[2,2] parameter(1) + bias = f32[2,2] parameter(2) + k = f32[] constant(3.0) + k_broadcast = f32[2, 2] broadcast(k), dimensions={} + dot_a = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + dot_a_multiplied = f32[2, 2] multiply(dot_a, k_broadcast) + biased_out = f32[2,2] add(dot_a_multiplied, bias) + ROOT out = f32[2,2] add(biased_out, bias) +} +)"; + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + MatchOptimizedHlo(hlo_text, + R"( +; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,2], y: f32[2,2], bias: f32[2,2]) -> f32[2,2] { +; CHECK-DAG: [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter(0) +; CHECK-DAG: [[P1:%[^ ]+]] = f32[2,2]{1,0} parameter(1) +; CHECK-NEXT: [[GEMM:%[^ ]+]] = f32[2,2]{1,0} custom-call([[P0]], [[P1]]), +; CHECK: custom_call_target="__cublas$gemm", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":3 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":0 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"DEFAULT\" +; CHECK: }" +)"); +} + +TEST_F(LegacyCublasGemmRewriteTest, BiasParameterNoOverwrite) { + const char* hlo_text = R"( +HloModule BiasParameterNoOverwrite + +ENTRY AddDotsFunc { + x = f32[2,2] parameter(0) + y = f32[2,2] parameter(1) + bias = f32[2,2] parameter(2) + dot_a = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT out = f32[2,2] add(dot_a, bias) +} +)"; + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + MatchOptimizedHlo(hlo_text, + R"( +; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,2], y: f32[2,2], bias: f32[2,2]) -> f32[2,2] { +; CHECK-DAG: [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter(0) +; CHECK-DAG: [[P1:%[^ ]+]] = f32[2,2]{1,0} parameter(1) +; CHECK-NEXT: [[GEMM:%[^ ]+]] = f32[2,2]{1,0} custom-call([[P0]], [[P1]]), +; CHECK: custom_call_target="__cublas$gemm", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":0 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"DEFAULT\" +; CHECK: }" +)"); +} + +TEST_F(LegacyCublasGemmRewriteTest, BiasTupleParameterOverwrite) { + const char* hlo_text = R"( +HloModule BiasTupleParameterOverwrite + +ENTRY AddDotsFunc { + x = f32[2,2] parameter(0) + y = f32[2,2] parameter(1) + param_2 = (f32[2,2], f32[3,3]) parameter(2) + bias = f32[2,2] get-tuple-element(param_2), index=0 + dot_a = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT out = f32[2,2] add(dot_a, bias) +} +)"; + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + MatchOptimizedHlo(hlo_text, + R"( +; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,2], y: f32[2,2], param_2: (f32[2,2], f32[3,3])) -> f32[2,2] { +; CHECK-DAG: [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter(0) +; CHECK-DAG: [[P1:%[^ ]+]] = f32[2,2]{1,0} parameter(1) +; CHECK-DAG: [[P2:%[^ ]+]] = (f32[2,2]{1,0}, f32[3,3]{1,0}) parameter(2) +; CHECK-DAG: [[BIAS:%[^ ]+]] = f32[2,2]{1,0} get-tuple-element([[P2]]), index=0 +; CHECK-DAG: [[BIAS_COPY:%[^ ]+]] = f32[2,2]{1,0} copy([[BIAS]]) +; CHECK-NEXT: [[GEMM:%[^ ]+]] = f32[2,2]{1,0} custom-call([[P0]], [[P1]], [[BIAS_COPY]]), +; CHECK: custom_call_target="__cublas$gemm", +; CHECK: output_to_operand_aliasing={{{{}: \(2, {}\)}}}, +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":1 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"DEFAULT\" +; CHECK: }" +)"); +} + +TEST_F(LegacyCublasGemmRewriteTest, AliasedBiasOverwrite) { + const char* hlo_text = R"( +HloModule AliasedBiasOverwrite, input_output_alias={ {}: (2, {}, must-alias) } + +ENTRY AddDotsFunc { + x = f32[2,2] parameter(0) + y = f32[2,2] parameter(1) + bias = f32[2,2] parameter(2) + k = f32[] constant(3.0) + k_broadcast = f32[2, 2] broadcast(k), dimensions={} + dot_a = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + dot_a_multiplied = f32[2, 2] multiply(dot_a, k_broadcast) + ROOT out = f32[2,2] add(dot_a_multiplied, bias) +} + +)"; + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + MatchOptimizedHlo(hlo_text, + R"( +; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,2], y: f32[2,2], bias: f32[2,2]) -> f32[2,2] { +; CHECK-DAG: [[X:%[^ ]+]] = f32[2,2]{1,0} parameter(0) +; CHECK-DAG: [[Y:%[^ ]+]] = f32[2,2]{1,0} parameter(1) +; CHECK-DAG: [[BIAS:%[^ ]+]] = f32[2,2]{1,0} parameter(2) +; CHECK: ROOT [[OUT:%[^ ]+]] = f32[2,2]{1,0} custom-call([[X]], [[Y]], [[BIAS]]), +; CHECK: custom_call_target="__cublas$gemm", +; CHECK: output_to_operand_aliasing={{{{}: \(2, {}\)}}}, +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":3 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":1 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"DEFAULT\" +; CHECK: }" +)"); +} + +TEST_F(LegacyCublasGemmRewriteTest, LargerBiasMultipleUsersNoRewrite) { + const char* hlo_text = R"( +HloModule LargerBiasMultipleUsersNoRewrite + +ENTRY AddDotsFunc { + x = f32[1024,1024] parameter(0) + y = f32[1024,1024] parameter(1) + bias = f32[1024,1024] parameter(2) + dot_a = f32[1024,1024] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + biased_out = f32[1024,1024] add(dot_a, bias) + ROOT out = f32[1024,1024] add(biased_out, bias) +} + +)"; + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3})); + MatchOptimizedHlo(hlo_text, + R"( +; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[1024,1024], y: f32[1024,1024], bias: f32[1024,1024]) -> f32[1024,1024] { +; CHECK-DAG: [[P0:%[^ ]+]] = f32[1024,1024]{1,0} parameter(0) +; CHECK-DAG: [[P1:%[^ ]+]] = f32[1024,1024]{1,0} parameter(1) +; CHECK-NEXT: [[GEMM:%[^ ]+]] = f32[1024,1024]{1,0} custom-call([[P0]], [[P1]]), +; CHECK: custom_call_target="__cublas$gemm", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":0 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"DEFAULT\" +; CHECK: }" +)"); +} + +// In order to have a bias add fused, the bias term must be overwritable. +// We assume that we may not overwrite parameters of a computation. Hence, we +// use the third parameter to create a new value which can be overwritten and +// will be used as the bias. This negate(param_2) has no semantic use, it simply +// exists so that bias may be overwritten. +TEST_F(LegacyCublasGemmRewriteTest, BF16GemmWithBias) { + const char* hlo_text = R"( +HloModule BF16GemmWithBias + +ENTRY BF16GemmWithBias { + x = bf16[8,8]{1,0} parameter(0) + y = bf16[8,8]{1,0} parameter(1) + dot.5 = bf16[8,8]{1,0} dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + param_2 = bf16[8,8]{1,0} parameter(2) + bias = bf16[8,8]{1,0} negate(param_2) + ROOT add.6 = bf16[8,8]{1,0} add(dot.5, bias) +} + )"; + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{2e-3, 2e-3})); + MatchOptimizedHlo(hlo_text, + R"( +; CHECK-LABEL: ENTRY %BF16GemmWithBias (x: bf16[8,8], y: bf16[8,8], param_2: bf16[8,8]) -> bf16[8,8] { +; CHECK-DAG: [[X:%[^ ]+]] = bf16[8,8]{1,0} parameter(0) +; CHECK-DAG: [[Y:%[^ ]+]] = bf16[8,8]{1,0} parameter(1) +; CHECK: ROOT [[GEMM:%[^ ]+]] = bf16[8,8]{1,0} custom-call([[X]], [[Y]], {{[^,)]+}}), +; CHECK: custom_call_target="__cublas$gemm", +; CHECK: output_to_operand_aliasing={{{{}: \(2, {}\)}}}, +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":1 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"DEFAULT\" +; CHECK: }" +)"); +} + +// In order to have a bias add fused, the bias term must be overwritable. +// We assume that we may not overwrite parameters of a computation. Hence, we +// use the third parameter to create a new value which can be overwritten and +// will be used as the bias. This negate(param_2) has no semantic use, it simply +// exists so that bias may be overwritten. +TEST_F(LegacyCublasGemmRewriteTest, MatrixBias) { + const char* hlo_text = R"( +HloModule test + +ENTRY test { + x = f32[2,3] parameter(0) + y = f32[3,4] parameter(1) + param_2 = f32[2,4] parameter(2) + bias = f32[2,4] negate(param_2) + dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT out = f32[2,4] add(dot_a, bias) +} + +)"; + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + MatchOptimizedHlo(hlo_text, + R"( +; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4], param_2: f32[2,4]) -> f32[2,4] { +; CHECK-DAG: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) +; CHECK-DAG: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) +; CHECK: ROOT [[GEMM:%[^ ]+]] = f32[2,4]{1,0} custom-call([[P0]], [[P1]], {{[^,)]+}}), +; CHECK: custom_call_target="__cublas$gemm", +; CHECK: output_to_operand_aliasing={{{{}: \(2, {}\)}}}, +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":1 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"DEFAULT\" +; CHECK: }" +)"); +} + +TEST_F(LegacyCublasGemmRewriteTest, MatrixBiasWhereBiasIsNotAParameter) { + const char* hlo_text = R"( +HloModule test + +ENTRY test { + w = f32[2,3] parameter(0) + x = f32[3,4] parameter(1) + first_dot = f32[2,4] dot(w, x), lhs_contracting_dims={1}, rhs_contracting_dims={0} + y = f32[2,3] parameter(2) + z = f32[3,4] parameter(3) + second_dot = f32[2,4] dot(y, z), lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT out = f32[2,4] add(second_dot, first_dot) +} + +)"; + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + MatchOptimizedHlo(hlo_text, + R"( +; CHECK-LABEL: ENTRY %test (w: f32[2,3], x: f32[3,4], y: f32[2,3], z: f32[3,4]) -> f32[2,4] { +; CHECK-DAG: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) +; CHECK-DAG: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) +; CHECK-DAG: [[P2:%[^ ]+]] = f32[2,3]{1,0} parameter(2) +; CHECK-DAG: [[P3:%[^ ]+]] = f32[3,4]{1,0} parameter(3) +; CHECK-NEXT: [[FIRST_GEMM:%[^ ]+]] = f32[2,4]{1,0} custom-call([[P0]], [[P1]]), +; CHECK: custom_call_target="__cublas$gemm", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":0 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"DEFAULT\" +; CHECK: }" +; CHECK-NEXT: ROOT [[SECOND_GEMM:%[^ ]+]] = f32[2,4]{1,0} custom-call([[P2]], [[P3]], [[FIRST_GEMM]]), +; CHECK: custom_call_target="__cublas$gemm", +; CHECK: output_to_operand_aliasing={{{{}: \(2, {}\)}}}, +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":1 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"DEFAULT\" +; CHECK: }" +)"); +} + +TEST_F(LegacyCublasGemmRewriteTest, MergeBitcastAndAdd) { + const char* hlo_text = R"( +HloModule test +ENTRY test { + x = f32[2,2] parameter(0) + y = f32[2,2] parameter(1) + bias = f32[4] parameter(2) + dot = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT out = f32[4] add(f32[4] bitcast(dot), bias) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + GemmRewriter pass(GetCudaComputeCapability()); + TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); + EXPECT_TRUE(changed); + + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch( + m::Bitcast( + m::CustomCall({"__cublas$gemm"}, m::Parameter(0), m::Parameter(1), + m::Bitcast(m::Parameter(2)).WithShape(F32, {2, 2}))) + .WithShape(F32, {4}))); +} + +// In order to have a bias add fused, the bias term must be overwritable. +// We assume that we may not overwrite parameters of a computation. Hence, we +// use the third parameter to create a new value which can be overwritten and +// will be used as the bias. This negate(param_2) has no semantic use, it simply +// exists so that bias may be overwritten. +TEST_F(LegacyCublasGemmRewriteTest, FoldConstantBias) { + const char* hlo_text = R"( +HloModule test +ENTRY test { + x = f32[2,2] parameter(0) + y = f32[2,2] parameter(1) + bias = f32[2,2] broadcast(f32[2] constant({0, 0})), dimensions={0} + + dot1 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + param_2 = f32[2,2] parameter(2) + bias1 = f32[2,2] negate(param_2) + sum1 = add(dot1, bias1) + + dot2 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + sum2 = add(dot2, f32[2,2] reshape(bias)) + + dot3 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + bias3 = f32[2,2] transpose(bias), dimensions={1,0} + sum3 = add(dot3, bias3) + + dot4 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + sum4 = add(dot4, f32[2,2] bitcast(bias)) + + ROOT root = tuple(sum1, sum2, sum3, sum4) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + GemmRewriter pass(GetCudaComputeCapability()); + TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); + SCOPED_TRACE(module->ToString()); + EXPECT_TRUE(changed); + + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Tuple( + m::CustomCall(m::Parameter(0), m::Parameter(1), + m::Negate(m::Parameter(2))), + m::CustomCall(m::Parameter(0), m::Parameter(1), m::Constant()), + m::CustomCall(m::Parameter(0), m::Parameter(1), m::Constant()), + m::CustomCall(m::Parameter(0), m::Parameter(1), m::Constant())))); +} + +// A test fixture class for tests which are specific to cublasLt +class CublasLtGemmRewriteTest : public GemmRewriteTest { + public: + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = GemmRewriteTest::GetDebugOptionsForTest(); + debug_options.set_xla_gpu_enable_cublaslt(true); + return debug_options; + } +}; + +TEST_F(CublasLtGemmRewriteTest, AlphaBetaRewrite) { + const char* hlo_text = R"( +HloModule NonZeroAlphaBeta + +ENTRY AddDotsFunc { + x = f32[2,2] parameter(0) + y = f32[2,2] parameter(1) + bias = f32[2,2] parameter(2) + k = f32[] constant(3.0) + k_broadcast = f32[2, 2] broadcast(k), dimensions={} + dot_a = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + dot_a_multiplied = f32[2, 2] multiply(dot_a, k_broadcast) + ROOT out = f32[2,2] add(dot_a_multiplied, bias) +} + +)"; + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + MatchOptimizedHlo(hlo_text, + R"( +; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,2], y: f32[2,2], bias: f32[2,2]) -> f32[2,2] { +; CHECK-DAG: [[X:%[^ ]+]] = f32[2,2]{1,0} parameter(0) +; CHECK-DAG: [[Y:%[^ ]+]] = f32[2,2]{1,0} parameter(1) +; CHECK-DAG: [[BIAS:%[^ ]+]] = f32[2,2]{1,0} parameter(2) +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,2]{1,0} custom-call([[X]], [[Y]], [[BIAS]]), +; CHECK: custom_call_target="__cublas$lt$matmul", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":3 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":1 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"DEFAULT\" +; CHECK: }" +)"); +} + +TEST_F(CublasLtGemmRewriteTest, BiasMultipleUsersNoOverwrite) { + const char* hlo_text = R"( +HloModule BiasMultipleUsersNoOverwrite + +ENTRY AddDotsFunc { + x = f32[2,2] parameter(0) + y = f32[2,2] parameter(1) + bias = f32[2,2] parameter(2) + k = f32[] constant(3.0) + k_broadcast = f32[2, 2] broadcast(k), dimensions={} + dot_a = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + dot_a_multiplied = f32[2, 2] multiply(dot_a, k_broadcast) + biased_out = f32[2,2] add(dot_a_multiplied, bias) + ROOT out = f32[2,2] add(biased_out, bias) +} +)"; + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + MatchOptimizedHlo(hlo_text, + R"( +; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,2], y: f32[2,2], bias: f32[2,2]) -> f32[2,2] { +; CHECK-DAG: [[P0:%[^ ]+]] = f32[2,2]{1,0} parameter(0) +; CHECK-DAG: [[P1:%[^ ]+]] = f32[2,2]{1,0} parameter(1) +; CHECK-DAG: [[BIAS:%[^ ]+]] = f32[2,2]{1,0} parameter(2) +; CHECK-NEXT: [[GEMM:%[^ ]+]] = f32[2,2]{1,0} custom-call([[P0]], [[P1]], [[BIAS]]), +; CHECK: custom_call_target="__cublas$lt$matmul", +; CHECK-NOT: output_to_operand_aliasing +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":3 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":1 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"DEFAULT\" +; CHECK: }" +)"); +} + +TEST_F(CublasLtGemmRewriteTest, LargerBiasMultipleUsersNoRewrite) { + const char* hlo_text = R"( +HloModule LargerBiasMultipleUsersNoRewrite + +ENTRY AddDotsFunc { + x = f32[1024,1024] parameter(0) + y = f32[1024,1024] parameter(1) + bias = f32[1024,1024] parameter(2) + dot_a = f32[1024,1024] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + biased_out = f32[1024,1024] add(dot_a, bias) + ROOT out = f32[1024,1024] add(biased_out, bias) +} + +)"; + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3})); + MatchOptimizedHlo(hlo_text, + R"( +; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[1024,1024], y: f32[1024,1024], bias: f32[1024,1024]) -> f32[1024,1024] { +; CHECK-DAG: [[P0:%[^ ]+]] = f32[1024,1024]{1,0} parameter(0) +; CHECK-DAG: [[P1:%[^ ]+]] = f32[1024,1024]{1,0} parameter(1) +; CHECK-DAG: [[BIAS:%[^ ]+]] = f32[1024,1024]{1,0} parameter(2) +; CHECK-NEXT: [[GEMM:%[^ ]+]] = f32[1024,1024]{1,0} custom-call([[P0]], [[P1]], [[BIAS]]), +; CHECK: custom_call_target="__cublas$lt$matmul", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":1 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"DEFAULT\" +; CHECK: }" +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[1024,1024]{1,0} add([[GEMM]], [[BIAS]]) +)"); +} + +TEST_F(CublasLtGemmRewriteTest, BF16GemmWithBias) { + const char* hlo_text = R"( +HloModule test + +ENTRY BF16GemmWithBias { + x = bf16[8,8]{1,0} parameter(0) + y = bf16[8,8]{1,0} parameter(1) + dot.5 = bf16[8,8]{1,0} dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + bias = bf16[8,8]{1,0} parameter(2) + ROOT add.6 = bf16[8,8]{1,0} add(dot.5, bias) +} + )"; + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3})); + MatchOptimizedHlo(hlo_text, + R"( +; CHECK-LABEL: ENTRY %BF16GemmWithBias (x: bf16[8,8], y: bf16[8,8], bias: bf16[8,8]) -> bf16[8,8] { +; CHECK-DAG: [[X:%[^ ]+]] = bf16[8,8]{1,0} parameter(0) +; CHECK-DAG: [[Y:%[^ ]+]] = bf16[8,8]{1,0} parameter(1) +; CHECK-DAG: [[BIAS:%[^ ]+]] = bf16[8,8]{1,0} parameter(2) +; CHECK-NEXT: ROOT [[GEMM:%[^ ]+]] = bf16[8,8]{1,0} custom-call([[X]], [[Y]], [[BIAS]]), +; CHECK: custom_call_target="__cublas$lt$matmul", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":1 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"DEFAULT\" +; CHECK: }" +)"); +} + +TEST_F(CublasLtGemmRewriteTest, MatrixBias) { + const char* hlo_text = R"( +HloModule test + +ENTRY test { + x = f32[2,3] parameter(0) + y = f32[3,4] parameter(1) + z = f32[2,4] parameter(2) + dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT out = f32[2,4] add(dot_a, z) +} + +)"; + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + MatchOptimizedHlo(hlo_text, + R"( +; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4], z: f32[2,4]) -> f32[2,4] { +; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) +; CHECK-NEXT: [[P2:%[^ ]+]] = f32[2,4]{1,0} parameter(2) +; CHECK-NEXT: ROOT [[GEMM:%[^ ]+]] = f32[2,4]{1,0} custom-call([[P0]], [[P1]], [[P2]]), +; CHECK: custom_call_target="__cublas$lt$matmul", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":1 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"DEFAULT\" +; CHECK: }" +)"); +} + +TEST_F(CublasLtGemmRewriteTest, MatrixBiasWhereBiasIsNotAParameter) { + const char* hlo_text = R"( +HloModule test + +ENTRY test { + w = f32[2,3] parameter(0) + x = f32[3,4] parameter(1) + first_dot = f32[2,4] dot(w, x), lhs_contracting_dims={1}, rhs_contracting_dims={0} + y = f32[2,3] parameter(2) + z = f32[3,4] parameter(3) + second_dot = f32[2,4] dot(y, z), lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT out = f32[2,4] add(second_dot, first_dot) +} + +)"; + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + MatchOptimizedHlo(hlo_text, + R"( +; CHECK-LABEL: ENTRY %test (w: f32[2,3], x: f32[3,4], y: f32[2,3], z: f32[3,4]) -> f32[2,4] { +; CHECK-DAG: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) +; CHECK-DAG: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) +; CHECK-DAG: [[P2:%[^ ]+]] = f32[2,3]{1,0} parameter(2) +; CHECK-DAG: [[P3:%[^ ]+]] = f32[3,4]{1,0} parameter(3) +; CHECK-NEXT: [[FIRST_GEMM:%[^ ]+]] = f32[2,4]{1,0} custom-call([[P0]], [[P1]]), +; CHECK: custom_call_target="__cublas$lt$matmul", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":0 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"DEFAULT\" +; CHECK: }" +; CHECK-NEXT: ROOT [[SECOND_GEMM:%[^ ]+]] = f32[2,4]{1,0} custom-call([[P2]], [[P3]], [[FIRST_GEMM]]), +; CHECK: custom_call_target="__cublas$lt$matmul", +; CHECK: output_to_operand_aliasing={{{{}: \(2, {}\)}}}, +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":1 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"DEFAULT\" +; CHECK: }" +)"); +} + +TEST_F(CublasLtGemmRewriteTest, VectorBias) { + const char* hlo_text = R"( +HloModule test + +ENTRY test { + x = f32[2,3] parameter(0) + y = f32[3,4] parameter(1) + z = f32[4] parameter(2) + dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + z_bcast = f32[2,4] broadcast(z), dimensions={1} + ROOT out = f32[2,4] add(dot_a, z_bcast) +} + +)"; + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + MatchOptimizedHlo(hlo_text, + R"( +; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4], z: f32[4]) -> f32[2,4] { +; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) +; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4]{0} parameter(2) +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,4]{1,0} custom-call([[P0]], [[P1]], [[P2]]), +; CHECK: custom_call_target="__cublas$lt$matmul", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":0 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"BIAS\" +; CHECK: }" +)"); +} + +// Epilogue Fusion disabled when GEMM has multiple users. +TEST_F(CublasLtGemmRewriteTest, VectorBiasMultipleUsers) { + const char* hlo_text = R"( +HloModule test + +ENTRY test { + x = f32[4,4] parameter(0) + y = f32[4,4] parameter(1) + z = f32[4] parameter(2) + c = f32[] constant(5) + dot_a = f32[4,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + z_bcast = f32[4,4] broadcast(z), dimensions={1} + add_a = f32[4,4] add(dot_a, z_bcast) + c_bcast = f32[4,4] broadcast(c), dimensions={} + dot_b = f32[4,4] dot(dot_a, c_bcast), lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT out = f32[4,4] dot(add_a, dot_b), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +)"; + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + MatchOptimizedHlo(hlo_text, + R"( + +; CHECK: [[FUSED_COMPUTATION:%[^ ]+]] ([[DUMMY0:[^ ]+]]: f32[4,4], [[DUMMY1:[^ ]+]]: f32[4]) -> f32[4,4] { +; CHECK-NEXT: [[P0:%[^ ]+]] = f32[4,4]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4]{0} parameter(1) +; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4,4]{1,0} broadcast([[P1]]), dimensions={1} +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[4,4]{1,0} add([[P0]], [[P2]]) +} + +; CHECK-LABEL: ENTRY %test (x: f32[4,4], y: f32[4,4], z: f32[4]) -> f32[4,4] { +; CHECK-NEXT: [[P0:%[^ ]+]] = f32[4,4]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4,4]{1,0} parameter(1) +; CHECK-NEXT: [[MATMUL0:%[^ ]+]] = f32[4,4]{1,0} custom-call([[P0]], [[P1]]), +; CHECK: custom_call_target="__cublas$lt$matmul", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":0 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"DEFAULT\" +; CHECK: }" +; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4]{0} parameter(2) +; CHECK-NEXT: [[FUSION:%[^ ]+]] = f32[4,4]{1,0} fusion([[MATMUL0]], [[P2]]), kind=kLoop, calls=[[FUSED_COMPUTATION]] +; CHECK-NEXT: [[C0:%[^ ]+]] = f32[] constant(5) +; CHECK-NEXT: [[C0_BCAST:%[^ ]+]] = f32[4,4]{1,0} broadcast([[C0]]), dimensions={} +; CHECK-NEXT: [[MATMUL1:%[^ ]+]] = f32[4,4]{1,0} custom-call([[MATMUL0]], [[C0_BCAST]]), +; CHECK: custom_call_target="__cublas$lt$matmul", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":0 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"DEFAULT\" +; CHECK: }" +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[4,4]{1,0} custom-call([[FUSION]], [[MATMUL1]]), +; CHECK: custom_call_target="__cublas$lt$matmul", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":0 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"DEFAULT\" +; CHECK: }" + )"); +} + +TEST_F(CublasLtGemmRewriteTest, BatchedVectorBias) { + const char* hlo_text = R"( +HloModule test + +ENTRY test { + x = f32[2,3,4] parameter(0) + y = f32[4,5,6] parameter(1) + z = f32[3,5,6] parameter(2) + dot_a = f32[2,3,5,6] dot(x, y), lhs_contracting_dims={2}, rhs_contracting_dims={0} + z_bcast = f32[2,3,5,6] broadcast(z), dimensions={1,2,3} + ROOT out = f32[2,3,5,6] add(dot_a, z_bcast) +} + +)"; + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + MatchOptimizedHlo(hlo_text, + R"( + +; CHECK: [[FUSED_COMPUTATION:%[^ ]+]] ([[DUMMY0:[^ ]+]]: f32[3,5,6]) -> f32[6,30] { +; CHECK-NEXT: [[P0:%[^ ]+]] = f32[3,5,6]{2,1,0} parameter(0) +; CHECK-NEXT: [[P0_BCAST:%[^ ]+]] = f32[2,3,5,6]{3,2,1,0} broadcast([[P0]]), dimensions={1,2,3} +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[6,30]{1,0} bitcast([[P0_BCAST]]) +} + +; CHECK-LABEL: ENTRY %test (x: f32[2,3,4], y: f32[4,5,6], z: f32[3,5,6]) -> f32[2,3,5,6] { +; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3,4]{2,1,0} parameter(0) +; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[6,4]{1,0} bitcast([[P0]]) +; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4,5,6]{2,1,0} parameter(1) +; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[4,30]{1,0} +; CHECK-NEXT: [[P2:%[^ ]+]] = f32[3,5,6]{2,1,0} parameter(2) +; CHECK-NEXT: [[FUSION:%[^ ]+]] = f32[6,30]{1,0} fusion([[P2]]), kind=kLoop, calls=[[FUSED_COMPUTATION]] +; CHECK-NEXT: [[MATMUL:%[^ ]+]] = f32[6,30]{1,0} custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[FUSION]]), +; CHECK: custom_call_target="__cublas$lt$matmul", +; CHECK: output_to_operand_aliasing={{[{][{]}}}: (2, {})}, +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":1 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"DEFAULT\" +; CHECK: }" +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,3,5,6]{3,2,1,0} bitcast([[MATMUL]]) + )"); +} + +TEST_F(CublasLtGemmRewriteTest, BatchedSharedVectorBias) { + const char* hlo_text = R"( +HloModule test + +ENTRY test { + x = f32[2,3,4] parameter(0) + y = f32[4,5,6] parameter(1) + z = f32[6] parameter(2) + dot_a = f32[2,3,5,6] dot(x, y), lhs_contracting_dims={2}, rhs_contracting_dims={0} + z_bcast = f32[2,3,5,6] broadcast(z), dimensions={3} + ROOT out = f32[2,3,5,6] add(dot_a, z_bcast) +} + +)"; + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + MatchOptimizedHlo(hlo_text, + R"( + +; CHECK: [[FUSED_COMPUTATION:%[^ ]+]] ([[DUMMY0:[^ ]+]]: f32[6]) -> f32[6,30] { +; CHECK-NEXT: [[P0:%[^ ]+]] = f32[6]{0} parameter(0) +; CHECK-NEXT: [[P0_BCAST:%[^ ]+]] = f32[2,3,5,6]{3,2,1,0} broadcast([[P0]]), dimensions={3} +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[6,30]{1,0} bitcast([[P0_BCAST]]) +} + +; CHECK-LABEL: ENTRY %test (x: f32[2,3,4], y: f32[4,5,6], z: f32[6]) -> f32[2,3,5,6] { +; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3,4]{2,1,0} parameter(0) +; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[6,4]{1,0} bitcast([[P0]]) +; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4,5,6]{2,1,0} parameter(1) +; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[4,30]{1,0} +; CHECK-NEXT: [[P2:%[^ ]+]] = f32[6]{0} parameter(2) +; CHECK-NEXT: [[FUSION:%[^ ]+]] = f32[6,30]{1,0} fusion([[P2]]), kind=kLoop, calls=[[FUSED_COMPUTATION]] +; CHECK-NEXT: [[MATMUL:%[^ ]+]] = f32[6,30]{1,0} custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[FUSION]]), +; CHECK: custom_call_target="__cublas$lt$matmul", +; CHECK: output_to_operand_aliasing={{[{][{]}}}: (2, {})}, +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":1 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"DEFAULT\" +; CHECK: }" +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,3,5,6]{3,2,1,0} bitcast([[MATMUL]]) + )"); +} + +TEST_F(CublasLtGemmRewriteTest, VectorBiasIncorrectAxisFusedAsMatrix) { + const char* hlo_text = R"( +HloModule test + +ENTRY test { + x = f32[2,3] parameter(0) + y = f32[3,4] parameter(1) + z = f32[2] parameter(2) + dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + z_bcast = f32[2,4] broadcast(z), dimensions={0} + add = f32[2,4] add(dot_a, z_bcast) + ROOT out = f32[4,2] transpose(add), dimensions={1,0} +} + +)"; + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + MatchOptimizedHlo(hlo_text, + R"( +; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4], z: f32[2]) -> f32[4,2] { +; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) +; CHECK-NEXT: [[P2:%[^ ]+]] = f32[2]{0} parameter(2) +; CHECK-NEXT: [[MATMUL:%[^ ]+]] = f32[2,4]{0,1} custom-call([[P0]], [[P1]], [[P2]]), +; CHECK: custom_call_target="__cublas$lt$matmul", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":0 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"BIAS\" +; CHECK: }" +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[4,2]{1,0} bitcast([[MATMUL]]) +)"); +} + +TEST_F(CublasLtGemmRewriteTest, VectorBiasSliced) { + const char* hlo_text = R"( +HloModule test + +ENTRY test { + x = f32[4,3] parameter(0) + y = f32[3,4] parameter(1) + z = f32[3] parameter(2) + dot_a = f32[4,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + slice_a = f32[2,3] slice(dot_a), slice={[0:2], [0:3]} + z_bcast = f32[2,3] broadcast(z), dimensions={1} + ROOT out = f32[2,3] add(slice_a, z_bcast) +} + +)"; + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + MatchOptimizedHlo(hlo_text, + R"( + +; CHECK-LABEL: ENTRY %test (x: f32[4,3], y: f32[3,4], z: f32[3]) -> f32[2,3] { +; CHECK-NEXT: [[P0:%[^ ]+]] = f32[4,3]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) +; CHECK-NEXT: [[P2:%[^ ]+]] = f32[3]{0} parameter(2) +; CHECK-NEXT: [[MATMUL:%[^ ]+]] = f32[4,4]{1,0} custom-call([[P0]], [[P1]], [[P2]]), +; CHECK: custom_call_target="__cublas$lt$matmul", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":0 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"BIAS\" +; CHECK: }" +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,3]{1,0} slice([[MATMUL]]), slice={[0:2], [0:3]} + )"); +} + +// Epilogue Fusion disabled when slice has multiple users. +TEST_F(CublasLtGemmRewriteTest, VectorBiasSlicedMultipleUsers) { + const char* hlo_text = R"( +HloModule test + +ENTRY test { + x = f32[2,3] parameter(0) + y = f32[3,4] parameter(1) + z = f32[2] parameter(2) + c = f32[] constant(5) + dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + slice_a = f32[2,2] slice(dot_a), slice={[0:2], [0:2]} + z_bcast = f32[2,2] broadcast(z), dimensions={1} + add_a = f32[2,2] add(slice_a, z_bcast) + c_bcast = f32[2,2] broadcast(c), dimensions={} + dot_b = f32[2,2] dot(slice_a, c_bcast), lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT out = f32[2,2] dot(add_a, dot_b), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +)"; + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + MatchOptimizedHlo(hlo_text, + R"( + +; CHECK: [[FUSED_COMPUTATION:%[^ ]+]] ([[DUMMY0:[^ ]+]]: f32[2], [[DUMMY1:[^ ]+]]: f32[2,4]) -> f32[2,2] { +; CHECK-DAG: [[P0:%[^ ]+]] = f32[2]{0} parameter(0) +; CHECK-DAG: [[P1:%[^ ]+]] = f32[2,4]{1,0} parameter(1) +; CHECK-DAG: [[SLICE:%[^ ]+]] = f32[2,2]{1,0} slice([[P1]]), slice={[0:2], [0:2]} +; CHECK-NEXT: [[P0_BCAST:%[^ ]+]] = f32[2,2]{1,0} broadcast([[P0]]), dimensions={1} +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,2]{1,0} add([[SLICE]], [[P0_BCAST]]) +} + +; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4], z: f32[2]) -> f32[2,2] { +; CHECK-DAG: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) +; CHECK-DAG: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) +; CHECK-DAG: [[P2:%[^ ]+]] = f32[2]{0} parameter(2) +; CHECK-NEXT: [[MATMUL0:%[^ ]+]] = f32[2,4]{1,0} custom-call([[P0]], [[P1]]), +; CHECK: custom_call_target="__cublas$lt$matmul", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":0 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"DEFAULT\" +; CHECK: }" +; CHECK-NEXT: [[FUSION:%[^ ]+]] = f32[2,2]{1,0} fusion([[P2]], [[MATMUL0]]), kind=kLoop, calls=[[FUSED_COMPUTATION]] +; CHECK-NEXT: [[SLICE:%[^ ]+]] = f32[2,2]{1,0} slice([[MATMUL0]]), slice={[0:2], [0:2]} +; CHECK-NEXT: [[C0:%[^ ]+]] = f32[] constant(5) +; CHECK-NEXT: [[C0_BCAST:%[^ ]+]] = f32[2,2]{1,0} broadcast([[C0]]), dimensions={} +; CHECK-NEXT: [[MATMUL1:%[^ ]+]] = f32[2,2]{1,0} custom-call([[SLICE]], [[C0_BCAST]]), +; CHECK: custom_call_target="__cublas$lt$matmul", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":0 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"DEFAULT\" +; CHECK: }" +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,2]{1,0} custom-call([[FUSION]], [[MATMUL1]]), +; CHECK: custom_call_target="__cublas$lt$matmul", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":0 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"DEFAULT\" +; CHECK: }" + )"); +} + +TEST_F(CublasLtGemmRewriteTest, VectorBiasTransposed) { + const char* hlo_text = R"( +HloModule test + +ENTRY test { + x = f32[2,3] parameter(0) + y = f32[3,4] parameter(1) + z = f32[2] parameter(2) + dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + z_bcast = f32[2,4] parameter(3) + ROOT out = f32[2,4] add(dot_a, z_bcast) +} + +)"; + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + MatchOptimizedHlo(hlo_text, + R"( +; CHECK: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) +; CHECK-NEXT: [[P2_BCAST:%[^ ]+]] = f32[2,4]{1,0} parameter(3) +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,4]{1,0} custom-call([[P0]], [[P1]], [[P2_BCAST]]), +; CHECK: custom_call_target="__cublas$lt$matmul", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":1 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"DEFAULT\" +; CHECK: }" +)"); +} + +TEST_F(CublasLtGemmRewriteTest, VectorBiasThenMatrixBias) { + const char* hlo_text = R"( +HloModule test + +ENTRY test { + x = f32[2,3] parameter(0) + y = f32[3,4] parameter(1) + z = f32[4] parameter(2) + z2 = f32[2,4] parameter(3) + dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + z_bcast = f32[2,4] broadcast(z), dimensions={1} + add0 = f32[2,4] add(dot_a, z_bcast) + ROOT add1 = f32[2,4] add(add0, z2) +} + +)"; + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + MatchOptimizedHlo(hlo_text, + R"( +; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4], z: f32[4], z2: f32[2,4]) -> f32[2,4] { +; CHECK-DAG: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) +; CHECK-DAG: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) +; CHECK-DAG: [[VECTOR_BIAS:%[^ ]+]] = f32[4]{0} parameter(2) +; CHECK-DAG: [[MATRIX_BIAS:%[^ ]+]] = f32[2,4]{1,0} parameter(3) +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,4]{1,0} custom-call([[P0]], [[P1]], [[MATRIX_BIAS]], [[VECTOR_BIAS]]), +; CHECK: custom_call_target="__cublas$lt$matmul", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":1 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"BIAS\" +; CHECK: }" +)"); +} + +TEST_F(CublasLtGemmRewriteTest, BF16VectorBias) { + const char* hlo_text = R"( +HloModule test + +ENTRY test { + x = bf16[16,24] parameter(0) + y = bf16[24,32] parameter(1) + z = bf16[32] parameter(2) + dot_a = bf16[16,32] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + z_bcast = bf16[16,32] broadcast(z), dimensions={1} + ROOT out = bf16[16,32] add(dot_a, z_bcast) +} + +)"; + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{3e-3, 1e-3})); + MatchOptimizedHlo(hlo_text, + R"( + +; CHECK-LABEL: ENTRY %test (x: bf16[16,24], y: bf16[24,32], z: bf16[32]) -> bf16[16,32] { +; CHECK-NEXT: [[P0:%[^ ]+]] = bf16[16,24]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = bf16[24,32]{1,0} parameter(1) +; CHECK-NEXT: [[P2:%[^ ]+]] = bf16[32]{0} parameter(2) +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = bf16[16,32]{1,0} custom-call([[P0]], [[P1]], [[P2]]), +; CHECK: custom_call_target="__cublas$lt$matmul", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":0 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"BIAS\" + )"); +} + +TEST_F(CublasLtGemmRewriteTest, BF16VectorBiasPadded) { + if (!GetCudaComputeCapability().IsAtLeast( + se::CudaComputeCapability::AMPERE)) { + GTEST_SKIP() << "Padding of GEMM bf16 operands only implemented on " + "architectures with bf16 Tensor Cores."; + } + const char* hlo_text = R"( +HloModule test + +ENTRY test { + x = bf16[2,3] parameter(0) + y = bf16[3,4] parameter(1) + z = bf16[4] parameter(2) + dot_a = bf16[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + z_bcast = bf16[2,4] broadcast(z), dimensions={1} + ROOT out = bf16[2,4] add(dot_a, z_bcast) +} + +)"; + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3})); + MatchOptimizedHlo(hlo_text, + R"( + +; CHECK-LABEL: ENTRY %test (x: bf16[2,3], y: bf16[3,4], z: bf16[4]) -> bf16[2,4] { +; CHECK-NEXT: [[P0:%[^ ]+]] = bf16[2,3]{1,0} parameter(0) +; CHECK-NEXT: [[C0:%[^ ]+]] = bf16[] constant(0) +; CHECK-NEXT: [[P0_PADDED:%[^ ]+]] = bf16[8,8]{1,0} pad([[P0]], [[C0]]), padding=0_6x0_5 +; CHECK-NEXT: [[P1:%[^ ]+]] = bf16[3,4]{1,0} parameter(1) +; CHECK-NEXT: [[P1_PADDED:%[^ ]+]] = bf16[8,8]{1,0} pad([[P1]], [[C0]]), padding=0_5x0_4 +; CHECK-NEXT: [[P2:%[^ ]+]] = bf16[4]{0} parameter(2) +; CHECK-NEXT: [[MATMUL:%[^ ]+]] = bf16[8,8]{1,0} custom-call([[P0_PADDED]], [[P1_PADDED]], [[P2]]), +; CHECK: custom_call_target="__cublas$lt$matmul", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":0 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"BIAS\" +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = bf16[2,4]{1,0} slice([[MATMUL]]), slice={[0:2], [0:4]} + )"); +} + +TEST_F(CublasLtGemmRewriteTest, ReluActivation) { + const char* hlo_text = R"( +HloModule test + +ENTRY test { + x = f32[2,3] parameter(0) + y = f32[3,4] parameter(1) + dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + c = f32[] constant(0) + c_bcast = f32[2,4] broadcast(c), dimensions={} + ROOT out = f32[2,4] maximum(dot_a, c_bcast) +} + +)"; + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + MatchOptimizedHlo(hlo_text, + R"( + +; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4]) -> f32[2,4] { +; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,4]{1,0} custom-call([[P0]], [[P1]]), +; CHECK: custom_call_target="__cublas$lt$matmul", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":0 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"RELU\" +; CHECK: }" + )"); +} + +TEST_F(CublasLtGemmRewriteTest, BatchedReluActivation) { + const char* hlo_text = R"( +HloModule test + +ENTRY test { + x = f32[2,3,4] parameter(0) + y = f32[4,5,6] parameter(1) + dot_a = f32[2,3,5,6] dot(x, y), lhs_contracting_dims={2}, rhs_contracting_dims={0} + c = f32[] constant(0) + c_bcast = f32[2,3,5,6] broadcast(c), dimensions={} + ROOT out = f32[2,3,5,6] maximum(dot_a, c_bcast) +} + +)"; + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + MatchOptimizedHlo(hlo_text, + R"( + +; CHECK-LABEL: ENTRY %test (x: f32[2,3,4], y: f32[4,5,6]) -> f32[2,3,5,6] { +; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3,4]{2,1,0} parameter(0) +; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[6,4]{1,0} bitcast([[P0]]) +; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4,5,6]{2,1,0} parameter(1) +; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[4,30]{1,0} +; CHECK-NEXT: [[MATMUL:%[^ ]+]] = f32[6,30]{1,0} custom-call([[P0_BITCAST]], [[P1_BITCAST]]), +; CHECK: custom_call_target="__cublas$lt$matmul", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":0 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"RELU\" +; CHECK: }" +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,3,5,6]{3,2,1,0} bitcast([[MATMUL]]) + )"); +} + +TEST_F(CublasLtGemmRewriteTest, ReluActivationSliced) { + const char* hlo_text = R"( +HloModule test + +ENTRY test { + x = f32[2,3] parameter(0) + y = f32[3,4] parameter(1) + dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + c = f32[] constant(0) + c_bcast = f32[2,2] broadcast(c), dimensions={} + slice_a = f32[2,2] slice(dot_a), slice={[0:2], [0:2]} + ROOT out = f32[2,2] maximum(slice_a, c_bcast) +} + +)"; + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + MatchOptimizedHlo(hlo_text, + R"( + +; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4]) -> f32[2,2] { +; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) +; CHECK-NEXT: [[MATMUL:%[^ ]+]] = f32[2,4]{1,0} custom-call([[P0]], [[P1]]), +; CHECK: custom_call_target="__cublas$lt$matmul", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":0 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"RELU\" +; CHECK: }" +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,2]{1,0} slice([[MATMUL]]), slice={[0:2], [0:2]} + )"); +} + +TEST_F(CublasLtGemmRewriteTest, MatrixBiasReluActivation) { + const char* hlo_text = R"( +HloModule test + +ENTRY test { + x = f32[2,3] parameter(0) + y = f32[3,4] parameter(1) + z = f32[2,4] parameter(2) + dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + add = f32[2,4] add(dot_a, z) + c = f32[] constant(0) + c_bcast = f32[2,4] broadcast(c), dimensions={} + ROOT out = f32[2,4] maximum(add, c_bcast) +} + +)"; + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + MatchOptimizedHlo(hlo_text, + R"( + +; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4], z: f32[2,4]) -> f32[2,4] { +; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) +; CHECK-NEXT: [[P2:%[^ ]+]] = f32[2,4]{1,0} parameter(2) +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,4]{1,0} custom-call([[P0]], [[P1]], [[P2]]), +; CHECK: custom_call_target="__cublas$lt$matmul", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":1 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"RELU\" +; CHECK: }" + )"); +} + +TEST_F(CublasLtGemmRewriteTest, SquareMatrixBiasReluActivation) { + const char* hlo_text = R"( +HloModule test + +ENTRY test { + x = f32[4,4] parameter(0) + y = f32[4,4] parameter(1) + z = f32[4,4] parameter(2) + dot_a = f32[4,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + add = f32[4,4] add(dot_a, z) + c = f32[] constant(0) + c_bcast = f32[4,4] broadcast(c), dimensions={} + ROOT out = f32[4,4] maximum(add, c_bcast) +} + +)"; + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + MatchOptimizedHlo(hlo_text, + R"( + +; CHECK-LABEL: ENTRY %test (x: f32[4,4], y: f32[4,4], z: f32[4,4]) -> f32[4,4] { +; CHECK-NEXT: [[P0:%[^ ]+]] = f32[4,4]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4,4]{1,0} parameter(1) +; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4,4]{1,0} parameter(2) +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[4,4]{1,0} custom-call([[P0]], [[P1]], [[P2]]), +; CHECK: custom_call_target="__cublas$lt$matmul", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":1 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"RELU\" +; CHECK: }" + )"); +} + +TEST_F(CublasLtGemmRewriteTest, VectorBiasReluActivation) { + const char* hlo_text = R"( +HloModule test + +ENTRY test { + x = f32[2,3] parameter(0) + y = f32[3,4] parameter(1) + z = f32[4] parameter(2) + dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + z_bcast = f32[2,4] broadcast(z), dimensions={1} + add = f32[2,4] add(dot_a, z_bcast) + c = f32[] constant(0) + c_bcast = f32[2,4] broadcast(c), dimensions={} + ROOT out = f32[2,4] maximum(add, c_bcast) +} + +)"; + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + MatchOptimizedHlo(hlo_text, + R"( + +; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4], z: f32[4]) -> f32[2,4] { +; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) +; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4]{0} parameter(2) +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,4]{1,0} custom-call([[P0]], [[P1]], [[P2]]), +; CHECK: custom_call_target="__cublas$lt$matmul", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":0 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"BIAS_RELU\" +; CHECK: }" + )"); +} + +TEST_F(CublasLtGemmRewriteTest, BatchedVectorBiasReluActivation) { + const char* hlo_text = R"( +HloModule test + +ENTRY test { + x = f32[2,3,4] parameter(0) + y = f32[4,5,6] parameter(1) + z = f32[3,5,6] parameter(2) + dot_a = f32[2,3,5,6] dot(x, y), lhs_contracting_dims={2}, rhs_contracting_dims={0} + z_bcast = f32[2,3,5,6] broadcast(z), dimensions={1,2,3} + add = f32[2,3,5,6] add(dot_a, z_bcast) + c = f32[] constant(0) + c_bcast = f32[2,3,5,6] broadcast(c), dimensions={} + ROOT out = f32[2,3,5,6] maximum(add, c_bcast) +} + +)"; + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + MatchOptimizedHlo(hlo_text, + R"( + +; CHECK: [[FUSED_COMPUTATION:%[^ ]+]] ([[DUMMY0:[^ ]+]]: f32[3,5,6]) -> f32[6,30] { +; CHECK-NEXT: [[P0:%[^ ]+]] = f32[3,5,6]{2,1,0} parameter(0) +; CHECK-NEXT: [[P0_BCAST:%[^ ]+]] = f32[2,3,5,6]{3,2,1,0} broadcast([[P0]]), dimensions={1,2,3} +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[6,30]{1,0} bitcast([[P0_BCAST]]) +} + +; CHECK-LABEL: ENTRY %test (x: f32[2,3,4], y: f32[4,5,6], z: f32[3,5,6]) -> f32[2,3,5,6] { +; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3,4]{2,1,0} parameter(0) +; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[6,4]{1,0} bitcast([[P0]]) +; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4,5,6]{2,1,0} parameter(1) +; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[4,30]{1,0} +; CHECK-NEXT: [[P2:%[^ ]+]] = f32[3,5,6]{2,1,0} parameter(2) +; CHECK-NEXT: [[FUSION:%[^ ]+]] = f32[6,30]{1,0} fusion([[P2]]), kind=kLoop, calls=[[FUSED_COMPUTATION]] +; CHECK-NEXT: [[MATMUL:%[^ ]+]] = f32[6,30]{1,0} custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[FUSION]]), +; CHECK: custom_call_target="__cublas$lt$matmul", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":1 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"RELU\" +; CHECK: }" +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,3,5,6]{3,2,1,0} bitcast([[MATMUL]]) + )"); +} + +TEST_F(CublasLtGemmRewriteTest, VectorBiasTransposedReluActivation) { + const char* hlo_text = R"( +HloModule test + +ENTRY test { + x = f32[2,3] parameter(0) + y = f32[3,4] parameter(1) + z = f32[2] parameter(2) + dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + z_bcast = f32[2,4] broadcast(z), dimensions={0} + add = f32[2,4] add(dot_a, z_bcast) + c = f32[] constant(0) + c_bcast = f32[2,4] broadcast(c), dimensions={} + maximum = f32[2,4] maximum(add, c_bcast) + ROOT out = f32[4,2] transpose(maximum), dimensions={1,0} +} + +)"; + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + MatchOptimizedHlo(hlo_text, + R"( + +; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4], z: f32[2]) -> f32[4,2] { +; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) +; CHECK-NEXT: [[P2:%[^ ]+]] = f32[2]{0} parameter(2) +; CHECK-NEXT: [[MATMUL:%[^ ]+]] = f32[2,4]{0,1} custom-call([[P0]], [[P1]], [[P2]]), +; CHECK: custom_call_target="__cublas$lt$matmul", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":0 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"BIAS_RELU\" +; CHECK: }" +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[4,2]{1,0} bitcast([[MATMUL]]) + )"); +} + +TEST_F(CublasLtGemmRewriteTest, VectorBiasThenMatrixBiasReluActivation) { + const char* hlo_text = R"( +HloModule test + +ENTRY test { + x = f32[2,3] parameter(0) + y = f32[3,4] parameter(1) + z_vec = f32[4] parameter(2) + z_matrix = f32[2,4] parameter(3) + dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + z_bcast = f32[2,4] broadcast(z_vec), dimensions={1} + add0 = f32[2,4] add(dot_a, z_bcast) + add1 = f32[2,4] add(add0, z_matrix) + c = f32[] constant(0) + c_bcast = f32[2,4] broadcast(c), dimensions={} + ROOT out = f32[2,4] maximum(add1, c_bcast) +} + +)"; + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + MatchOptimizedHlo(hlo_text, + R"( + +; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4], z_vec: f32[4], z_matrix: f32[2,4]) -> f32[2,4] { +; CHECK-DAG: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) +; CHECK-DAG: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) +; CHECK-DAG: [[P2:%[^ ]+]] = f32[4]{0} parameter(2) +; CHECK-DAG: [[P3:%[^ ]+]] = f32[2,4]{1,0} parameter(3) +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,4]{1,0} custom-call([[P0]], [[P1]], [[P3]], [[P2]]), +; CHECK: custom_call_target="__cublas$lt$matmul", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":1 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"BIAS_RELU\" +; CHECK: }" + )"); +} + +TEST_F(CublasLtGemmRewriteTest, ApproxGeluActivation) { + const char* hlo_text = R"( +HloModule test + +ENTRY test { + x = f32[2,3] parameter(0) + y = f32[3,4] parameter(1) + dot = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + mul.0 = f32[2,4] multiply(dot, dot) + mul.1 = f32[2,4] multiply(dot, mul.0) + const.0 = f32[] constant(0.044715) + bcast.0 = f32[2,4] broadcast(const.0), dimensions={} + mul.2 = f32[2,4] multiply(mul.1, bcast.0) + add.0 = f32[2,4] add(dot, mul.2) + const.1 = f32[] constant(0.797884583) + bcast.1 = f32[2,4] broadcast(const.1), dimensions={} + mul.3 = f32[2,4] multiply(add.0, bcast.1) + tanh = f32[2,4] tanh(mul.3) + const.2 = f32[] constant(1) + bcast.2 = f32[2,4] broadcast(const.2), dimensions={} + add.2 = f32[2,4] add(tanh, bcast.2) + const.3 = f32[] constant(0.5) + bcast.3 = f32[2,4] broadcast(const.3), dimensions={} + mul.4 = f32[2,4] multiply(add.2, bcast.3) + ROOT out = f32[2,4] multiply(dot, mul.4) +} + +)"; + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + MatchOptimizedHlo(hlo_text, + R"( + +; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4]) -> f32[2,4] { +; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,4]{1,0} custom-call([[P0]], [[P1]]), +; CHECK: custom_call_target="__cublas$lt$matmul", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":0 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"GELU\" +; CHECK: }" + )"); +} + +TEST_F(CublasLtGemmRewriteTest, ApproxGeluActivationWrongConstant) { + // Modify one constant slightly, so it should no longer pattern match. + const char* hlo_text = R"( +HloModule test + +ENTRY test { + x = f32[2,3] parameter(0) + y = f32[3,4] parameter(1) + dot = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + mul.0 = f32[2,4] multiply(dot, dot) + mul.1 = f32[2,4] multiply(dot, mul.0) + const.0 = f32[] constant(0.05) + bcast.0 = f32[2,4] broadcast(const.0), dimensions={} + mul.2 = f32[2,4] multiply(mul.1, bcast.0) + add.0 = f32[2,4] add(dot, mul.2) + const.1 = f32[] constant(0.797884583) + bcast.1 = f32[2,4] broadcast(const.1), dimensions={} + mul.3 = f32[2,4] multiply(add.0, bcast.1) + tanh = f32[2,4] tanh(mul.3) + const.2 = f32[] constant(1) + bcast.2 = f32[2,4] broadcast(const.2), dimensions={} + add.2 = f32[2,4] add(tanh, bcast.2) + const.3 = f32[] constant(0.5) + bcast.3 = f32[2,4] broadcast(const.3), dimensions={} + mul.4 = f32[2,4] multiply(add.2, bcast.3) + ROOT out = f32[2,4] multiply(dot, mul.4) +} + +)"; + + MatchOptimizedHlo(hlo_text, + R"( + +; CHECK-NOT: GELU + )"); +} + +TEST_F(CublasLtGemmRewriteTest, VectorBiasThenApproxGeluActivation) { + const char* hlo_text = R"( +HloModule test + +ENTRY test { + x = f32[2,3] parameter(0) + y = f32[3,4] parameter(1) + z = f32[4] parameter(2) + dot = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + z_bcast = f32[2,4] broadcast(z), dimensions={1} + add = f32[2,4] add(dot, z_bcast) + mul.0 = f32[2,4] multiply(add, add) + mul.1 = f32[2,4] multiply(add, mul.0) + const.0 = f32[] constant(0.044715) + bcast.0 = f32[2,4] broadcast(const.0), dimensions={} + mul.2 = f32[2,4] multiply(mul.1, bcast.0) + add.0 = f32[2,4] add(add, mul.2) + const.1 = f32[] constant(0.797884583) + bcast.1 = f32[2,4] broadcast(const.1), dimensions={} + mul.3 = f32[2,4] multiply(add.0, bcast.1) + tanh = f32[2,4] tanh(mul.3) + const.2 = f32[] constant(1) + bcast.2 = f32[2,4] broadcast(const.2), dimensions={} + add.2 = f32[2,4] add(tanh, bcast.2) + const.3 = f32[] constant(0.5) + bcast.3 = f32[2,4] broadcast(const.3), dimensions={} + mul.4 = f32[2,4] multiply(add.2, bcast.3) + ROOT out = f32[2,4] multiply(add, mul.4) +} + +)"; + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + MatchOptimizedHlo(hlo_text, + R"( + +; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4], z: f32[4]) -> f32[2,4] { +; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) +; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4]{0} parameter(2) +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,4]{1,0} custom-call([[P0]], [[P1]], [[P2]]), +; CHECK: custom_call_target="__cublas$lt$matmul", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":0 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"BIAS_GELU\" +; CHECK: }" + )"); +} + +TEST_F(CublasLtGemmRewriteTest, ApproxGeluActivationWithAux) { + const char* hlo_text = R"( +HloModule test + +ENTRY test { + x = f32[2,3] parameter(0) + y = f32[3,4] parameter(1) + dot = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + mul.0 = f32[2,4] multiply(dot, dot) + mul.1 = f32[2,4] multiply(dot, mul.0) + const.0 = f32[] constant(0.044715) + bcast.0 = f32[2,4] broadcast(const.0), dimensions={} + mul.2 = f32[2,4] multiply(mul.1, bcast.0) + add.0 = f32[2,4] add(dot, mul.2) + const.1 = f32[] constant(0.797884583) + bcast.1 = f32[2,4] broadcast(const.1), dimensions={} + mul.3 = f32[2,4] multiply(add.0, bcast.1) + tanh = f32[2,4] tanh(mul.3) + const.2 = f32[] constant(1) + bcast.2 = f32[2,4] broadcast(const.2), dimensions={} + add.2 = f32[2,4] add(tanh, bcast.2) + const.3 = f32[] constant(0.5) + bcast.3 = f32[2,4] broadcast(const.3), dimensions={} + mul.4 = f32[2,4] multiply(add.2, bcast.3) + mul.5 = f32[2,4] multiply(dot, mul.4) + ROOT out = (f32[2,4], f32[2,4]) tuple(mul.5, dot) +} + +)"; + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + MatchOptimizedHlo(hlo_text, + R"( + +; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4]) -> (f32[2,4], f32[2,4]) { +; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = (f32[2,4]{1,0}, f32[2,4]{1,0}) custom-call([[P0]], [[P1]]), +; CHECK: custom_call_target="__cublas$lt$matmul", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":0 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"GELU_AUX\" +; CHECK: }" + )"); +} + +TEST_F(CublasLtGemmRewriteTest, VectorBiasThenApproxGeluActivationWithAux) { + const char* hlo_text = R"( +HloModule test + +ENTRY test { + x = f32[2,3] parameter(0) + y = f32[3,4] parameter(1) + z = f32[4] parameter(2) + dot = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + z_bcast = f32[2,4] broadcast(z), dimensions={1} + add = f32[2,4] add(dot, z_bcast) + mul.0 = f32[2,4] multiply(add, add) + mul.1 = f32[2,4] multiply(add, mul.0) + const.0 = f32[] constant(0.044715) + bcast.0 = f32[2,4] broadcast(const.0), dimensions={} + mul.2 = f32[2,4] multiply(mul.1, bcast.0) + add.0 = f32[2,4] add(add, mul.2) + const.1 = f32[] constant(0.797884583) + bcast.1 = f32[2,4] broadcast(const.1), dimensions={} + mul.3 = f32[2,4] multiply(add.0, bcast.1) + tanh = f32[2,4] tanh(mul.3) + const.2 = f32[] constant(1) + bcast.2 = f32[2,4] broadcast(const.2), dimensions={} + add.2 = f32[2,4] add(tanh, bcast.2) + const.3 = f32[] constant(0.5) + bcast.3 = f32[2,4] broadcast(const.3), dimensions={} + mul.4 = f32[2,4] multiply(add.2, bcast.3) + mul.5 = f32[2,4] multiply(add, mul.4) + ROOT out = (f32[2,4], f32[2,4]) tuple(mul.5, add) +} + +)"; + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + MatchOptimizedHlo(hlo_text, + R"( + +; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4], z: f32[4]) -> (f32[2,4], f32[2,4]) { +; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) +; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4]{0} parameter(2) +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = (f32[2,4]{1,0}, f32[2,4]{1,0}) custom-call([[P0]], [[P1]], [[P2]]), +; CHECK: custom_call_target="__cublas$lt$matmul", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":0 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"BIAS_GELU_AUX\" +; CHECK: }" + )"); +} + +// For F16, the sizes of all dimensions of the operands are required to be +// multiples of 8 to allow matrix bias fusion. +TEST_F(CublasLtGemmRewriteTest, MatrixBiasF16) { + const char* hlo_text = R"( +HloModule test + +ENTRY test { + x = f16[8,16] parameter(0) + y = f16[16,8] parameter(1) + z = f16[8,8] parameter(2) + dot_a = f16[8,8] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT out = f16[8,8] add(dot_a, z) +} + +)"; + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3})); + MatchOptimizedHlo(hlo_text, + R"( + +; CHECK-LABEL: ENTRY %test (x: f16[8,16], y: f16[16,8], z: f16[8,8]) -> f16[8,8] { +; CHECK-NEXT: [[P0:%[^ ]+]] = f16[8,16]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = f16[16,8]{1,0} parameter(1) +; CHECK-NEXT: [[P2:%[^ ]+]] = f16[8,8]{1,0} parameter(2) +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f16[8,8]{1,0} custom-call([[P0]], [[P1]], [[P2]]), +; CHECK: custom_call_target="__cublas$lt$matmul", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":1 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"DEFAULT\" +; CHECK: }" + )"); +} + +// For F16, the operands are padded on GPUs with Tensor Cores (i.e. Volta and +// newer architectures) so that the sizes of all dimensions are multiples of 8. +TEST_F(CublasLtGemmRewriteTest, VectorBiasF16Unpadded) { + const char* hlo_text = R"( +HloModule test + +ENTRY test { + x = f16[8,16] parameter(0) + y = f16[16,8] parameter(1) + z = f16[8] parameter(2) + dot_a = f16[8,8] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + z_bcast = f16[8,8] broadcast(z), dimensions={1} + ROOT add = f16[8,8] add(dot_a, z_bcast) +} + +)"; + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{8e-3, 2e-3})); + MatchOptimizedHlo(hlo_text, + R"( + +; CHECK-LABEL: ENTRY %test (x: f16[8,16], y: f16[16,8], z: f16[8]) -> f16[8,8] { +; CHECK-NEXT: [[P0:%[^ ]+]] = f16[8,16]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = f16[16,8]{1,0} parameter(1) +; CHECK-NEXT: [[P2:%[^ ]+]] = f16[8]{0} parameter(2) +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f16[8,8]{1,0} custom-call([[P0]], [[P1]], [[P2]]), +; CHECK: custom_call_target="__cublas$lt$matmul", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":0 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"BIAS\" +; CHECK: }" + )"); +} + +TEST_F(CublasLtGemmRewriteTest, VectorBiasF16Padded) { + if (!GetCudaComputeCapability().IsAtLeast(se::CudaComputeCapability::VOLTA)) { + GTEST_SKIP() << "Padding of GEMM operands only implemented on " + "architectures with Tensor Cores."; + } + const char* hlo_text = R"( +HloModule test + +ENTRY test { + x = f16[6,12] parameter(0) + y = f16[12,6] parameter(1) + z = f16[6] parameter(2) + dot_a = f16[6,6] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + z_bcast = f16[6,6] broadcast(z), dimensions={1} + ROOT add = f16[6,6] add(dot_a, z_bcast) +} + +)"; + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3})); + MatchOptimizedHlo(hlo_text, + R"( + +; CHECK-LABEL: ENTRY %test (x: f16[6,12], y: f16[12,6], z: f16[6]) -> f16[6,6] { +; CHECK-NEXT: [[P0:%[^ ]+]] = f16[6,12]{1,0} parameter(0) +; CHECK-NEXT: [[C0:%[^ ]+]] = f16[] constant(0) +; CHECK-NEXT: [[P0_PADDED:%[^ ]+]] = f16[8,16]{1,0} pad([[P0]], [[C0]]), padding=0_2x0_4 +; CHECK-NEXT: [[P1:%[^ ]+]] = f16[12,6]{1,0} parameter(1) +; CHECK-NEXT: [[P1_PADDED:%[^ ]+]] = f16[16,8]{1,0} pad([[P1]], [[C0]]), padding=0_4x0_2 +; CHECK-NEXT: [[P2:%[^ ]+]] = f16[6]{0} parameter(2) +; CHECK-NEXT: [[MATMUL:%[^ ]+]] = f16[8,8]{1,0} custom-call([[P0_PADDED]], [[P1_PADDED]], [[P2]]), +; CHECK: custom_call_target="__cublas$lt$matmul", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":0 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"BIAS\" +; CHECK: }" +; CHECK-NEXT: [[OUT:%[^ ]+]] = f16[6,6]{1,0} slice([[MATMUL]]), slice={[0:6], [0:6]} + )"); +} + +// For F16, the operands are padded on GPUs with Tensor Cores (i.e. Volta and +// newer architectures) so that the sizes of all dimensions are multiples of 8. +TEST_F(CublasLtGemmRewriteTest, ReluActivationF16Unpadded) { + const char* hlo_text = R"( +HloModule test + +ENTRY test { + x = f16[8,16] parameter(0) + y = f16[16,8] parameter(1) + dot_a = f16[8,8] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + c = f16[] constant(0) + c_bcast = f16[8,8] broadcast(c), dimensions={} + ROOT out = f16[8,8] maximum(dot_a, c_bcast) +} + +)"; + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3})); + MatchOptimizedHlo(hlo_text, + R"( + +; CHECK-LABEL: ENTRY %test (x: f16[8,16], y: f16[16,8]) -> f16[8,8] { +; CHECK-NEXT: [[P0:%[^ ]+]] = f16[8,16]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = f16[16,8]{1,0} parameter(1) +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f16[8,8]{1,0} custom-call([[P0]], [[P1]]), +; CHECK: custom_call_target="__cublas$lt$matmul", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":0 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"RELU\" +; CHECK: }" + )"); +} + +TEST_F(CublasLtGemmRewriteTest, ReluActivationF16Padded) { + if (!GetCudaComputeCapability().IsAtLeast(se::CudaComputeCapability::VOLTA)) { + GTEST_SKIP() << "Padding of GEMM operands only implemented on " + "architectures with Tensor Cores."; + } + const char* hlo_text = R"( +HloModule test + +ENTRY test { + x = f16[6,12] parameter(0) + y = f16[12,6] parameter(1) + dot_a = f16[6,6] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + c = f16[] constant(0) + c_bcast = f16[6,6] broadcast(c), dimensions={} + ROOT out = f16[6,6] maximum(dot_a, c_bcast) +} + +)"; + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + MatchOptimizedHlo(hlo_text, + R"( + +; CHECK-LABEL: ENTRY %test (x: f16[6,12], y: f16[12,6]) -> f16[6,6] { +; CHECK-NEXT: [[P0:%[^ ]+]] = f16[6,12]{1,0} parameter(0) +; CHECK-NEXT: [[C0:%[^ ]+]] = f16[] constant(0) +; CHECK-NEXT: [[P0_PADDED:%[^ ]+]] = f16[8,16]{1,0} pad([[P0]], [[C0]]), padding=0_2x0_4 +; CHECK-NEXT: [[P1:%[^ ]+]] = f16[12,6]{1,0} parameter(1) +; CHECK-NEXT: [[P1_PADDED:%[^ ]+]] = f16[16,8]{1,0} pad([[P1]], [[C0]]), padding=0_4x0_2 +; CHECK-NEXT: [[MATMUL:%[^ ]+]] = f16[8,8]{1,0} custom-call([[P0_PADDED]], [[P1_PADDED]]), +; CHECK: custom_call_target="__cublas$lt$matmul", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":0 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"RELU\" +; CHECK: }" +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f16[6,6]{1,0} slice([[MATMUL]]), slice={[0:6], [0:6]} + )"); +} + +TEST_F(CublasLtGemmRewriteTest, MatrixBiasReluActivationF16) { + const char* hlo_text = R"( +HloModule test + +ENTRY test { + x = f16[8,16] parameter(0) + y = f16[16,8] parameter(1) + z = f16[8,8] parameter(2) + dot_a = f16[8,8] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + add = f16[8,8] add(dot_a, z) + c = f16[] constant(0) + c_bcast = f16[8,8] broadcast(c), dimensions={} + ROOT out = f16[8,8] maximum(add, c_bcast) +} + +)"; + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3})); + MatchOptimizedHlo(hlo_text, + R"( + +; CHECK-LABEL: ENTRY %test (x: f16[8,16], y: f16[16,8], z: f16[8,8]) -> f16[8,8] { +; CHECK-NEXT: [[P0:%[^ ]+]] = f16[8,16]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = f16[16,8]{1,0} parameter(1) +; CHECK-NEXT: [[P2:%[^ ]+]] = f16[8,8]{1,0} parameter(2) +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f16[8,8]{1,0} custom-call([[P0]], [[P1]], [[P2]]), +; CHECK: custom_call_target="__cublas$lt$matmul", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":1 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"RELU\" +; CHECK: }" + )"); +} + +// For F16, the operands are padded on GPUs with Tensor Cores (i.e. Volta and +// newer architectures) so that the sizes of all dimensions are multiples of 8. +TEST_F(CublasLtGemmRewriteTest, VectorBiasReluActivationF16Unpadded) { + const char* hlo_text = R"( +HloModule test + +ENTRY test { + x = f16[8,16] parameter(0) + y = f16[16,8] parameter(1) + z = f16[8] parameter(2) + dot_a = f16[8,8] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + z_bcast = f16[8,8] broadcast(z), dimensions={1} + add = f16[8,8] add(dot_a, z_bcast) + c = f16[] constant(0) + c_bcast = f16[8,8] broadcast(c), dimensions={} + ROOT out = f16[8,8] maximum(add, c_bcast) +} + +)"; + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3})); + MatchOptimizedHlo(hlo_text, + R"( + +; CHECK-LABEL: ENTRY %test (x: f16[8,16], y: f16[16,8], z: f16[8]) -> f16[8,8] { +; CHECK-NEXT: [[P0:%[^ ]+]] = f16[8,16]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = f16[16,8]{1,0} parameter(1) +; CHECK-NEXT: [[P2:%[^ ]+]] = f16[8]{0} parameter(2) +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f16[8,8]{1,0} custom-call([[P0]], [[P1]], [[P2]]), +; CHECK: custom_call_target="__cublas$lt$matmul", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":0 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"BIAS_RELU\" +; CHECK: }" + )"); +} + +TEST_F(CublasLtGemmRewriteTest, VectorBiasReluActivationF16Padded) { + if (!GetCudaComputeCapability().IsAtLeast(se::CudaComputeCapability::VOLTA)) { + GTEST_SKIP() << "Padding of GEMM operands only implemented on " + "architectures with Tensor Cores."; + } + const char* hlo_text = R"( +HloModule test + +ENTRY test { + x = f16[6,12] parameter(0) + y = f16[12,6] parameter(1) + z = f16[6] parameter(2) + dot_a = f16[6,6] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + z_bcast = f16[6,6] broadcast(z), dimensions={1} + add = f16[6,6] add(dot_a, z_bcast) + c = f16[] constant(0) + c_bcast = f16[6,6] broadcast(c), dimensions={} + ROOT out = f16[6,6] maximum(add, c_bcast) +} + +)"; + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3})); + MatchOptimizedHlo(hlo_text, + R"( + +; CHECK-LABEL: ENTRY %test (x: f16[6,12], y: f16[12,6], z: f16[6]) -> f16[6,6] { +; CHECK-NEXT: [[P0:%[^ ]+]] = f16[6,12]{1,0} parameter(0) +; CHECK-NEXT: [[C0:%[^ ]+]] = f16[] constant(0) +; CHECK-NEXT: [[P0_PADDED:%[^ ]+]] = f16[8,16]{1,0} pad([[P0]], [[C0]]), padding=0_2x0_4 +; CHECK-NEXT: [[P1:%[^ ]+]] = f16[12,6]{1,0} parameter(1) +; CHECK-NEXT: [[P1_PADDED:%[^ ]+]] = f16[16,8]{1,0} pad([[P1]], [[C0]]), padding=0_4x0_2 +; CHECK-NEXT: [[P2:%[^ ]+]] = f16[6]{0} parameter(2) +; CHECK-NEXT: [[MATMUL:%[^ ]+]] = f16[8,8]{1,0} custom-call([[P0_PADDED]], [[P1_PADDED]], [[P2]]), +; CHECK: custom_call_target="__cublas$lt$matmul", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: "beta\":0 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"BIAS_RELU\" +; CHECK: }" + )"); +} + +// For bfloat16, the sizes of all dimensions of the operands are required to be +// multiples of 8 to allow matrix bias fusion. +TEST_F(CublasLtGemmRewriteTest, MatrixBiasBF16) { + const char* hlo_text = R"( +HloModule test + +ENTRY test { + x = bf16[8,16] parameter(0) + y = bf16[16,8] parameter(1) + z = bf16[8,8] parameter(2) + dot_a = bf16[8,8] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT out = bf16[8,8] add(dot_a, z) +} + +)"; + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3})); + MatchOptimizedHlo(hlo_text, + R"( + +; CHECK-LABEL: ENTRY %test (x: bf16[8,16], y: bf16[16,8], z: bf16[8,8]) -> bf16[8,8] { +; CHECK-DAG: [[P0:%[^ ]+]] = bf16[8,16]{1,0} parameter(0) +; CHECK-DAG: [[P1:%[^ ]+]] = bf16[16,8]{1,0} parameter(1) +; CHECK-DAG: [[P2:%[^ ]+]] = bf16[8,8]{1,0} parameter(2) +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = bf16[8,8]{1,0} custom-call([[P0]], [[P1]], [[P2]]), +; CHECK: custom_call_target="__cublas$lt$matmul", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":1 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"DEFAULT\" +; CHECK: }" + )"); +} + +TEST_F(CublasLtGemmRewriteTest, MatrixBiasBitcastBF16) { + const char* hlo_text = R"( +HloModule test + +ENTRY test { + x = bf16[8,16] parameter(0) + y = bf16[16,8] parameter(1) + bias = bf16[2,4,8] parameter(2) + dot = bf16[8,8] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + bitcast = bf16[2,4,8] bitcast(dot) + ROOT out = bf16[2,4,8] add(bitcast, bias) +} + +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + GemmRewriter pass(GetCudaComputeCapability()); + TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); + EXPECT_TRUE(changed); + + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch( + m::Bitcast(m::CustomCall( + {"__cublas$lt$matmul"}, + m::Parameter(0).WithShape(BF16, {8, 16}), + m::Parameter(1).WithShape(BF16, {16, 8}), + m::Bitcast(m::Parameter(2)).WithShape(BF16, {8, 8}))) + .WithShape(BF16, {2, 4, 8}))); +} + +// For bfloat16, the operands are padded if necessary on Ampere and newer +// architectures so that the sizes of all dimensions are multiples of 8. +TEST_F(CublasLtGemmRewriteTest, VectorBiasBF16Unpadded) { + const char* hlo_text = R"( +HloModule test + +ENTRY test { + x = bf16[8,16] parameter(0) + y = bf16[16,8] parameter(1) + z = bf16[8] parameter(2) + dot_a = bf16[8,8] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + z_bcast = bf16[8,8] broadcast(z), dimensions={1} + ROOT add = bf16[8,8] add(dot_a, z_bcast) +} + +)"; + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{8e-3, 2e-3})); + MatchOptimizedHlo(hlo_text, + R"( + +; CHECK-LABEL: ENTRY %test (x: bf16[8,16], y: bf16[16,8], z: bf16[8]) -> bf16[8,8] { +; CHECK-DAG: [[P0:%[^ ]+]] = bf16[8,16]{1,0} parameter(0) +; CHECK-DAG: [[P1:%[^ ]+]] = bf16[16,8]{1,0} parameter(1) +; CHECK-DAG: [[P2:%[^ ]+]] = bf16[8]{0} parameter(2) +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = bf16[8,8]{1,0} custom-call([[P0]], [[P1]], [[P2]]), +; CHECK: custom_call_target="__cublas$lt$matmul", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":0 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"BIAS\" +; CHECK: }" + )"); +} + +TEST_F(CublasLtGemmRewriteTest, VectorBiasBF16Padded) { + if (!GetCudaComputeCapability().IsAtLeast( + se::CudaComputeCapability::AMPERE)) { + GTEST_SKIP() << "Padding of GEMM operands in bfloat16 only implemented on " + "Ampere and newer architectures."; + } + const char* hlo_text = R"( +HloModule test + +ENTRY test { + x = bf16[6,12] parameter(0) + y = bf16[12,6] parameter(1) + z = bf16[6] parameter(2) + dot_a = bf16[6,6] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + z_bcast = bf16[6,6] broadcast(z), dimensions={1} + ROOT add = bf16[6,6] add(dot_a, z_bcast) +} + +)"; + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3})); + MatchOptimizedHlo(hlo_text, + R"( + +; CHECK-LABEL: ENTRY %test (x: bf16[6,12], y: bf16[12,6], z: bf16[6]) -> bf16[6,6] { +; CHECK-DAG: [[P0:%[^ ]+]] = bf16[6,12]{1,0} parameter(0) +; CHECK-DAG: [[C0:%[^ ]+]] = bf16[] constant(0) +; CHECK-DAG: [[P0_PADDED:%[^ ]+]] = bf16[8,16]{1,0} pad([[P0]], [[C0]]), padding=0_2x0_4 +; CHECK-DAG: [[P1:%[^ ]+]] = bf16[12,6]{1,0} parameter(1) +; CHECK-DAG: [[P1_PADDED:%[^ ]+]] = bf16[16,8]{1,0} pad([[P1]], [[C0]]), padding=0_4x0_2 +; CHECK-DAG: [[P2:%[^ ]+]] = bf16[6]{0} parameter(2) +; CHECK-NEXT: [[MATMUL:%[^ ]+]] = bf16[8,8]{1,0} custom-call([[P0_PADDED]], [[P1_PADDED]], [[P2]]), +; CHECK: custom_call_target="__cublas$lt$matmul", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":0 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"BIAS\" +; CHECK: }" +; CHECK-NEXT: [[OUT:%[^ ]+]] = bf16[6,6]{1,0} slice([[MATMUL]]), slice={[0:6], [0:6]} + )"); +} + +// For bfloat16, the operands are padded if necessary on Ampere and newer +// architectures so that the sizes of all dimensions are multiples of 8. +TEST_F(CublasLtGemmRewriteTest, ReluActivationBF16Unpadded) { + const char* hlo_text = R"( +HloModule test + +ENTRY test { + x = bf16[8,16] parameter(0) + y = bf16[16,8] parameter(1) + dot_a = bf16[8,8] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + c = bf16[] constant(0) + c_bcast = bf16[8,8] broadcast(c), dimensions={} + ROOT out = bf16[8,8] maximum(dot_a, c_bcast) +} + +)"; + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3})); + MatchOptimizedHlo(hlo_text, + R"( + +; CHECK-LABEL: ENTRY %test (x: bf16[8,16], y: bf16[16,8]) -> bf16[8,8] { +; CHECK-DAG: [[P0:%[^ ]+]] = bf16[8,16]{1,0} parameter(0) +; CHECK-DAG: [[P1:%[^ ]+]] = bf16[16,8]{1,0} parameter(1) +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = bf16[8,8]{1,0} custom-call([[P0]], [[P1]]), +; CHECK: custom_call_target="__cublas$lt$matmul", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":0 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"RELU\" +; CHECK: }" + )"); +} + +TEST_F(CublasLtGemmRewriteTest, ReluActivationBF16Padded) { + if (!GetCudaComputeCapability().IsAtLeast( + se::CudaComputeCapability::AMPERE)) { + GTEST_SKIP() << "Padding of GEMM operands in bfloat16 only implemented on " + "Ampere and newer architectures."; + } + const char* hlo_text = R"( +HloModule test + +ENTRY test { + x = bf16[6,12] parameter(0) + y = bf16[12,6] parameter(1) + dot_a = bf16[6,6] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + c = bf16[] constant(0) + c_bcast = bf16[6,6] broadcast(c), dimensions={} + ROOT out = bf16[6,6] maximum(dot_a, c_bcast) +} + +)"; + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + MatchOptimizedHlo(hlo_text, + R"( + +; CHECK-LABEL: ENTRY %test (x: bf16[6,12], y: bf16[12,6]) -> bf16[6,6] { +; CHECK-DAG: [[P0:%[^ ]+]] = bf16[6,12]{1,0} parameter(0) +; CHECK-DAG: [[C0:%[^ ]+]] = bf16[] constant(0) +; CHECK-DAG: [[P0_PADDED:%[^ ]+]] = bf16[8,16]{1,0} pad([[P0]], [[C0]]), padding=0_2x0_4 +; CHECK-DAG: [[P1:%[^ ]+]] = bf16[12,6]{1,0} parameter(1) +; CHECK-DAG: [[P1_PADDED:%[^ ]+]] = bf16[16,8]{1,0} pad([[P1]], [[C0]]), padding=0_4x0_2 +; CHECK-NEXT: [[MATMUL:%[^ ]+]] = bf16[8,8]{1,0} custom-call([[P0_PADDED]], [[P1_PADDED]]), +; CHECK: custom_call_target="__cublas$lt$matmul", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":0 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"RELU\" +; CHECK: }" +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = bf16[6,6]{1,0} slice([[MATMUL]]), slice={[0:6], [0:6]} + )"); +} + +// For bfloat16, the operands are padded if necessary on Ampere and newer +// architectures so that the sizes of all dimensions are multiples of 8. +TEST_F(CublasLtGemmRewriteTest, VectorBiasReluActivationBF16Unpadded) { + const char* hlo_text = R"( +HloModule test + +ENTRY test { + x = bf16[8,16] parameter(0) + y = bf16[16,8] parameter(1) + z = bf16[8] parameter(2) + dot_a = bf16[8,8] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + z_bcast = bf16[8,8] broadcast(z), dimensions={1} + add = bf16[8,8] add(dot_a, z_bcast) + c = bf16[] constant(0) + c_bcast = bf16[8,8] broadcast(c), dimensions={} + ROOT out = bf16[8,8] maximum(add, c_bcast) +} + +)"; + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{8e-3, 2e-3})); + MatchOptimizedHlo(hlo_text, + R"( + +; CHECK-LABEL: ENTRY %test (x: bf16[8,16], y: bf16[16,8], z: bf16[8]) -> bf16[8,8] { +; CHECK-DAG: [[P0:%[^ ]+]] = bf16[8,16]{1,0} parameter(0) +; CHECK-DAG: [[P1:%[^ ]+]] = bf16[16,8]{1,0} parameter(1) +; CHECK-DAG: [[P2:%[^ ]+]] = bf16[8]{0} parameter(2) +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = bf16[8,8]{1,0} custom-call([[P0]], [[P1]], [[P2]]), +; CHECK: custom_call_target="__cublas$lt$matmul", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":0 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"BIAS_RELU\" +; CHECK: }" + )"); +} + +TEST_F(CublasLtGemmRewriteTest, VectorBiasReluActivationBF16Padded) { + if (!GetCudaComputeCapability().IsAtLeast( + se::CudaComputeCapability::AMPERE)) { + GTEST_SKIP() << "Padding of GEMM operands in bfloat16 only implemented on " + "Ampere and newer architectures."; + } + const char* hlo_text = R"( +HloModule test + +ENTRY test { + x = bf16[6,12] parameter(0) + y = bf16[12,6] parameter(1) + z = bf16[6] parameter(2) + dot_a = bf16[6,6] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + z_bcast = bf16[6,6] broadcast(z), dimensions={1} + add = bf16[6,6] add(dot_a, z_bcast) + c = bf16[] constant(0) + c_bcast = bf16[6,6] broadcast(c), dimensions={} + ROOT out = bf16[6,6] maximum(add, c_bcast) +} + +)"; + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3})); + MatchOptimizedHlo(hlo_text, + R"( - private: - static constexpr const char* kCustomCallTargetPlaceholder{ - "<>"}; -}; +; CHECK-LABEL: ENTRY %test (x: bf16[6,12], y: bf16[12,6], z: bf16[6]) -> bf16[6,6] { +; CHECK-DAG: [[P0:%[^ ]+]] = bf16[6,12]{1,0} parameter(0) +; CHECK-DAG: [[C0:%[^ ]+]] = bf16[] constant(0) +; CHECK-DAG: [[P0_PADDED:%[^ ]+]] = bf16[8,16]{1,0} pad([[P0]], [[C0]]), padding=0_2x0_4 +; CHECK-DAG: [[P1:%[^ ]+]] = bf16[12,6]{1,0} parameter(1) +; CHECK-DAG: [[P1_PADDED:%[^ ]+]] = bf16[16,8]{1,0} pad([[P1]], [[C0]]), padding=0_4x0_2 +; CHECK-DAG: [[P2:%[^ ]+]] = bf16[6]{0} parameter(2) +; CHECK-NEXT: [[MATMUL:%[^ ]+]] = bf16[8,8]{1,0} custom-call([[P0_PADDED]], [[P1_PADDED]], [[P2]]), +; CHECK: custom_call_target="__cublas$lt$matmul", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":0 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"BIAS_RELU\" +; CHECK: }" +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = bf16[6,6]{1,0} slice([[MATMUL]]), slice={[0:6], [0:6]} + )"); +} + +TEST_F(CublasLtGemmRewriteTest, VectorBiasReluActivationF64) { + const char* hlo_text = R"( +HloModule test + +ENTRY test { + x = f64[2,3] parameter(0) + y = f64[3,4] parameter(1) + z = f64[4] parameter(2) + dot_a = f64[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + z_bcast = f64[2,4] broadcast(z), dimensions={1} + add = f64[2,4] add(dot_a, z_bcast) + c = f64[] constant(0) + c_bcast = f64[2,4] broadcast(c), dimensions={} + ROOT out = f64[2,4] maximum(add, c_bcast) +} + +)"; + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-10, 1e-10})); + MatchOptimizedHlo(hlo_text, + R"( + +; CHECK-LABEL: ENTRY %test (x: f64[2,3], y: f64[3,4], z: f64[4]) -> f64[2,4] { +; CHECK-NEXT: [[P0:%[^ ]+]] = f64[2,3]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = f64[3,4]{1,0} parameter(1) +; CHECK-NEXT: [[P2:%[^ ]+]] = f64[4]{0} parameter(2) +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f64[2,4]{1,0} custom-call([[P0]], [[P1]], [[P2]]), +; CHECK: custom_call_target="__cublas$lt$matmul", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":0 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"BIAS_RELU\" +; CHECK: }" + )"); +} + +TEST_F(CublasLtGemmRewriteTest, AlphaSimpleRewriteBiasAddActivation) { + const char* hlo_text = R"( +HloModule test + +ENTRY test { + x = f32[2,3] parameter(0) + y = f32[3,4] parameter(1) + z = f32[4] parameter(2) + k = f32[] constant(3.0) + k_bcast = f32[2,4] broadcast(k), dimensions={} + dot_a = f32[2,4] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + dot_a_multiplied = f32[2, 4] multiply(dot_a, k_bcast) + z_bcast = f32[2,4] broadcast(z), dimensions={1} + add = f32[2,4] add(dot_a_multiplied, z_bcast) + c = f32[] constant(0) + c_bcast = f32[2,4] broadcast(c), dimensions={} + ROOT out = f32[2,4] maximum(add, c_bcast) +} + +)"; + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + MatchOptimizedHlo(hlo_text, + R"( + +; CHECK-LABEL: ENTRY %test (x: f32[2,3], y: f32[3,4], z: f32[4]) -> f32[2,4] { +; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,3]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = f32[3,4]{1,0} parameter(1) +; CHECK-NEXT: [[P2:%[^ ]+]] = f32[4]{0} parameter(2) +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[2,4]{1,0} custom-call([[P0]], [[P1]], [[P2]]), +; CHECK: custom_call_target="__cublas$lt$matmul", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":3 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":0 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"0\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"BIAS_RELU\" +; CHECK: }" + )"); +} + +TEST_F(CublasLtGemmRewriteTest, FoldConstantBias) { + const char* hlo_text = R"( +HloModule test +ENTRY test { + x = f32[2,2] parameter(0) + y = f32[2,2] parameter(1) + bias = f32[2,2] broadcast(f32[2] constant({0, 0})), dimensions={0} + + dot1 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + bias1 = f32[2,2] parameter(2) + sum1 = add(dot1, bias1) + + dot2 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + sum2 = add(dot2, f32[2,2] reshape(bias)) + + dot3 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + bias3 = f32[2,2] transpose(bias), dimensions={1,0} + sum3 = add(dot3, bias3) + + dot4 = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + sum4 = add(dot4, f32[2,2] bitcast(bias)) + + ROOT root = tuple(sum1, sum2, sum3, sum4) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + GemmRewriter pass(GetCudaComputeCapability()); + TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); + SCOPED_TRACE(module->ToString()); + EXPECT_TRUE(changed); + + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Tuple( + m::CustomCall(m::Parameter(0), m::Parameter(1), m::Parameter()), + m::CustomCall(m::Parameter(0), m::Parameter(1), m::Constant()), + m::CustomCall(m::Parameter(0), m::Parameter(1), m::Constant()), + m::CustomCall(m::Parameter(0), m::Parameter(1), m::Constant())))); +} + +TEST_F(CublasLtGemmRewriteTest, MultipleMaximumUsers) { + const char* hlo_text = R"( +HloModule multiple_maximum_users + +relu { + Arg_0 = f32[3,896,54]{2,1,0} parameter(0) + constant = f32[] constant(0) + broadcast = f32[3,896,54]{2,1,0} broadcast(constant), dimensions={} + ROOT maximum = f32[3,896,54]{2,1,0} maximum(Arg_0, broadcast) +} + +ENTRY main { + constant = f32[] constant(1) + broadcast_1 = f32[3,896,1024]{2,1,0} broadcast(constant), dimensions={} + Arg_2 = f32[1024,54]{1,0} parameter(2) + dot = f32[3,896,54]{2,1,0} dot(broadcast_1, Arg_2), lhs_contracting_dims={2}, rhs_contracting_dims={0} + Arg_1 = f32[54]{0} parameter(1) + broadcast_2 = f32[3,896,54]{2,1,0} broadcast(Arg_1), dimensions={2} + add = f32[3,896,54]{2,1,0} add(dot, broadcast_2) + call = f32[3,896,54]{2,1,0} call(add), to_apply=relu + Arg_0 = f32[1]{0} parameter(0) + reshape_1 = f32[1,1,1]{2,1,0} reshape(Arg_0) + broadcast_3 = f32[1,1,1]{2,1,0} broadcast(reshape_1), dimensions={0,1,2} + reshape_2 = f32[] reshape(broadcast_3) + broadcast_4 = f32[3,896,54]{2,1,0} broadcast(reshape_2), dimensions={} + multiply = f32[3,896,54]{2,1,0} multiply(call, broadcast_4) + ROOT tuple = (f32[3,896,54]{2,1,0}, f32[3,896,54]{2,1,0}) tuple(multiply, call) +} +)"; + + // TODO(cjfj): Why do we need to relax the error constraint here?! + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-4})); + MatchOptimizedHlo(hlo_text, + R"( +; CHECK: custom_call_target="__cublas$lt$matmul", + )"); +} class ParameterizedFp8GemmRewriteTest : public ParameterizedGemmRewriteTest { public: @@ -114,6 +4198,7 @@ class ParameterizedFp8GemmRewriteTest : public ParameterizedGemmRewriteTest { return; } EXPECT_TRUE(RunAndCompare(hlo_text, error_spec)); + // Most FP8 tests directly create a GemmRewriter and check the output. // Here, also run the entire HLO pass pipeline to ensure no other passes // interfere with GemmRewriter's pattern matching. @@ -1174,6 +5259,70 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF32VectorBiasF8) { )"); } +TEST_P(ParameterizedFp8GemmRewriteTest, + ScaledABUnscaledDVectorBiasThenReluActivationF8) { +#if CUDA_VERSION < 12000 + GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; +#endif // CUDA_VERSION < 12000 + const char* hlo_text = R"( + HloModule test + + ENTRY test { + x = f8e4m3fn[16,32] parameter(0) + y = f8e4m3fn[32,16] parameter(1) + b = f16[16] parameter(2) + b_bcast = f16[16,16] broadcast(b), dimensions={1} + x_f32 = f16[16,32] convert(x) + y_f32 = f16[32,16] convert(y) + x_scale = f16[] parameter(3) + y_scale = f16[] parameter(4) + x_scale_bcast = f16[16,32] broadcast(x_scale), dimensions={} + y_scale_bcast = f16[32,16] broadcast(y_scale), dimensions={} + x_unscaled = f16[16,32] multiply(x_f32, x_scale_bcast) + y_unscaled = f16[32,16] multiply(y_f32, y_scale_bcast) + c = f16[] constant(0) + c_bcast = f16[16,16] broadcast(c), dimensions={} + dot_a0 = f16[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0} + dot_a = f16[16,16] add(dot_a0, b_bcast) + ROOT out = f16[16,16] maximum(dot_a, c_bcast) + } +)"; + + CheckFp8IfOnHopper(hlo_text, ErrorSpec{2e-3, 0.}); + RunAndFilecheckHloRewrite(hlo_text, + GemmRewriter(se::CudaComputeCapability{ + se::CudaComputeCapability::HOPPER, 0}), + R"( +; CHECK-LABEL: ENTRY %test (x: f8e4m3fn[16,32], y: f8e4m3fn[32,16], b: f16[16], x_scale: f16[], y_scale: f16[]) -> f16[16,16] { +; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[32,16]{1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[16,32]{1,0} transpose([[P1]]), dimensions={1,0} +; CHECK-NEXT: [[P2:%[^ ]+]] = f16[] parameter(3) +; CHECK-NEXT: [[CV:%[^ ]+]] = f32[] convert([[P2]]) +; CHECK-NEXT: [[P3:%[^ ]+]] = f16[] parameter(4) +; CHECK-NEXT: [[CV1:%[^ ]+]] = f32[] convert([[P3]]) +; CHECK-NEXT: [[C:%[^ ]+]] = f32[] constant(1) +; CHECK-NEXT: [[VB:%[^ ]+]] = f16[16]{0} parameter(2) +; CHECK : ROOT [[OUT:%[^ ]+]] = f16[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[CV]], [[CV1]], [[C]], /*index=5*/[[C]], [[VB]]), +; CHECK: custom_call_target="__cublas$lt$matmul$f8", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":0 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"BIAS_RELU\" +; CHECK: }" + )"); +} + TEST_P(ParameterizedFp8GemmRewriteTest, Rank3ScaledABUnscaledDVectorBiasF8) { #if CUDA_VERSION < 12000 GTEST_SKIP() << "A matrix bias on a matmul is only supported in CUDA 12"; @@ -1490,70 +5639,6 @@ TEST_P(ParameterizedFp8GemmRewriteTest, )"); } -TEST_P(ParameterizedFp8GemmRewriteTest, - ScaledABUnscaledDVectorBiasThenReluActivationF8) { -#if CUDA_VERSION < 12000 - GTEST_SKIP() << "F8 gemm rewrite is only supported in CUDA 12 and above."; -#endif // CUDA_VERSION < 12000 - const char* hlo_text = R"( - HloModule test - - ENTRY test { - x = f8e4m3fn[16,32] parameter(0) - y = f8e4m3fn[32,16] parameter(1) - b = f16[16] parameter(2) - b_bcast = f16[16,16] broadcast(b), dimensions={1} - x_f32 = f16[16,32] convert(x) - y_f32 = f16[32,16] convert(y) - x_scale = f16[] parameter(3) - y_scale = f16[] parameter(4) - x_scale_bcast = f16[16,32] broadcast(x_scale), dimensions={} - y_scale_bcast = f16[32,16] broadcast(y_scale), dimensions={} - x_unscaled = f16[16,32] multiply(x_f32, x_scale_bcast) - y_unscaled = f16[32,16] multiply(y_f32, y_scale_bcast) - c = f16[] constant(0) - c_bcast = f16[16,16] broadcast(c), dimensions={} - dot_a0 = f16[16,16] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0} - dot_a = f16[16,16] add(dot_a0, b_bcast) - ROOT out = f16[16,16] maximum(dot_a, c_bcast) - } -)"; - - CheckFp8IfOnHopper(hlo_text, ErrorSpec{2e-3, 0.}); - RunAndFilecheckHloRewrite(hlo_text, - GemmRewriter(se::CudaComputeCapability{ - se::CudaComputeCapability::HOPPER, 0}), - R"( -; CHECK-LABEL: ENTRY %test (x: f8e4m3fn[16,32], y: f8e4m3fn[32,16], b: f16[16], x_scale: f16[], y_scale: f16[]) -> f16[16,16] { -; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(0) -; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[32,16]{1,0} parameter(1) -; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[16,32]{1,0} transpose([[P1]]), dimensions={1,0} -; CHECK-NEXT: [[P2:%[^ ]+]] = f16[] parameter(3) -; CHECK-NEXT: [[CV:%[^ ]+]] = f32[] convert([[P2]]) -; CHECK-NEXT: [[P3:%[^ ]+]] = f16[] parameter(4) -; CHECK-NEXT: [[CV1:%[^ ]+]] = f32[] convert([[P3]]) -; CHECK-NEXT: [[C:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[VB:%[^ ]+]] = f16[16]{0} parameter(2) -; CHECK : ROOT [[OUT:%[^ ]+]] = f16[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[CV]], [[CV1]], [[C]], /*index=5*/[[C]], [[VB]]), -; CHECK: custom_call_target="__cublas$lt$matmul$f8", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":0 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"BIAS_RELU\" -; CHECK: }" - )"); -} - TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDMatrixBiasThenVectorBiasF8) { #if CUDA_VERSION < 12000 @@ -2036,6 +6121,66 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF8TF32E5M2) { INSTANTIATE_TEST_SUITE_P(Fp8CublasTestsBothLegacyAndLt, ParameterizedFp8GemmRewriteTest, ::testing::Bool()); +TEST_F(GemmRewriteTest, NoFuseBiasBroadcast) { + const char* hlo = R"( + +HloModule module + +ENTRY main.10 { + Arg_0.1 = f16[384,128]{1,0} parameter(0) + Arg_1.2 = f16[128,256]{1,0} parameter(1) + dot.4 = f16[384,256]{1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0} + Arg_2.3 = f16[256]{0} parameter(2) + reshape.5 = f16[1,256]{1,0} reshape(Arg_2.3) + broadcast.6 = f16[1,256]{1,0} broadcast(reshape.5), dimensions={0,1} + reshape.7 = f16[256]{0} reshape(broadcast.6) + broadcast.8 = f16[384,256]{1,0} broadcast(reshape.7), dimensions={1} + ROOT add.9 = f16[384,256]{1,0} add(dot.4, broadcast.8) +})"; + + MatchOptimizedHlo(hlo, R"( +// CHECK: \"beta\":0 + )"); +} + +class GemmRewriteAllocationTest : public GpuCodegenTest { + public: + void CheckNumberOfAllocations(const std::string& hlo, + int expected_number_of_allocations) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr optimized_module, + GetOptimizedModule(hlo)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr executable, + backend().compiler()->RunBackend( + std::move(optimized_module), backend().default_stream_executor(), + backend().default_stream_executor()->GetAllocator())); + GpuExecutable* gpu_executable = + static_cast(executable.get()); + absl::Span allocations = + gpu_executable->GetAllocations(); + CHECK_EQ(allocations.size(), expected_number_of_allocations); + } +}; + +// TEST_F(GemmRewriteAllocationTest, SharedBufferAssignment) { +// const char* hlo_text = R"( +// HloModule SharedBufferAssignment + +// ENTRY AddDotsFunc { +// x = f32[2,2] parameter(0) +// y = f32[2,2] parameter(1) +// bias = f32[2,2] add(x, y) +// dot = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} +// ROOT out = f32[2,2] add(dot, bias) +// } + +// )"; + +// // Bias should be fused into the multiplication. +// CheckNumberOfAllocations(hlo_text, 3); +// EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); +// } + } // namespace } // namespace gpu } // namespace xla From 3e1ad85c79fe469364e9e366538318774afaebeb Mon Sep 17 00:00:00 2001 From: shuw Date: Mon, 12 Jun 2023 08:50:31 -0700 Subject: [PATCH 013/410] All pass --- .../service/gpu/tests/gemm_rewrite_test.cc | 36 +++++++++---------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc index 4a0e938a8d9d31..4db9f60667d57b 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc @@ -6162,24 +6162,24 @@ class GemmRewriteAllocationTest : public GpuCodegenTest { } }; -// TEST_F(GemmRewriteAllocationTest, SharedBufferAssignment) { -// const char* hlo_text = R"( -// HloModule SharedBufferAssignment - -// ENTRY AddDotsFunc { -// x = f32[2,2] parameter(0) -// y = f32[2,2] parameter(1) -// bias = f32[2,2] add(x, y) -// dot = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} -// ROOT out = f32[2,2] add(dot, bias) -// } - -// )"; - -// // Bias should be fused into the multiplication. -// CheckNumberOfAllocations(hlo_text, 3); -// EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); -// } +TEST_F(GemmRewriteAllocationTest, SharedBufferAssignment) { + const char* hlo_text = R"( +HloModule SharedBufferAssignment + +ENTRY AddDotsFunc { + x = f32[2,2] parameter(0) + y = f32[2,2] parameter(1) + bias = f32[2,2] add(x, y) + dot = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT out = f32[2,2] add(dot, bias) +} + +)"; + + // Bias should be fused into the multiplication. + CheckNumberOfAllocations(hlo_text, 3); + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); +} } // namespace } // namespace gpu From 9fb316ca3da873789eeb4131d92c24a1f25c562d Mon Sep 17 00:00:00 2001 From: shuw Date: Mon, 12 Jun 2023 09:25:47 -0700 Subject: [PATCH 014/410] Format and remove debug_comments --- .../compiler/xla/service/gpu/gemm_rewriter.cc | 260 +++++++----------- 1 file changed, 96 insertions(+), 164 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc b/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc index 6c23c047db480d..973f473030d7dd 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc @@ -504,30 +504,28 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { HloInstruction *optional_slice = nullptr; HloInstruction *optional_convert = nullptr; HloInstruction *optional_bitcast = nullptr; - VLOG(1)<<"Yes, in HandleAdd!\n"; - // VLOG(1)<<"shuw:" << instr->GetModule()->ToString(); // only on for rank3 debug - // Attempt to elide broadcast and fuse addition of a vector bias into GEMM, - // including when slicing is applied to the result. + // Attempt to elide broadcast and fuse addition of a vector bias into + // GEMM, including when slicing is applied to the result. if (Match(instr, m::AddAnyOrder( - OptionalBitcast(&optional_bitcast,OptionalSlice( - &optional_slice, - CublasLtMatmulMaybeF8(&existing_gemm).WithOneUser()) - .WithOneUser()).WithOneUser(), + OptionalBitcast( + &optional_bitcast, + OptionalSlice( + &optional_slice, + CublasLtMatmulMaybeF8(&existing_gemm).WithOneUser()) + .WithOneUser()) + .WithOneUser(), m::Broadcast(&bias, OptionalConvert(&optional_convert, m::Op()))))) { - VLOG(1) << "Yes, in HandleAdd VECTOR!\n"; - TF_ASSIGN_OR_RETURN(bool was_fused, - FuseVectorBiasAdd(instr, bias, existing_gemm, - optional_slice, optional_convert, - optional_bitcast)); + TF_ASSIGN_OR_RETURN( + bool was_fused, + FuseVectorBiasAdd(instr, bias, existing_gemm, optional_slice, + optional_convert, optional_bitcast)); if (was_fused) { return OkStatus(); } } - // return OkStatus(); - VLOG(1)<<"HandleAdd:Yes, 11111111111111!\n"; // Attempt to elide broadcast and fuse addition of a vector bias into // *batched* GEMM as a matrix bias addition using FuseMatrixBiasAdd. // add(bitcast(gemm(a, b)), broadcast(bias)) -> @@ -550,7 +548,6 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { // Continue below. instr = new_add; } - VLOG(1)<<"Yes, 22222222222222!\n"; // Do not fuse broadcast unless we can fuse its input, as it will cause // broadcast materialization. auto is_not_broadcast = [](const HloInstruction *instr) { @@ -585,22 +582,14 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { // Continue below transforming new_add. instr = new_add; } - VLOG(1)<<"Yes, 333333333333! if gemm null\n" << (existing_gemm == nullptr); - VLOG(1)<<"Yes, 333333333333!\n" << instr->ToShortString(); if (Match(instr, m::AddAnyOrder( GemmOrCublasLtMatmulMaybeF8(&existing_gemm).WithOneUser(), m::Op(&bias).WithPredicate(is_not_broadcast)))) { - VLOG(1) << "Yes, heads to FuseMatrixBiasAdd!\n"; return FuseMatrixBiasAdd(instr, bias, existing_gemm); } - VLOG(1) << "Herer cccccccccccccc\n"; - // if (bias) { - // VLOG(1) << "bias not emp"; - // VLOG(1) << bias->ToString(); - // } HloInstruction *optional_bitcast_matrix = nullptr; - HloInstruction* optional_slice_matrix = nullptr; + HloInstruction *optional_slice_matrix = nullptr; if (Match(instr, m::AddAnyOrder( OptionalBitcast( @@ -610,12 +599,6 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { CublasLtMatmulMaybeF8(&existing_gemm).WithOneUser())) .WithOneUser(), m::Op(&bias).WithPredicate(is_not_broadcast)))) { - VLOG(1) << "Herer uuuuuuuuuuuuuuuuuuuuuuuuuuuu\n"; - if (bias) { - VLOG(1) << "bias not emp"; - VLOG(1) << bias->ToString(); - VLOG(1) << bias->users()[0]->ToString(); - } return FuseMatrixBiasAdd(instr, bias, existing_gemm, optional_bitcast_matrix, optional_slice_matrix); } @@ -772,7 +755,6 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { HloInstruction *bias = instr->users()[0]->mutable_operand( !instr->users()[0]->operand_index(instr)); if (bias->opcode() != HloOpcode::kBroadcast) { - VLOG(1) <<"WWWWWWWWWWWWWWWWWWWWWWWW"; c = bias; gemm_backend_config.set_beta(1.0); add = instr->users()[0]; @@ -1090,15 +1072,12 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { const HloInstruction *gemm, HloInstruction *bitcast = nullptr, HloInstruction *slice = nullptr) { - // return OkStatus(); TF_RET_CHECK(bias->shape() == (bitcast ? bitcast->shape() : gemm->shape())); - VLOG(1) << "FuseMatrixBiasAdd:1111111111111111111\n"; // Do not fuse bias into S32 GEMM, as for this datatype cuBLAS only // supports fixed values for alpha/beta. if (gemm->shape().element_type() == S32) { return OkStatus(); } - VLOG(1) << "FuseMatrixBiasAdd:22222222222222\n"; // Cublas gemm overwrites the bias matrix, so fusion is only possible if the // gemm is the only user. CublasLt gemm can operate out-of-place. bool can_overwrite_bias = [bias]() { @@ -1128,17 +1107,15 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { return in_out_alias_config.ParameterHasAlias(bias->parameter_number(), /*param_index=*/{}); }(); - bool want_to_fuse_bias = IsCublasLtMatmulF8(*gemm) || IsCublasLtMatmul(*gemm) || can_overwrite_bias; + bool want_to_fuse_bias = IsCublasLtMatmulF8(*gemm) || + IsCublasLtMatmul(*gemm) || can_overwrite_bias; auto config = gemm->backend_config().value(); - VLOG(1) << "FuseMatrixBiasAdd:3333333333333333\n"; // It is possible to fuse into a cublasLt matmul that already has a vector // bias, but no other epilogue will commute with the matrix bias add. bool supported_epilogue = ((config.epilogue() == GemmBackendConfig::DEFAULT) || (config.epilogue() == GemmBackendConfig::BIAS)); - VLOG(1) << config.beta() << "; " << want_to_fuse_bias << "; " - << gemm->user_count() << "; " << supported_epilogue << "; \n"; if ((config.beta() != 0) || !want_to_fuse_bias || (gemm->user_count() != 1) || !supported_epilogue) { return OkStatus(); @@ -1148,57 +1125,52 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { std::vector operands(gemm->operands().begin(), gemm->operands().end()); - HloInstruction* broadcast_bias = MaybeConstantFoldBias(bias); + HloInstruction *broadcast_bias = MaybeConstantFoldBias(bias); if (slice) { - broadcast_bias = instr->AddInstruction(HloInstruction::CreateBitcast(slice->shape(),broadcast_bias)); + broadcast_bias = instr->AddInstruction( + HloInstruction::CreateBitcast(slice->shape(), broadcast_bias)); } - VLOG(1) << broadcast_bias->ToString(); - if(bitcast) VLOG(1) << bitcast->ToString(); - absl::Span batch_dims = + absl::Span batch_dims = config.dot_dimension_numbers().rhs_batch_dimensions(); - // Get the padded shape. - auto pad_shape = [&batch_dims](const Shape old_shape) { - Shape padded_shape = old_shape; - for (int i = 0; i < old_shape.rank(); ++i) { - if (!absl::c_linear_search(batch_dims, i)) { - int64_t padded_dimension = - RoundUpTo(old_shape.dimensions(i), 16); - padded_shape.set_dimensions(i, padded_dimension); - } + // Get the padded shape. + auto pad_shape = [&batch_dims](const Shape old_shape) { + Shape padded_shape = old_shape; + for (int i = 0; i < old_shape.rank(); ++i) { + if (!absl::c_linear_search(batch_dims, i)) { + int64_t padded_dimension = + RoundUpTo(old_shape.dimensions(i), 16); + padded_shape.set_dimensions(i, padded_dimension); } - return padded_shape; - }; + } + return padded_shape; + }; - // Pad the non-batch dimensions of the operands to multiples of 16 as - // required by cuBLASLt. - auto pad_operand = [&instr, &pad_shape](HloInstruction *&x) -> void { - PaddingConfig padding_config; - Shape padded_shape = pad_shape(x->shape()); - for (int i = 0; i < x->shape().rank(); ++i) { - auto dimension = padding_config.add_dimensions(); - dimension->set_edge_padding_low(0); - dimension->set_edge_padding_high(padded_shape.dimensions(i) - - x->shape().dimensions(i)); - dimension->set_interior_padding(0); - } - if (!ShapeUtil::Equal(padded_shape, x->shape())) { - HloInstruction *zero = - instr->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::Zero(x->shape().element_type()))); - x = instr->AddInstruction( - HloInstruction::CreatePad(padded_shape, x, zero, padding_config)); - } - return; - }; - if (gemm->custom_call_target() == kCublasLtMatmulF8CallTarget){ - pad_operand(broadcast_bias); + // Pad the non-batch dimensions of the operands to multiples of 16 as + // required by cuBLASLt. + auto pad_operand = [&instr, &pad_shape](HloInstruction *&x) -> void { + PaddingConfig padding_config; + Shape padded_shape = pad_shape(x->shape()); + for (int i = 0; i < x->shape().rank(); ++i) { + auto dimension = padding_config.add_dimensions(); + dimension->set_edge_padding_low(0); + dimension->set_edge_padding_high(padded_shape.dimensions(i) - + x->shape().dimensions(i)); + dimension->set_interior_padding(0); } - VLOG(1) << "FuseMatrixBiasAdd:44444444444444444444\n"; - // if (gemm->custom_call_target() == kCublasLtMatmulF8CallTarget) { - // operands.at(2) = broadcast_bias; - // } else { - operands.insert(operands.begin() + 2, broadcast_bias); - // } + if (!ShapeUtil::Equal(padded_shape, x->shape())) { + HloInstruction *zero = + instr->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(x->shape().element_type()))); + x = instr->AddInstruction( + HloInstruction::CreatePad(padded_shape, x, zero, padding_config)); + } + return; + }; + if (gemm->custom_call_target() == kCublasLtMatmulF8CallTarget) { + pad_operand(broadcast_bias); + } + + operands.insert(operands.begin() + 2, broadcast_bias); std::unique_ptr fused_op = gemm->CloneWithNewOperands(gemm->shape(), operands); @@ -1228,10 +1200,10 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { TF_RETURN_IF_ERROR(SetName(instr->GetModule(), fused_op.get())); if (slice != nullptr) { fused_op = slice->CloneWithNewOperands( - slice->shape(), - {slice->parent()->AddInstruction(std::move(fused_op))}); + slice->shape(), + {slice->parent()->AddInstruction(std::move(fused_op))}); } - + if (bitcast != nullptr) { fused_op = bitcast->CloneWithNewOperands( bitcast->shape(), @@ -1247,16 +1219,9 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { HloInstruction *slice = nullptr, HloInstruction *convert = nullptr, HloInstruction *bitcast = nullptr) { - VLOG(1) << "FuseVectorBiasAdd: 1"; - // Only for rank3 debug - // if(bitcast) VLOG(1) << bitcast->ToString(); - // if(instr) VLOG(1) << instr->ToString(); - // if(broadcast) VLOG(1) << broadcast->ToString(); - // if(gemm) VLOG(1) << gemm->ToString(); - VLOG(1) << "FuseVectorBiasAdd: 1 ends"; - if (bitcast == nullptr) { - TF_RET_CHECK(ShapeUtil::Compatible( - broadcast->shape(), (slice ? slice->shape() : gemm->shape()))); + if (bitcast == nullptr) { + TF_RET_CHECK(ShapeUtil::Compatible( + broadcast->shape(), (slice ? slice->shape() : gemm->shape()))); } // Verify that the data type is supported by Epilogue Fusion. if (!SupportsEpilogueFusion(gemm->shape().element_type())) { @@ -1264,8 +1229,6 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { } HloInstruction *bias = broadcast->mutable_operand(0); - VLOG(1) << bias->ToString(); - VLOG(1) << "FuseVectorBiasAdd: 2 ends"; TF_ASSIGN_OR_RETURN(auto config, gemm->backend_config()); // # output column dims == # non-contracting rhs operand dims. @@ -1279,7 +1242,6 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { (bias->shape().rank() != num_col_dims)) { return false; } - VLOG(1) << "FuseVectorBiasAdd: 3 ends"; // We require the bias vector to have been broadcast in the most major // dimensions; i.e. its most minor physical dimensions align with most minor // physical dimensions of the gemm output. @@ -1287,22 +1249,6 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { if (bitcast) { broadcast_dims = gemm->shape().dimensions(); } - // for (size_t i = 0; i < num_col_dims; ++i) { - // int64_t dim = gemm->shape().layout().minor_to_major(i); - - // // Find the corresponding dimension from the bias vector. - // auto it = absl::c_find(broadcast_dims, dim); - - // if (it == broadcast_dims.end()) { - // return false; - // } - - // int64_t vector_dim = it - broadcast_dims.begin(); - // if (bias->shape().layout().minor_to_major(i) != vector_dim) { - // return false; - // } - // } - VLOG(1) << "FuseVectorBiasAdd: 4 ends"; std::vector operands(gemm->operands().begin(), gemm->operands().end()); // When (non-trivial) matrix and vector bias co-exist for FP8 matmul, just @@ -1311,7 +1257,6 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { config.beta() != 0.0) { return true; } - VLOG(1) << "FuseVectorBiasAdd: 5 ends"; if (gemm->custom_call_target() == kCublasLtMatmulF8CallTarget && bias->shape().element_type() == F32) { if (convert == nullptr) { @@ -1345,49 +1290,45 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { return false; } } - VLOG(1) << "FuseVectorBiasAdd: 6 ends"; - absl::Span batch_dims = + absl::Span batch_dims = config.dot_dimension_numbers().rhs_batch_dimensions(); - // Get the padded shape. - auto pad_shape = [&batch_dims](const Shape old_shape) { - Shape padded_shape = old_shape; - for (int i = 0; i < old_shape.rank(); ++i) { - if (!absl::c_linear_search(batch_dims, i)) { - int64_t padded_dimension = - RoundUpTo(old_shape.dimensions(i), 16); - padded_shape.set_dimensions(i, padded_dimension); - } + // Get the padded shape. + auto pad_shape = [&batch_dims](const Shape old_shape) { + Shape padded_shape = old_shape; + for (int i = 0; i < old_shape.rank(); ++i) { + if (!absl::c_linear_search(batch_dims, i)) { + int64_t padded_dimension = + RoundUpTo(old_shape.dimensions(i), 16); + padded_shape.set_dimensions(i, padded_dimension); } - return padded_shape; - }; + } + return padded_shape; + }; - // Pad the non-batch dimensions of the operands to multiples of 16 as - // required by cuBLASLt. - auto pad_operand = [&instr, &pad_shape](HloInstruction *&x) -> void { - PaddingConfig padding_config; - Shape padded_shape = pad_shape(x->shape()); - for (int i = 0; i < x->shape().rank(); ++i) { - auto dimension = padding_config.add_dimensions(); - dimension->set_edge_padding_low(0); - dimension->set_edge_padding_high(padded_shape.dimensions(i) - - x->shape().dimensions(i)); - dimension->set_interior_padding(0); - } - if (!ShapeUtil::Equal(padded_shape, x->shape())) { - HloInstruction *zero = - instr->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::Zero(x->shape().element_type()))); - x = instr->AddInstruction( - HloInstruction::CreatePad(padded_shape, x, zero, padding_config)); - } - return; - }; + // Pad the non-batch dimensions of the operands to multiples of 16 as + // required by cuBLASLt. + auto pad_operand = [&instr, &pad_shape](HloInstruction *&x) -> void { + PaddingConfig padding_config; + Shape padded_shape = pad_shape(x->shape()); + for (int i = 0; i < x->shape().rank(); ++i) { + auto dimension = padding_config.add_dimensions(); + dimension->set_edge_padding_low(0); + dimension->set_edge_padding_high(padded_shape.dimensions(i) - + x->shape().dimensions(i)); + dimension->set_interior_padding(0); + } + if (!ShapeUtil::Equal(padded_shape, x->shape())) { + HloInstruction *zero = + instr->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(x->shape().element_type()))); + x = instr->AddInstruction( + HloInstruction::CreatePad(padded_shape, x, zero, padding_config)); + } + return; + }; if (bitcast) { - // bias = instr->AddInstruction(HloInstruction::CreateBitcast(slice->shape(),broadcast_bias)); - pad_operand(bias); - VLOG(1) << "FuseVectorBiasAdd: 7 ends"; - VLOG(1) << bias->ToString(); + pad_operand(bias); } // Replace add(gemm, broadcast) with fused new_gemm. operands.push_back(bias); @@ -1400,22 +1341,13 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { result = slice->CloneWithNewOperands( slice->shape(), {slice->parent()->AddInstruction(std::move(result))}); } - VLOG(1) << "FuseVectorBiasAdd: 7.5 ends"; - VLOG(1) << instr->ToString(); - VLOG(1) << result.get()->ToString(); if (bitcast != nullptr) { result = bitcast->CloneWithNewOperands( bitcast->shape(), {bitcast->parent()->AddInstruction(std::move(result))}); } - VLOG(1) << "FuseVectorBiasAdd: 7.75 ends"; - VLOG(1) << instr->ToString(); - VLOG(1) << result.get()->ToString(); - TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(instr, std::move(result))); - VLOG(1) << "FuseVectorBiasAdd: 8 ends"; - VLOG(1) << instr->ToString(); return true; } From 5204bfbdc507777fffb0b4c6fefd78f273b3086d Mon Sep 17 00:00:00 2001 From: Song Ziming Date: Tue, 13 Jun 2023 19:55:11 +0800 Subject: [PATCH 015/410] Simplified tflite::transpose_utils::Flatten function --- tensorflow/lite/kernels/internal/transpose_utils.cc | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/tensorflow/lite/kernels/internal/transpose_utils.cc b/tensorflow/lite/kernels/internal/transpose_utils.cc index 76808020853050..730a722dd31bb3 100644 --- a/tensorflow/lite/kernels/internal/transpose_utils.cc +++ b/tensorflow/lite/kernels/internal/transpose_utils.cc @@ -143,18 +143,7 @@ size_t Flatten(const RuntimeShape& input_shape, for (int i = skip_dims_cnt; i < params.perm_count; ++i) { non_flatten_input_shape->SetDim(i - skip_dims_cnt, input_shape.Dims(i)); non_flatten_output_shape->SetDim(i - skip_dims_cnt, output_shape.Dims(i)); - non_flatten_params->perm[i - skip_dims_cnt] = params.perm[i]; - } - for (int i = 0; i < new_dims_cnt; ++i) { - int min_val_idx = -1; - for (int j = 0; j < new_dims_cnt; ++j) { - if (non_flatten_params->perm[j] >= i && - (min_val_idx == -1 || non_flatten_params->perm[min_val_idx] > - non_flatten_params->perm[j])) { - min_val_idx = j; - } - } - non_flatten_params->perm[min_val_idx] = i; + non_flatten_params->perm[i - skip_dims_cnt] = params.perm[i] - skip_dims_cnt; } return flat_size; From c0b5692bcdabec7b0cb28f4970e230765d173c17 Mon Sep 17 00:00:00 2001 From: shuw Date: Tue, 13 Jun 2023 08:55:20 -0700 Subject: [PATCH 016/410] pad_operand as func --- .../compiler/xla/service/gpu/gemm_rewriter.cc | 157 +++++------------- 1 file changed, 45 insertions(+), 112 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc b/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc index 973f473030d7dd..ccd2e199d5d07d 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc @@ -658,6 +658,43 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { return OkStatus(); } + // Get the padded shape. + Shape pad_shape(const Shape old_shape, + const absl::Span batch_dims) { + Shape padded_shape = old_shape; + for (int i = 0; i < old_shape.rank(); ++i) { + if (!absl::c_linear_search(batch_dims, i)) { + int64_t padded_dimension = + RoundUpTo(old_shape.dimensions(i), 16); + padded_shape.set_dimensions(i, padded_dimension); + } + } + return padded_shape; + } + + // Pad the non-batch dimensions of the operands to multiples of 16 as + // required by cuBLASLt. + void pad_operand(absl::Span batch_dims, HloInstruction *&instr, + HloInstruction *&x) { + PaddingConfig padding_config; + Shape padded_shape = pad_shape(x->shape(), batch_dims); + for (int i = 0; i < x->shape().rank(); ++i) { + auto dimension = padding_config.add_dimensions(); + dimension->set_edge_padding_low(0); + dimension->set_edge_padding_high(padded_shape.dimensions(i) - + x->shape().dimensions(i)); + dimension->set_interior_padding(0); + } + if (!ShapeUtil::Equal(padded_shape, x->shape())) { + HloInstruction *zero = + instr->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(x->shape().element_type()))); + x = instr->AddInstruction( + HloInstruction::CreatePad(padded_shape, x, zero, padding_config)); + } + return; + } + StatusOr CreateF8CustomCall(HloInstruction *instr, GemmBackendConfig &gemm_backend_config, HloInstruction *a, HloInstruction *b, @@ -855,47 +892,12 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { b = TransposeMatrix(b, b_contracting_dims[0], batch_dims); } - // Get the padded shape. - auto pad_shape = [&batch_dims](const Shape old_shape) { - Shape padded_shape = old_shape; - for (int i = 0; i < old_shape.rank(); ++i) { - if (!absl::c_linear_search(batch_dims, i)) { - int64_t padded_dimension = - RoundUpTo(old_shape.dimensions(i), 16); - padded_shape.set_dimensions(i, padded_dimension); - } - } - return padded_shape; - }; - - // Pad the non-batch dimensions of the operands to multiples of 16 as - // required by cuBLASLt. - auto pad_operand = [&instr, &pad_shape](HloInstruction *&x) -> void { - PaddingConfig padding_config; - Shape padded_shape = pad_shape(x->shape()); - for (int i = 0; i < x->shape().rank(); ++i) { - auto dimension = padding_config.add_dimensions(); - dimension->set_edge_padding_low(0); - dimension->set_edge_padding_high(padded_shape.dimensions(i) - - x->shape().dimensions(i)); - dimension->set_interior_padding(0); - } - if (!ShapeUtil::Equal(padded_shape, x->shape())) { - HloInstruction *zero = - instr->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::Zero(x->shape().element_type()))); - x = instr->AddInstruction( - HloInstruction::CreatePad(padded_shape, x, zero, padding_config)); - } - return; - }; - - pad_operand(a); - pad_operand(b); + pad_operand(batch_dims, instr, a); + pad_operand(batch_dims, instr, b); if (c != nullptr) { - pad_operand(c); + pad_operand(batch_dims, instr, c); } - Shape new_output_shape = pad_shape(instr->shape()); + Shape new_output_shape = pad_shape(instr->shape(), batch_dims); std::vector operands_list = { a, b, scales_f32[0], scales_f32[1], one, one}; @@ -1130,44 +1132,10 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { broadcast_bias = instr->AddInstruction( HloInstruction::CreateBitcast(slice->shape(), broadcast_bias)); } - absl::Span batch_dims = - config.dot_dimension_numbers().rhs_batch_dimensions(); - // Get the padded shape. - auto pad_shape = [&batch_dims](const Shape old_shape) { - Shape padded_shape = old_shape; - for (int i = 0; i < old_shape.rank(); ++i) { - if (!absl::c_linear_search(batch_dims, i)) { - int64_t padded_dimension = - RoundUpTo(old_shape.dimensions(i), 16); - padded_shape.set_dimensions(i, padded_dimension); - } - } - return padded_shape; - }; - // Pad the non-batch dimensions of the operands to multiples of 16 as - // required by cuBLASLt. - auto pad_operand = [&instr, &pad_shape](HloInstruction *&x) -> void { - PaddingConfig padding_config; - Shape padded_shape = pad_shape(x->shape()); - for (int i = 0; i < x->shape().rank(); ++i) { - auto dimension = padding_config.add_dimensions(); - dimension->set_edge_padding_low(0); - dimension->set_edge_padding_high(padded_shape.dimensions(i) - - x->shape().dimensions(i)); - dimension->set_interior_padding(0); - } - if (!ShapeUtil::Equal(padded_shape, x->shape())) { - HloInstruction *zero = - instr->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::Zero(x->shape().element_type()))); - x = instr->AddInstruction( - HloInstruction::CreatePad(padded_shape, x, zero, padding_config)); - } - return; - }; if (gemm->custom_call_target() == kCublasLtMatmulF8CallTarget) { - pad_operand(broadcast_bias); + pad_operand(config.dot_dimension_numbers().rhs_batch_dimensions(), instr, + broadcast_bias); } operands.insert(operands.begin() + 2, broadcast_bias); @@ -1291,44 +1259,9 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { } } - absl::Span batch_dims = - config.dot_dimension_numbers().rhs_batch_dimensions(); - // Get the padded shape. - auto pad_shape = [&batch_dims](const Shape old_shape) { - Shape padded_shape = old_shape; - for (int i = 0; i < old_shape.rank(); ++i) { - if (!absl::c_linear_search(batch_dims, i)) { - int64_t padded_dimension = - RoundUpTo(old_shape.dimensions(i), 16); - padded_shape.set_dimensions(i, padded_dimension); - } - } - return padded_shape; - }; - - // Pad the non-batch dimensions of the operands to multiples of 16 as - // required by cuBLASLt. - auto pad_operand = [&instr, &pad_shape](HloInstruction *&x) -> void { - PaddingConfig padding_config; - Shape padded_shape = pad_shape(x->shape()); - for (int i = 0; i < x->shape().rank(); ++i) { - auto dimension = padding_config.add_dimensions(); - dimension->set_edge_padding_low(0); - dimension->set_edge_padding_high(padded_shape.dimensions(i) - - x->shape().dimensions(i)); - dimension->set_interior_padding(0); - } - if (!ShapeUtil::Equal(padded_shape, x->shape())) { - HloInstruction *zero = - instr->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::Zero(x->shape().element_type()))); - x = instr->AddInstruction( - HloInstruction::CreatePad(padded_shape, x, zero, padding_config)); - } - return; - }; if (bitcast) { - pad_operand(bias); + pad_operand(config.dot_dimension_numbers().rhs_batch_dimensions(), instr, + bias); } // Replace add(gemm, broadcast) with fused new_gemm. operands.push_back(bias); From b3e5b870cb55ccdac3bfbe049373619a5339c1b3 Mon Sep 17 00:00:00 2001 From: shuw Date: Tue, 13 Jun 2023 13:40:01 -0700 Subject: [PATCH 017/410] lambda to func --- .../compiler/xla/service/gpu/gemm_rewriter.cc | 31 ++++++++++--------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc b/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc index ccd2e199d5d07d..71d7ea142b4013 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc @@ -531,7 +531,6 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { // add(bitcast(gemm(a, b)), broadcast(bias)) -> // bitcast(add(gemm(a, b), bitcast(broadcast(bias)))) -> // bitcast(gemm(a, b, bitcast(broadcast(bias)))) (FuseMatrixBiasAdd) - // if (Match( instr, m::AddAnyOrder( @@ -787,16 +786,16 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { // a matrix bias is only supported with CUDA 12 and above. HloInstruction *c = nullptr, *add = nullptr; - if (instr->user_count() == 1 && - instr->users()[0]->opcode() == HloOpcode::kAdd) { - HloInstruction *bias = instr->users()[0]->mutable_operand( - !instr->users()[0]->operand_index(instr)); - if (bias->opcode() != HloOpcode::kBroadcast) { - c = bias; - gemm_backend_config.set_beta(1.0); - add = instr->users()[0]; - } - } + // if (instr->user_count() == 1 && + // instr->users()[0]->opcode() == HloOpcode::kAdd) { + // HloInstruction *bias = instr->users()[0]->mutable_operand( + // !instr->users()[0]->operand_index(instr)); + // if (bias->opcode() != HloOpcode::kBroadcast) { + // c = bias; + // gemm_backend_config.set_beta(1.0); + // add = instr->users()[0]; + // } + // } // Each operand must have exactly one contracting and one non-contracting // dimension. @@ -925,6 +924,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { instr->shape(), new_custom_call, start_indices, instr->shape().dimensions(), strides)); } + TF_RETURN_IF_ERROR( ReplaceInstruction(add ? add : instr, slice ? slice : new_custom_call)); return true; @@ -1074,7 +1074,10 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { const HloInstruction *gemm, HloInstruction *bitcast = nullptr, HloInstruction *slice = nullptr) { - TF_RET_CHECK(bias->shape() == (bitcast ? bitcast->shape() : gemm->shape())); + TF_RET_CHECK( + (bias->shape() == (bitcast ? bitcast->shape() : gemm->shape())) || + (bias->shape() == (slice ? slice->shape() : gemm->shape()))); + // Do not fuse bias into S32 GEMM, as for this datatype cuBLAS only // supports fixed values for alpha/beta. if (gemm->shape().element_type() == S32) { @@ -1122,13 +1125,12 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { (gemm->user_count() != 1) || !supported_epilogue) { return OkStatus(); } - config.set_beta(1.0); std::vector operands(gemm->operands().begin(), gemm->operands().end()); HloInstruction *broadcast_bias = MaybeConstantFoldBias(bias); - if (slice) { + if (slice && bitcast) { broadcast_bias = instr->AddInstruction( HloInstruction::CreateBitcast(slice->shape(), broadcast_bias)); } @@ -1177,7 +1179,6 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { bitcast->shape(), {bitcast->parent()->AddInstruction(std::move(fused_op))}); } - return ReplaceWithNewInstruction(instr, std::move(fused_op)); } From c5ee2f190935170f8e535e15ada650136777ebb9 Mon Sep 17 00:00:00 2001 From: shuw Date: Tue, 13 Jun 2023 13:45:29 -0700 Subject: [PATCH 018/410] Tidy and clean --- .../compiler/xla/service/gpu/gemm_rewriter.cc | 39 +++++++------------ 1 file changed, 15 insertions(+), 24 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc b/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc index 71d7ea142b4013..5aa1d7804aaa11 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc @@ -500,7 +500,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { } Status HandleAdd(HloInstruction *instr) override { - HloInstruction *bias, *existing_gemm(nullptr); + HloInstruction *bias, *existing_gemm = nullptr; HloInstruction *optional_slice = nullptr; HloInstruction *optional_convert = nullptr; HloInstruction *optional_bitcast = nullptr; @@ -531,6 +531,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { // add(bitcast(gemm(a, b)), broadcast(bias)) -> // bitcast(add(gemm(a, b), bitcast(broadcast(bias)))) -> // bitcast(gemm(a, b, bitcast(broadcast(bias)))) (FuseMatrixBiasAdd) + // if (Match( instr, m::AddAnyOrder( @@ -547,6 +548,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { // Continue below. instr = new_add; } + // Do not fuse broadcast unless we can fuse its input, as it will cause // broadcast materialization. auto is_not_broadcast = [](const HloInstruction *instr) { @@ -581,12 +583,14 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { // Continue below transforming new_add. instr = new_add; } + if (Match(instr, m::AddAnyOrder( GemmOrCublasLtMatmulMaybeF8(&existing_gemm).WithOneUser(), m::Op(&bias).WithPredicate(is_not_broadcast)))) { return FuseMatrixBiasAdd(instr, bias, existing_gemm); } + HloInstruction *optional_bitcast_matrix = nullptr; HloInstruction *optional_slice_matrix = nullptr; if (Match(instr, @@ -781,22 +785,6 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { return false; } - // Fuse the possible addition of a matrix bias here to enable the subsequent - // fusion of the scaling and conversion of D into the Custom Call. Fusing - // a matrix bias is only supported with CUDA 12 and above. - HloInstruction *c = nullptr, *add = nullptr; - - // if (instr->user_count() == 1 && - // instr->users()[0]->opcode() == HloOpcode::kAdd) { - // HloInstruction *bias = instr->users()[0]->mutable_operand( - // !instr->users()[0]->operand_index(instr)); - // if (bias->opcode() != HloOpcode::kBroadcast) { - // c = bias; - // gemm_backend_config.set_beta(1.0); - // add = instr->users()[0]; - // } - // } - // Each operand must have exactly one contracting and one non-contracting // dimension. absl::Span a_contracting_dims = @@ -893,16 +881,10 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { pad_operand(batch_dims, instr, a); pad_operand(batch_dims, instr, b); - if (c != nullptr) { - pad_operand(batch_dims, instr, c); - } Shape new_output_shape = pad_shape(instr->shape(), batch_dims); std::vector operands_list = { a, b, scales_f32[0], scales_f32[1], one, one}; - if (c != nullptr) { - operands_list.insert(operands_list.begin() + 2, c); - } HloInstruction *new_custom_call = instr->AddInstruction(HloInstruction::CreateCustomCall( @@ -926,7 +908,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { } TF_RETURN_IF_ERROR( - ReplaceInstruction(add ? add : instr, slice ? slice : new_custom_call)); + ReplaceInstruction(instr, slice ? slice : new_custom_call)); return true; } @@ -1083,6 +1065,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { if (gemm->shape().element_type() == S32) { return OkStatus(); } + // Cublas gemm overwrites the bias matrix, so fusion is only possible if the // gemm is the only user. CublasLt gemm can operate out-of-place. bool can_overwrite_bias = [bias]() { @@ -1116,15 +1099,18 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { IsCublasLtMatmul(*gemm) || can_overwrite_bias; auto config = gemm->backend_config().value(); + // It is possible to fuse into a cublasLt matmul that already has a vector // bias, but no other epilogue will commute with the matrix bias add. bool supported_epilogue = ((config.epilogue() == GemmBackendConfig::DEFAULT) || (config.epilogue() == GemmBackendConfig::BIAS)); + if ((config.beta() != 0) || !want_to_fuse_bias || (gemm->user_count() != 1) || !supported_epilogue) { return OkStatus(); } + config.set_beta(1.0); std::vector operands(gemm->operands().begin(), @@ -1179,6 +1165,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { bitcast->shape(), {bitcast->parent()->AddInstruction(std::move(fused_op))}); } + return ReplaceWithNewInstruction(instr, std::move(fused_op)); } @@ -1198,6 +1185,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { } HloInstruction *bias = broadcast->mutable_operand(0); + TF_ASSIGN_OR_RETURN(auto config, gemm->backend_config()); // # output column dims == # non-contracting rhs operand dims. @@ -1218,6 +1206,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { if (bitcast) { broadcast_dims = gemm->shape().dimensions(); } + std::vector operands(gemm->operands().begin(), gemm->operands().end()); // When (non-trivial) matrix and vector bias co-exist for FP8 matmul, just @@ -1226,11 +1215,13 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { config.beta() != 0.0) { return true; } + if (gemm->custom_call_target() == kCublasLtMatmulF8CallTarget && bias->shape().element_type() == F32) { if (convert == nullptr) { return false; } + HloInstruction *bias_f16_or_bf16 = convert->mutable_operand(0); auto compatible_bias_type = [](const PrimitiveType bias_type, const PrimitiveType output_type) { From 8b8be9a5a42544c3a7354fa0339b321533dc6043 Mon Sep 17 00:00:00 2001 From: Peng Sun Date: Fri, 24 Mar 2023 16:01:26 +0000 Subject: [PATCH 019/410] [tosa] legalize matmul with quantized output This commit updates the quantized matmul operation to produce the correct quantized output. Specifically, it handles the following cases: - When the input is of type int8_t, the output is of type int32_t. - When the input is of type int16_t, the output is of type int48_t. Change-Id: I770aaaa93b8e9bc1467efbc2acd39a62ff9c06f7 --- .../mlir/tosa/tests/tfl-to-tosa-pipeline.mlir | 32 +++++++++++++++++++ .../mlir/tosa/transforms/legalize_tfl.cc | 29 +++++++++++++++-- 2 files changed, 59 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir index 145b1877761d8e..6bb5140e5c46df 100644 --- a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir @@ -1495,6 +1495,38 @@ func.func @test_batch_matmul_transpose(%arg0: tensor<1x16x128xf32>, %arg1: tenso // ----- +// CHECK-LABEL: test_batch_matmul_qi8 +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x3x4x4x!quant.uniform> +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x3x4x3x!quant.uniform> +// CHECK-DAG: %[[VAL_2:.*]] = "tosa.reshape"(%[[VAL_0]]) <{new_shape = array}> : (tensor<1x3x4x4x!quant.uniform>) -> tensor<3x4x4x!quant.uniform> +// CHECK-DAG: %[[VAL_3:.*]] = "tosa.reshape"(%[[VAL_1]]) <{new_shape = array}> : (tensor<1x3x4x3x!quant.uniform>) -> tensor<3x4x3x!quant.uniform> +// CHECK-DAG: %[[VAL_4:.*]] = "tosa.matmul"(%[[VAL_2]], %[[VAL_3]]) <{quantization_info = #tosa.matmul_quant}> : (tensor<3x4x4x!quant.uniform>, tensor<3x4x3x!quant.uniform>) -> tensor<3x4x3xi32> +// CHECK-DAG: %[[VAL_5:.*]] = "tosa.reshape"(%[[VAL_4]]) <{new_shape = array}> : (tensor<3x4x3xi32>) -> tensor<1x3x4x3xi32> +// CHECK-DAG: %[[VAL_6:.*]] = "tosa.rescale"(%[[VAL_5]]) <{double_round = true, input_zp = 0 : i32, multiplier = array, output_zp = -128 : i32, per_channel = false, scale32 = true, shift = array}> : (tensor<1x3x4x3xi32>) -> tensor<1x3x4x3x!quant.uniform> +// CHECK: return %[[VAL_6]] : tensor<1x3x4x3x!quant.uniform> +func.func @test_batch_matmul_qi8(%arg0: tensor<1x3x4x4x!quant.uniform>, %arg1: tensor<1x3x4x3x!quant.uniform>) -> tensor<1x3x4x3x!quant.uniform> { + %0 = "tfl.batch_matmul"(%arg0, %arg1) {adj_x = false, adj_y = false, asymmetric_quantize_inputs = false} : (tensor<1x3x4x4x!quant.uniform>, tensor<1x3x4x3x!quant.uniform>) -> tensor<1x3x4x3x!quant.uniform> + return %0 : tensor<1x3x4x3x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: test_batch_matmul_qi16 +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x3x4x4x!quant.uniform>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x3x4x3x!quant.uniform>) -> tensor<1x3x4x3x!quant.uniform> +// CHECK-DAG: %[[VAL_2:.*]] = "tosa.reshape"(%[[VAL_0]]) <{new_shape = array}> : (tensor<1x3x4x4x!quant.uniform>) -> tensor<3x4x4x!quant.uniform> +// CHECK-DAG: %[[VAL_3:.*]] = "tosa.reshape"(%[[VAL_1]]) <{new_shape = array}> : (tensor<1x3x4x3x!quant.uniform>) -> tensor<3x4x3x!quant.uniform> +// CHECK-DAG: %[[VAL_4:.*]] = "tosa.matmul"(%[[VAL_2]], %[[VAL_3]]) <{quantization_info = #tosa.matmul_quant}> : (tensor<3x4x4x!quant.uniform>, tensor<3x4x3x!quant.uniform>) -> tensor<3x4x3xi48> +// CHECK-DAG: %[[VAL_5:.*]] = "tosa.reshape"(%[[VAL_4]]) <{new_shape = array}> : (tensor<3x4x3xi48>) -> tensor<1x3x4x3xi48> +// CHECK-DAG: %[[VAL_6:.*]] = "tosa.rescale"(%[[VAL_5]]) <{double_round = false, input_zp = 0 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = false, shift = array}> : (tensor<1x3x4x3xi48>) -> tensor<1x3x4x3x!quant.uniform> +// CHECK: return %[[VAL_6]] : tensor<1x3x4x3x!quant.uniform> +func.func @test_batch_matmul_qi16(%arg0: tensor<1x3x4x4x!quant.uniform>, %arg1: tensor<1x3x4x3x!quant.uniform>) -> (tensor<1x3x4x3x!quant.uniform>) { +%0 = "tfl.batch_matmul"(%arg0, %arg1) {adj_x = false, adj_y = false, asymmetric_quantize_inputs = false} : (tensor<1x3x4x4x!quant.uniform>, tensor<1x3x4x3x!quant.uniform>) -> tensor<1x3x4x3x!quant.uniform> +return %0 : tensor<1x3x4x3x!quant.uniform> +} + +// ----- + // CHECK-LABEL: test_add_scalar // CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<1x1x1xf32>}> // CHECK: %[[VAR2:.*]] = "tosa.add"(%arg0, %[[VAR0]]) diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc index 87573d30ed56d5..c25c58af9be105 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc @@ -1936,10 +1936,35 @@ LogicalResult ConvertTFLBatchMatMulOp::matchAndRewrite( .getResult(); } + Type output_ety; + if (result_is_qtype) { + auto lhs_qty_width = lhs_ty.getElementType() + .cast() + .getStorageTypeIntegralWidth(); + auto rhs_qty_width = rhs_ty.getElementType() + .cast() + .getStorageTypeIntegralWidth(); + + if (lhs_qty_width != rhs_qty_width) { + return rewriter.notifyMatchFailure( + op, "input tensors should have same qtype storage width"); + } + + if (lhs_qty_width == 8) { + output_ety = rewriter.getI32Type(); + } else if (lhs_qty_width == 16) { + output_ety = rewriter.getIntegerType(48); + } else { + return rewriter.notifyMatchFailure( + op, "only support 8-bit or 16-bit quantized type"); + } + } else { + output_ety = result_ty.getElementType(); + } + auto matmul = CreateOpAndInfer( - rewriter, op->getLoc(), - UnrankedTensorType::get(result_ty.getElementType()), lhs, rhs) + rewriter, op->getLoc(), UnrankedTensorType::get(output_ety), lhs, rhs) .getResult(); // Conditionally reshape rank back to expected rank. From e160eb8376413b317cf3fd64d753235a23a3ef87 Mon Sep 17 00:00:00 2001 From: Kanvi Khanna Date: Tue, 27 Jun 2023 11:17:50 -0700 Subject: [PATCH 020/410] Fix build errors --- .../generic_layout_optimizer_transposer.cc | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc index c745e07380c705..82726b0c853a06 100644 --- a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc +++ b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc @@ -1027,7 +1027,7 @@ Status MaxPool3DTransposer::TransposeNode(TransposeContext* context, auto* data_fanin_node = data_fanin.node_view(); if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*data_fanin_node, data_fanin.index(), 5)) { - return Status::OK(); + return OkStatus(); } ScopedDataFormatUpgrader data_format_upgrader(context, 5); VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName() @@ -1468,7 +1468,7 @@ Status MergeTransposer::TransposeNode(TransposeContext* context, DCHECK(IsMerge(*node->node())); const int rank = GetFaninPortRank(*node, 0); if (rank != 4 && rank != 5) { - return Status::OK(); + return OkStatus(); } ScopedDataFormatUpgrader data_format_upgrader(context, rank); if (!ShouldProcess(*context, *node) || @@ -1721,7 +1721,7 @@ Status SplitTransposer::TransposeNode(TransposeContext* context, int rank = 4; if (!IsFanoutPortsRankN(*node, ports, 4)) { if (!IsFanoutPortsRankN(*node, ports, 5)) { - return Status::OK(); + return OkStatus(); } else { rank = 5; } @@ -1746,7 +1746,7 @@ Status SplitVTransposer::TransposeNode(TransposeContext* context, int rank = 4; if (!IsFanoutPortsRankN(*node, ports, 4)) { if (!IsFanoutPortsRankN(*node, ports, 5)) { - return Status::OK(); + return OkStatus(); } else { rank = 5; } @@ -1927,7 +1927,7 @@ Status StridedSliceTransposer::TransposeNode(TransposeContext* context, DCHECK(IsStridedSlice(*node->node())); const int rank = GetFanoutPortRank(*node, 0); if (rank != 4 && rank != 5) { - return Status::OK(); + return OkStatus(); } ScopedDataFormatUpgrader data_format_upgrader(context, rank); if (!ShouldProcess(*context, *node) || !HasOnlyBeginEndMask(*node) || @@ -1955,7 +1955,7 @@ Status SwitchTransposer::TransposeNode(TransposeContext* context, DCHECK(IsSwitch(*node->node())); const int rank = GetFaninPortRank(*node, 0); if (rank != 4 && rank != 5) { - return Status::OK(); + return OkStatus(); } ScopedDataFormatUpgrader data_format_upgrader(context, rank); if (!ShouldProcess(*context, *node) || @@ -1973,7 +1973,7 @@ Status TernaryOpTransposer::TransposeNode(TransposeContext* context, DCHECK(IsTernaryOp(*node->node())); const int rank = GetFanoutPortRank(*node, 0); if (rank != 4 && rank != 5) { - return Status::OK(); + return OkStatus(); } ScopedDataFormatUpgrader data_format_upgrader(context, rank); if (!ShouldProcess(*context, *node) || From c699c44c9254a68c04fc53d2c472fc11932397e8 Mon Sep 17 00:00:00 2001 From: pjpratik <118897289+pjpratik@users.noreply.github.com> Date: Wed, 5 Jul 2023 14:17:12 +0530 Subject: [PATCH 021/410] Fix aligned alloc feature condition on C++14 compiler The updated condition will only compile if both C++17 and C11 compliance are met, guaranteeing that the aligned_alloc feature is available. Relevant closed PR #57707 Fixes #57706 --- tensorflow/lite/core/interpreter_builder.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensorflow/lite/core/interpreter_builder.cc b/tensorflow/lite/core/interpreter_builder.cc index 0c508e83f5b6b4..d1cf3eec2db7c1 100644 --- a/tensorflow/lite/core/interpreter_builder.cc +++ b/tensorflow/lite/core/interpreter_builder.cc @@ -49,7 +49,8 @@ limitations under the License. #include "tensorflow/lite/version.h" // aligned_alloc is available (via cstdlib/stdlib.h) with C++17/C11. -#if __cplusplus >= 201703L || __STDC_VERSION__ >= 201112L +//(introduced in stdc11 but realized in C++17) +#if __cplusplus >= 201703L && __STDC_VERSION__ >= 201112L #if !defined(__ANDROID__) || __ANDROID_API__ >= 28 // Neither Apple nor Windows provide aligned_alloc. #if !defined(__APPLE__) && !defined(_WIN32) From 4f52f7b7ebafa498054454217bacdff18fd56e65 Mon Sep 17 00:00:00 2001 From: pjpratik <118897289+pjpratik@users.noreply.github.com> Date: Wed, 5 Jul 2023 14:23:17 +0530 Subject: [PATCH 022/410] Fix aligned alloc feature condition on C++14 compiler The updated condition will only compile if both C++17 and C11 compliance are met, guaranteeing that the aligned_alloc feature is available. --- .../lite/kernels/internal/optimized/neon_tensor_utils.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc index 71755d851e7c23..67da19fc95a773 100644 --- a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc +++ b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc @@ -36,7 +36,8 @@ limitations under the License. #ifdef USE_NEON // aligned_alloc is available (via cstdlib/stdlib.h) with C++17/C11. -#if __cplusplus >= 201703L || __STDC_VERSION__ >= 201112L +//(introduced in stdc11 but realized in C++17) +#if __cplusplus >= 201703L && __STDC_VERSION__ >= 201112L #if !defined(__ANDROID__) || __ANDROID_API__ >= 28 // Neither Apple nor Windows provide aligned_alloc. #if !defined(__APPLE__) && !defined(_WIN32) From eb9e49a995ce1d2108700152364181a3679ec508 Mon Sep 17 00:00:00 2001 From: Tai Ly Date: Mon, 22 May 2023 20:24:27 +0000 Subject: [PATCH 023/410] [Tosa] Add legalization of BroadcastTo For tf/tfl broadcast-to operator with input and shape where: - shape is compile time constant, and - shape's rank is greater than or equal to input's rank, and - input element type is not complex, and - input element type is not integer whose bitwidth is greater than 32 will convert to tosa operators as follows: 1. if input element type is floating point, add input with constant -0.f of the broadcast shape 2. if input element type is i1, logical-or input with constant 'false' of the broadcast shape 3. if input element type is i32, add input with constant 0 (i32) of the broadcast shape 4. otherwise, cast input to i32, add with constant 0 (i32) of the broadcast shape, and cast back to original element type added tf/tfl lit tests Signed-off-by: Tai Ly Change-Id: I12302adcf1c791d452a5b5a928e63e5ffcd523bc --- .../mlir/tosa/tests/tf-to-tosa-pipeline.mlir | 66 ++++++++ .../mlir/tosa/tests/tfl-to-tosa-pipeline.mlir | 91 +++++++++++ .../mlir/tosa/transforms/legalize_common.cc | 144 ++++++++++++++++++ .../mlir/tosa/transforms/legalize_common.h | 5 + .../mlir/tosa/transforms/legalize_tf.cc | 17 +++ .../mlir/tosa/transforms/legalize_tfl.cc | 17 +++ 6 files changed, 340 insertions(+) diff --git a/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir b/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir index 47e2571e2bb13d..5c863439436d01 100644 --- a/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir @@ -1071,3 +1071,69 @@ func.func @mirrorpad_reflect(%arg0: tensor<13x21x3xf32>) -> tensor<14x22x4xf32> %1 = "tf.Identity"(%0) {device = ""} : (tensor<14x22x4xf32>) -> tensor<14x22x4xf32> return %0 : tensor<14x22x4xf32> } + +// ----- + +// CHECK-LABEL: test_broadcast_to_f32 +// CHECK: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<-0.000000e+00> : tensor<3x3x13x7xf32>} +// CHECK: %[[VAL_1:.*]] = "tosa.reshape"(%arg0) <{new_shape = array}> : (tensor<13x1xf32>) +// CHECK: %[[VAL_2:.*]] = "tosa.add"(%[[VAL_1]], %[[VAL_0]]) : (tensor<1x1x13x1xf32>, tensor<3x3x13x7xf32>) -> tensor<3x3x13x7xf32> +// CHECK: return %[[VAL_2]] : tensor<3x3x13x7xf32> +func.func @test_broadcast_to_f32(%arg0: tensor<13x1xf32>) -> (tensor<3x3x13x7xf32>) { + %shape = "tf.Const"() {value = dense<[3, 3, 1, 7]> : tensor<4xi32>} : () -> tensor<4xi32> + %1 = "tf.BroadcastTo"(%arg0, %shape) : (tensor<13x1xf32>, tensor<4xi32>) -> tensor<3x3x13x7xf32> + return %1 : tensor<3x3x13x7xf32> +} + +// ----- + +// CHECK-LABEL: test_broadcast_to_i32 +// CHECK: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<0> : tensor<7x7x13x3xi32>} +// CHECK: %[[VAL_1:.*]] = "tosa.reshape"(%arg0) <{new_shape = array}> : (tensor<13x1xi32> +// CHECK: %[[VAL_2:.*]] = "tosa.add"(%[[VAL_1]], %[[VAL_0]]) : (tensor<1x1x13x1xi32>, tensor<7x7x13x3xi32>) -> tensor<7x7x13x3xi32> +// CHECK: return %[[VAL_2]] : tensor<7x7x13x3xi32> +func.func @test_broadcast_to_i32(%arg0: tensor<13x1xi32>) -> (tensor<3x3x13x3xi32>) { + %shape = "tf.Const"() {value = dense<[7, 7, 13, 3]> : tensor<4xi32>} : () -> tensor<4xi32> + %1 = "tf.BroadcastTo"(%arg0, %shape) : (tensor<13x1xi32>, tensor<4xi32>) -> tensor<3x3x13x3xi32> + return %1 : tensor<3x3x13x3xi32> +} + +// ----- + +// CHECK-LABEL: test_broadcast_to_i1 +// CHECK: %[[VAL_0:.*]] = "tosa.const"() <{value = dense : tensor<7x7x13x7xi1>} +// CHECK: %[[VAL_1:.*]] = "tosa.reshape"(%arg0) <{new_shape = array}> : (tensor<13x1xi1> +// CHECK: %[[VAL_2:.*]] = "tosa.logical_or"(%[[VAL_1]], %[[VAL_0]]) : (tensor<1x1x13x1xi1>, tensor<7x7x13x7xi1>) -> tensor<7x7x13x7xi1> +// CHECK: return %[[VAL_2]] : tensor<7x7x13x7xi1> +func.func @test_broadcast_to_i1(%arg0: tensor<13x1xi1>) -> (tensor<7x7x13x7xi1>) { + %shape = "tf.Const"() {value = dense<[7, 7, 13, 7]> : tensor<4xi32>} : () -> tensor<4xi32> + %1 = "tf.BroadcastTo"(%arg0, %shape) : (tensor<13x1xi1>, tensor<4xi32>) -> tensor<7x7x13x7xi1> + return %1 : tensor<7x7x13x7xi1> +} + +// ----- + +// CHECK-LABEL: test_broadcast_to_i16 +// CHECK: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<0> : tensor<7x7x13x3xi32>} +// CHECK: %[[VAL_1:.*]] = "tosa.reshape"(%arg0) <{new_shape = array} +// CHECK: %[[VAL_2:.*]] = "tosa.cast"(%1) : (tensor<1x1x13x1xi16>) -> tensor<1x1x13x1xi32> +// CHECK: %[[VAL_3:.*]] = "tosa.add"(%[[VAL_2]], %[[VAL_0]]) : (tensor<1x1x13x1xi32>, tensor<7x7x13x3xi32>) -> tensor<7x7x13x3xi32> +// CHECK: %[[VAL_4:.*]] = "tosa.cast"(%3) : (tensor<7x7x13x3xi32>) -> tensor<7x7x13x3xi16> +// CHECK: return %[[VAL_4]] : tensor<7x7x13x3xi16> +func.func @test_broadcast_to_i16(%arg0: tensor<13x1xi16>) -> (tensor<7x7x13x3xi16>) { + %shape = "tf.Const"() {value = dense<[7, 7, 1, 3]> : tensor<4xi32>} : () -> tensor<4xi32> + %1 = "tf.BroadcastTo"(%arg0, %shape) : (tensor<13x1xi16>, tensor<4xi32>) -> tensor<7x7x13x3xi16> + return %1 : tensor<7x7x13x3xi16> +} + +// ----- + +// CHECK-LABEL: test_broadcast_to_smaller_rank +// CHECK: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<[13, 7]> : tensor<2xi32>} +// CHECK: %[[VAL_1:.*]] = "tf.BroadcastTo"(%arg0, %[[VAL_0]]) : (tensor<2x3x13x1xi32>, tensor<2xi32>) -> tensor<13x7xi32> +// CHECK: return %[[VAL_1]] : tensor<13x7xi32> +func.func @test_broadcast_to_smaller_rank(%arg0: tensor<2x3x13x1xi32>) -> (tensor<13x7xi32>) { + %s = "tf.Const"() {value = dense<[13, 7]> : tensor<2xi32>} : () -> tensor<2xi32> + %1 = "tf.BroadcastTo"(%arg0, %s) : (tensor<2x3x13x1xi32>, tensor<2xi32>) -> tensor<13x7xi32> + return %1 : tensor<13x7xi32> +} diff --git a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir index 68e68345ce7436..8da78d19499041 100644 --- a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir @@ -2804,3 +2804,94 @@ func.func @test_squared_difference_f32(%arg0: tensor<1x197x768xf32>, %arg1: tens %0 = "tfl.squared_difference"(%arg0, %arg1) : (tensor<1x197x768xf32>, tensor<1x197x1xf32>) -> tensor<1x197x768xf32> func.return %0 : tensor<1x197x768xf32> } + +// ----- + +// CHECK-LABEL: test_broadcast_to_f32 +// CHECK: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<-0.000000e+00> : tensor<3x3x13x7xf32>} +// CHECK: %[[VAL_1:.*]] = "tosa.reshape"(%arg0) <{new_shape = array}> : (tensor<13x1xf32>) +// CHECK: %[[VAL_2:.*]] = "tosa.add"(%[[VAL_1]], %[[VAL_0]]) : (tensor<1x1x13x1xf32>, tensor<3x3x13x7xf32>) -> tensor<3x3x13x7xf32> +// CHECK: return %[[VAL_2]] : tensor<3x3x13x7xf32> +func.func @test_broadcast_to_f32(%arg0: tensor<13x1xf32>) -> (tensor<3x3x13x7xf32>) { + %shape = arith.constant dense<[3, 3, 1, 7]> : tensor<4xi32> + %1 = "tfl.broadcast_to"(%arg0, %shape) : (tensor<13x1xf32>, tensor<4xi32>) -> tensor<3x3x13x7xf32> + return %1 : tensor<3x3x13x7xf32> +} + +// ----- + +// CHECK-LABEL: test_broadcast_to_f16 +// CHECK: %[[VAL_0:.*]] = "tosa.const"() +// CHECK: %[[VAL_1:.*]] = "tosa.reshape"(%arg0) <{new_shape = array}> : (tensor<13x1xf16>) +// CHECK: %[[VAL_2:.*]] = "tosa.add"(%[[VAL_1]], %[[VAL_0]]) : (tensor<1x1x13x1xf16>, tensor<3x3x13x7xf16>) -> tensor<3x3x13x7xf16> +// CHECK: return %[[VAL_2]] : tensor<3x3x13x7xf16> +func.func @test_broadcast_to_f16(%arg0: tensor<13x1xf16>) -> (tensor<3x3x13x7xf16>) { + %shape = arith.constant dense<[3, 3, 1, 7]> : tensor<4xi32> + %1 = "tfl.broadcast_to"(%arg0, %shape) : (tensor<13x1xf16>, tensor<4xi32>) -> tensor<3x3x13x7xf16> + return %1 : tensor<3x3x13x7xf16> +} + +// ----- + +// CHECK-LABEL: test_broadcast_to_i32 +// CHECK: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<0> : tensor<7x7x13x3xi32>} +// CHECK: %[[VAL_1:.*]] = "tosa.reshape"(%arg0) <{new_shape = array}> : (tensor<13x1xi32> +// CHECK: %[[VAL_2:.*]] = "tosa.add"(%[[VAL_1]], %[[VAL_0]]) : (tensor<1x1x13x1xi32>, tensor<7x7x13x3xi32>) -> tensor<7x7x13x3xi32> +// CHECK: return %[[VAL_2]] : tensor<7x7x13x3xi32> +func.func @test_broadcast_to_i32(%arg0: tensor<13x1xi32>) -> (tensor<3x3x13x3xi32>) { + %shape = arith.constant dense<[7, 7, 13, 3]> : tensor<4xi64> + %1 = "tfl.broadcast_to"(%arg0, %shape) : (tensor<13x1xi32>, tensor<4xi64>) -> tensor<3x3x13x3xi32> + return %1 : tensor<3x3x13x3xi32> +} + +// ----- + +// CHECK-LABEL: test_broadcast_to_i1 +// CHECK: %[[VAL_0:.*]] = "tosa.const"() <{value = dense : tensor<7x7x13x7xi1>} +// CHECK: %[[VAL_1:.*]] = "tosa.reshape"(%arg0) <{new_shape = array}> : (tensor<13x1xi1> +// CHECK: %[[VAL_2:.*]] = "tosa.logical_or"(%[[VAL_1]], %[[VAL_0]]) : (tensor<1x1x13x1xi1>, tensor<7x7x13x7xi1>) -> tensor<7x7x13x7xi1> +// CHECK: return %[[VAL_2]] : tensor<7x7x13x7xi1> +func.func @test_broadcast_to_i1(%arg0: tensor<13x1xi1>) -> (tensor<7x7x13x7xi1>) { + %shape = arith.constant dense<[7, 7, 13, 7]> : tensor<4xi64> + %1 = "tfl.broadcast_to"(%arg0, %shape) : (tensor<13x1xi1>, tensor<4xi64>) -> tensor<7x7x13x7xi1> + return %1 : tensor<7x7x13x7xi1> +} + +// ----- + +// CHECK-LABEL: test_broadcast_to_qi8 +// CHECK: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<0> : tensor<7x7x13x3xi32>} +// CHECK: %[[VAL_1:.*]] = "tosa.reshape"(%arg0) <{new_shape = array} +// CHECK: %[[VAL_2:.*]] = "tosa.cast"(%1) : (tensor<1x1x13x1x!quant.uniform>) -> tensor<1x1x13x1xi32> +// CHECK: %[[VAL_3:.*]] = "tosa.add"(%[[VAL_2]], %[[VAL_0]]) : (tensor<1x1x13x1xi32>, tensor<7x7x13x3xi32>) -> tensor<7x7x13x3xi32> +// CHECK: %[[VAL_4:.*]] = "tosa.cast"(%3) : (tensor<7x7x13x3xi32>) -> tensor<7x7x13x3x!quant.uniform> +// CHECK: return %[[VAL_4]] : tensor<7x7x13x3x!quant.uniform> +func.func @test_broadcast_to_qi8(%arg0: tensor<13x1x!quant.uniform>) -> (tensor<7x7x13x3x!quant.uniform>) { + %shape = arith.constant dense<[7, 7, 1, 3]> : tensor<4xi64> + %1 = "tfl.broadcast_to"(%arg0, %shape) : (tensor<13x1x!quant.uniform>, tensor<4xi64>) -> tensor<7x7x13x3x!quant.uniform> + return %1 : tensor<7x7x13x3x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: test_broadcast_to_smaller_rank +// CHECK: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<[13, 7]> : tensor<2xi48>} +// CHECK: %[[VAL_1:.*]] = "tfl.broadcast_to"(%arg0, %[[VAL_0]]) : (tensor<2x3x13x1xi32>, tensor<2xi48>) -> tensor<13x7xi32> +// CHECK: return %[[VAL_1]] : tensor<13x7xi32> +func.func @test_broadcast_to_smaller_rank(%arg0: tensor<2x3x13x1xi32>) -> (tensor<13x7xi32>) { + %shape = arith.constant dense<[13, 7]> : tensor<2xi64> + %1 = "tfl.broadcast_to"(%arg0, %shape) : (tensor<2x3x13x1xi32>, tensor<2xi64>) -> tensor<13x7xi32> + return %1 : tensor<13x7xi32> +} + +// ----- + +// CHECK-LABEL: test_broadcast_to_i48 +// CHECK: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<[7, 7, 1, 7]> : tensor<4xi48>} +// CHECK: %[[VAL_1:.*]] = "tfl.broadcast_to"(%arg0, %[[VAL_0]]) : (tensor<1x1x13x1xi48>, tensor<4xi48>) -> tensor<7x7x13x7xi48> +// CHECK: return %[[VAL_1]] : tensor<7x7x13x7xi48> +func.func @test_broadcast_to_i48(%arg0: tensor<1x1x13x1xi48>) -> (tensor<7x7x13x7xi48>) { + %shape = arith.constant dense<[7, 7, 1, 7]> : tensor<4xi64> + %1 = "tfl.broadcast_to"(%arg0, %shape) : (tensor<1x1x13x1xi48>, tensor<4xi64>) -> tensor<7x7x13x7xi48> + return %1 : tensor<7x7x13x7xi48> +} diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc index d8785910105ea2..1c37792a84d06e 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc @@ -4610,5 +4610,149 @@ std::optional convertSignOp(PatternRewriter& rewriter, Operation* op, .getResult(); } +// Lowers BroadcastTo operator to a sequence of TOSA ops. +std::optional convertBroadcastToOp(PatternRewriter& rewriter, + Operation* op, Value input, + Value shape) { + RankedTensorType input_type = dyn_cast(input.getType()); + if (!input_type) { + (void)rewriter.notifyMatchFailure(op, "input type not ranked tensor"); + return std::nullopt; + } + + Type element_type = input_type.getElementType(); + if (element_type.isa()) { + (void)rewriter.notifyMatchFailure(op, "input element type is complex"); + return std::nullopt; + } + + if (element_type.isa()) { + auto bitwidth = element_type.getIntOrFloatBitWidth(); + if (bitwidth > 32) { + (void)rewriter.notifyMatchFailure(op, "input element type has greater than 32 bits"); + return std::nullopt; + } + } + + ElementsAttr shape_elems; + if (!matchPattern(shape, m_Constant(&shape_elems))) { + (void)rewriter.notifyMatchFailure(op, "shape is not constant"); + return std::nullopt; + } + int input_rank = input_type.getRank(); + int shape_rank = shape_elems.getNumElements(); + + if (auto shape_type = dyn_cast(shape.getType())) { + if (shape_type.hasStaticShape()) { + assert(shape_type.getRank() == 1); + if (!shape_type.isDynamicDim(0) && + shape_rank != shape_type.getDimSize(0)) { + // shape_elems and shape's type's 'are different + // this is not supported for now + (void)rewriter.notifyMatchFailure( + op, + "shape's constant value has different elements than its static " + "dimension"); + return std::nullopt; + } + } + } + + if (input_rank > shape_rank) { + // not clear what to do in this case, bail for now + (void)rewriter.notifyMatchFailure(op, "shape has less rank than input"); + return std::nullopt; + } + + // equalize new_rank and input_rank + if (input_rank < shape_rank) { + // reshape input to shape_rank + SmallVector reshaped_shape((shape_rank - input_rank), 1); + for (auto dim : input_type.getShape()) { + reshaped_shape.push_back(dim); + } + input_type = + tensorflow::GetTypeFromTFTensorShape(reshaped_shape, element_type); + input = CreateOpAndInfer( + rewriter, op->getLoc(), input_type, input, + rewriter.getDenseI64ArrayAttr( + tensorflow::ConvertMlirShapeToTF(reshaped_shape))); + } + + auto input_shape = input_type.getShape(); + assert(input_shape.size() == shape_rank); // should be equal ranks by now + + // construct new_shape as broadcasted shape of input_shape and shape_elems + int32_t num_elements = 1; + SmallVector new_shape; + for (int i = 0; i < shape_rank; i++) { + auto shape_dim = shape_elems.getValues()[i].getInt(); + auto input_dim = input_shape[i]; + if (shape_dim != input_dim && std::min(shape_dim, input_dim) != 1) { + // shape_dim and input_dim are different, but the lower value is not 1 + // this is not broadcastable + (void)rewriter.notifyMatchFailure( + op, "input and shape are not broadcastable"); + return std::nullopt; + } + auto dim = std::max(shape_dim, input_dim); + new_shape.push_back(dim); + num_elements *= dim; + } + + RankedTensorType output_type = + tensorflow::GetTypeFromTFTensorShape(new_shape, element_type); + + if (element_type.isa()) { + // F32: legalize to broadcastable Add with (-0.f) + SmallVector values(num_elements, -0.f); + auto const_attr = + DenseElementsAttr::get(output_type, llvm::ArrayRef(values)); + Value f32_const_zero = + rewriter.create(op->getLoc(), output_type, const_attr); + return CreateOpAndInfer(rewriter, op->getLoc(), output_type, + input, f32_const_zero) + .getResult(); + } + + if (element_type.isInteger(1)) { + // I1: legalize to broadcastable LogicalOr with false + SmallVector values(num_elements, 0); + auto const_attr = + DenseElementsAttr::get(output_type, llvm::ArrayRef(values)); + Value i1_const_zero = + rewriter.create(op->getLoc(), output_type, const_attr); + return CreateOpAndInfer( + rewriter, op->getLoc(), output_type, input, i1_const_zero) + .getResult(); + } + + SmallVector values(num_elements, 0); + RankedTensorType I32_shaped_type = + tensorflow::GetTypeFromTFTensorShape(new_shape, rewriter.getI32Type()); + auto const_attr = + DenseElementsAttr::get(I32_shaped_type, llvm::ArrayRef(values)); + Value I32_const_zero = + rewriter.create(op->getLoc(), I32_shaped_type, const_attr); + + if (element_type.isInteger(32)) { + // I32: legalize to broadcastable Add with 0 + return CreateOpAndInfer(rewriter, op->getLoc(), output_type, + input, I32_const_zero) + .getResult(); + } + + // for any other non-float element type: + // cast input to I32, Add with 0(I32), then cast back to output type + Value input_cast = CreateOpAndInfer( + rewriter, op->getLoc(), + /* I32 input type */ input_type.clone(rewriter.getI32Type()), input); + Value add_const = CreateOpAndInfer( + rewriter, op->getLoc(), I32_shaped_type, input_cast, I32_const_zero); + return CreateOpAndInfer(rewriter, op->getLoc(), output_type, + add_const) + .getResult(); +} + }; // namespace tosa }; // namespace mlir diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.h b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.h index 3dc87952753583..52073400b079f2 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.h +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.h @@ -306,6 +306,11 @@ std::optional convertSinOp(PatternRewriter& rewriter, Operation* op, std::optional convertSignOp(PatternRewriter& rewriter, Operation* op, Value input, RankedTensorType output_type); +// Lowers BroadcastTo operator to a sequence of TOSA ops. +std::optional convertBroadcastToOp(PatternRewriter& rewriter, + Operation* op, Value input, + Value shape); + }; // namespace tosa }; // namespace mlir diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc index 8b20350e5f5bef..313f06961140ce 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc @@ -151,6 +151,7 @@ DECL_CONVERT_OP(LeftShift); DECL_CONVERT_OP(RightShift); DECL_CONVERT_OP(OneHot); DECL_CONVERT_OP(BatchMatMulV2); +DECL_CONVERT_OP(BroadcastTo); #undef DECL_CONVERT_OP LogicalResult ConvertTFReluOp::matchAndRewrite( @@ -2421,6 +2422,21 @@ LogicalResult ConvertTFBatchMatMulV2Op::matchAndRewrite( return success(); } +LogicalResult ConvertTFBroadcastToOp::matchAndRewrite( + Operation* op, PatternRewriter& rewriter) const { + auto tf_broadcast_to_op = cast(op); + + std::optional result = + convertBroadcastToOp(rewriter, op, tf_broadcast_to_op.getInput(), + tf_broadcast_to_op.getShape()); + + if (!result) return failure(); + + rewriter.replaceOp(op, {result.value()}); + + return success(); +} + void LegalizeTF::runOnOperation() { auto* ctx = &getContext(); RewritePatternSet patterns(ctx); @@ -2523,6 +2539,7 @@ void populateLegalizeTFPatterns(MLIRContext* ctx, RewritePatternSet& patterns) { patterns.add(ctx); patterns.add(ctx); patterns.add(ctx); + patterns.add(ctx); } // Creates an instance of the TensorFlow dialect LegalizeTF pass. diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc index 8863a3d0df4031..bd53b5a3acd974 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc @@ -193,6 +193,7 @@ DECL_CONVERT_OP(While); DECL_CONVERT_OP(Real); DECL_CONVERT_OP(Imag); DECL_CONVERT_OP(RFFT2d); +DECL_CONVERT_OP(BroadcastTo); #undef DECL_CONVERT_OP @@ -4478,6 +4479,21 @@ LogicalResult ConvertTFLRFFT2dOp::matchAndRewrite( return success(); } +LogicalResult ConvertTFLBroadcastToOp::matchAndRewrite( + Operation* op, PatternRewriter& rewriter) const { + auto tfl_broadcast_to_op = cast(op); + + std::optional result = + convertBroadcastToOp(rewriter, op, tfl_broadcast_to_op.getInput(), + tfl_broadcast_to_op.getShape()); + + if (!result) return failure(); + + rewriter.replaceOp(op, {result.value()}); + + return success(); +} + LogicalResult LegalizeTFL::initialize(MLIRContext* context) { RewritePatternSet patterns(context); mlir::tosa::populateLegalizeTFLPatterns(context, patterns); @@ -4615,6 +4631,7 @@ void populateLegalizeTFLPatterns(MLIRContext* ctx, DEF_PATTERN_INSERT(TFLReal); DEF_PATTERN_INSERT(TFLImag); DEF_PATTERN_INSERT(TFLRFFT2d); + DEF_PATTERN_INSERT(TFLBroadcastTo); } // Creates an instance of the TensorFlow Lite dialect LegalizeTFL pass. From 45a3dff3490af05b88f87faf319a1fe1b71a9007 Mon Sep 17 00:00:00 2001 From: Gauri1 Deshpande Date: Fri, 14 Jul 2023 15:54:27 -0700 Subject: [PATCH 024/410] Change errors::* -> absl::*, and necessary formatting etc changes --- tensorflow/core/kernels/mkl/mkl_conv_ops.cc | 197 ++++++++++---------- 1 file changed, 103 insertions(+), 94 deletions(-) diff --git a/tensorflow/core/kernels/mkl/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl/mkl_conv_ops.cc index 8dae3705e0a811..24a0259cdef3bf 100644 --- a/tensorflow/core/kernels/mkl/mkl_conv_ops.cc +++ b/tensorflow/core/kernels/mkl/mkl_conv_ops.cc @@ -620,7 +620,7 @@ class MklConvOp : public OpKernel { context, !(context->HasAttr("padding_list") && context->HasAttr("explicit_paddings")), - errors::InvalidArgument("Can only have 1 `padding` list at most")); + absl::InvalidArgumentError("Can only have 1 `padding` list at most")); if (context->HasAttr("padding_list")) { OP_REQUIRES_OK(context, context->GetAttr("padding_list", &padding_list_)); } @@ -632,17 +632,17 @@ class MklConvOp : public OpKernel { OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_)); OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format_str_)); OP_REQUIRES(context, FormatFromString(data_format_str_, &data_format_), - errors::InvalidArgument("Invalid data format")); + absl::InvalidArgumentError("Invalid data format")); OP_REQUIRES(context, (strides_.size() == 4 || strides_.size() == 5), - errors::InvalidArgument("Sliding window strides field must " - "specify 4 or 5 dimensions")); + absl::InvalidArgumentError("Sliding window strides field must " + "specify 4 or 5 dimensions")); const int64 stride_n = GetTensorDim(strides_, data_format_, 'N'); const int64 stride_c = GetTensorDim(strides_, data_format_, 'C'); OP_REQUIRES( context, stride_n == 1 && stride_c == 1, - errors::Unimplemented("Current implementation does not yet support " - "strides in the batch and depth dimensions.")); + absl::UnimplementedError("Current implementation does not yet support " + "strides in the batch and depth dimensions.")); OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); is_filter_const_ = false; @@ -654,28 +654,29 @@ class MklConvOp : public OpKernel { } if (strides_.size() == 4) { - OP_REQUIRES(context, dilations_.size() == 4, - errors::InvalidArgument("Sliding window dilations field must " - "specify 4 dimensions")); + OP_REQUIRES( + context, dilations_.size() == 4, + absl::InvalidArgumentError("Sliding window dilations field must " + "specify 4 dimensions")); const int64 dilation_n = GetTensorDim(dilations_, data_format_, 'N'); const int64 dilation_c = GetTensorDim(dilations_, data_format_, 'C'); const int64 dilation_h = GetTensorDim(dilations_, data_format_, 'H'); const int64 dilation_w = GetTensorDim(dilations_, data_format_, 'W'); OP_REQUIRES(context, dilation_n == 1 && dilation_c == 1, - errors::InvalidArgument( + absl::InvalidArgumentError( "Current implementation does not yet support " "dilations in the batch and depth dimensions.")); OP_REQUIRES( context, dilation_h > 0 && dilation_w > 0, - errors::InvalidArgument("Dilated rates should be larger than 0.")); + absl::InvalidArgumentError("Dilated rates should be larger than 0.")); } else if (strides_.size() == 5) { OP_REQUIRES(context, dilations_.size() == 5, - errors::InvalidArgument("Dilation rates field must " - "specify 5 dimensions")); + absl::InvalidArgumentError("Dilation rates field must " + "specify 5 dimensions")); OP_REQUIRES(context, (GetTensorDim(dilations_, data_format_, 'N') == 1 && GetTensorDim(dilations_, data_format_, 'C') == 1), - errors::InvalidArgument( + absl::InvalidArgumentError( "Current implementation does not yet support " "dilations rates in the batch and depth dimensions.")); OP_REQUIRES( @@ -683,7 +684,7 @@ class MklConvOp : public OpKernel { (GetTensorDim(dilations_, data_format_, '0') > 0 && GetTensorDim(dilations_, data_format_, '1') > 0 && GetTensorDim(dilations_, data_format_, '2') > 0), - errors::InvalidArgument("Dilated rates should be larger than 0.")); + absl::InvalidArgumentError("Dilated rates should be larger than 0.")); } } @@ -694,8 +695,8 @@ class MklConvOp : public OpKernel { const Tensor& filter_tensor = MklGetInput(context, kInputIndex_Filter); OP_REQUIRES( context, filter_tensor.NumElements() > 0, - errors::InvalidArgument("filter must not have zero elements " - "(i.e. all dimensions must be non-zero)")); + absl::InvalidArgumentError("filter must not have zero elements " + "(i.e. all dimensions must be non-zero)")); if (std::is_same::value) { (void)SetFPMathMode(); @@ -707,8 +708,8 @@ class MklConvOp : public OpKernel { native_format); OP_REQUIRES(context, !filter_mkl_shape.IsMklTensor(), - errors::InvalidArgument("Filter should not be in " - "Mkl Layout")); + absl::InvalidArgumentError("Filter should not be in " + "Mkl Layout")); MklDnnData src(&cpu_engine_); MklDnnData filter(&cpu_engine_); @@ -780,18 +781,18 @@ class MklConvOp : public OpKernel { bool is_conv3d = (strides_.size() == 5); if (!is_conv2d && !is_conv3d) { - OP_REQUIRES( - context, !pad_enabled, - errors::InvalidArgument("Pad + Conv fusion only works for 2D/3D")); + OP_REQUIRES(context, !pad_enabled, + absl::InvalidArgumentError( + "Pad + Conv fusion only works for 2D/3D")); OP_REQUIRES( context, !fuse_pad_, - errors::InvalidArgument("Pad+Conv fusion only works for 2D/3D")); + absl::InvalidArgumentError("Pad+Conv fusion only works for 2D/3D")); } // TODO(intel-tf) 3-D support for Depthwise is not there if (is_depthwise) { OP_REQUIRES(context, is_conv2d, - errors::InvalidArgument( + absl::InvalidArgumentError( "Only 2D convolution is supported for depthwise.")); } @@ -804,7 +805,7 @@ class MklConvOp : public OpKernel { auto mkl_fmt_tag = MklTensorFormatToMklDnnDataFormat(tf_fmt); // NOTE: `mkl_fmt_tag` will be `format_tag::undef` for ReLU OP_REQUIRES(context, mkl_fmt_tag != memory::format_tag::undef, - errors::InvalidArgument("Invalid data format")); + absl::InvalidArgumentError("Invalid data format")); // If input is in MKL layout, then simply grab the layout; otherwise, // construct TF layout for input. @@ -862,8 +863,9 @@ class MklConvOp : public OpKernel { // Inputs to FusedBatchNorm have same 1D shape fuse_bn_shape = MklGetInput(context, kInputIndex_BN_Mean).shape(); OP_REQUIRES(context, fuse_bn_shape.dims() == 1, - errors::InvalidArgument("FusedBatchNorm must be 1D, not: ", - fuse_bn_shape.DebugString())); + absl::InvalidArgumentError( + absl::StrCat("FusedBatchNorm must be 1D, not: ", + fuse_bn_shape.DebugString()))); // Note - MKL-DNN expects {1, C, 1, 1} for binary post-op even for NHWC fuse_bn_dims = {1, fuse_bn_shape.dim_size(0), 1, 1}; @@ -991,9 +993,9 @@ class MklConvOp : public OpKernel { string error_msg = tensorflow::strings::StrCat( "Status: ", e.status, ", message: ", string(e.message), ", in file ", __FILE__, ":", __LINE__); - OP_REQUIRES_OK( - context, - errors::Aborted("Operation received an exception:", error_msg)); + OP_REQUIRES_OK(context, + absl::AbortedError(absl::StrCat( + "Operation received an exception:", error_msg))); } } @@ -1006,8 +1008,9 @@ class MklConvOp : public OpKernel { } else { const Tensor& paddings_tf = MklGetInput(context, input_index_pad_); OP_REQUIRES(context, paddings_tf.dims() == 2, - errors::InvalidArgument("paddings must be 2-dimensional: ", - paddings_tf.shape().DebugString())); + absl::InvalidArgumentError( + absl::StrCat("paddings must be 2-dimensional: ", + paddings_tf.shape().DebugString()))); // Flatten tensor to get individual paddings. paddings = static_cast( const_cast(paddings_tf.flat().data())); @@ -1102,9 +1105,9 @@ class MklConvOp : public OpKernel { virtual void ComputeBNScale(OpKernelContext* context, float epsilon, int bn_variance_index, Tinput* scale_buf_ptr) { - OP_REQUIRES( - context, false, - errors::Unimplemented("Compute BN scale not expected in base class")); + OP_REQUIRES(context, false, + absl::UnimplementedError( + "Compute BN scale not expected in base class")); return; } @@ -1215,7 +1218,7 @@ class MklConvOp : public OpKernel { auto output_format_tag = MklTensorFormatToMklDnnDataFormat( output_mkl_shape->GetTfDataFormat()); OP_REQUIRES(context, output_format_tag != memory::format_tag::undef, - errors::InvalidArgument( + absl::InvalidArgumentError( "MklConvOp: AddN fusion: Invalid data format")); auto add_md = add_mkl_shape.IsMklTensor() @@ -1493,14 +1496,14 @@ class MklFusedConvOp int num_args; OP_REQUIRES_OK(context, context->GetAttr("num_args", &num_args)); OP_REQUIRES(context, !fused_ops.empty(), - errors::InvalidArgument( + absl::InvalidArgumentError( "Fused Conv2D must have at least one fused op.")); // TODO(intel-tf): Compact the code for activation checking if (fused_ops == std::vector{"BiasAdd"}) { this->set_fuse_biasadd(true); OP_REQUIRES(context, num_args == 1, - errors::InvalidArgument( + absl::InvalidArgumentError( "Fused Conv2D must have one extra argument: bias.")); } else if (fused_ops == std::vector{"Relu"}) { this->set_fuse_activation(true, dnnl::algorithm::eltwise_relu); @@ -1519,26 +1522,26 @@ class MklFusedConvOp OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon)); OP_REQUIRES( context, num_args == 4, - errors::InvalidArgument( + absl::InvalidArgumentError( "Fused Conv2D with batchnorm must have 4 extra argument")); this->set_fuse_bn(true, epsilon); } else if (fused_ops == std::vector{"BiasAdd", "Relu"}) { this->set_fuse_biasadd(true); this->set_fuse_activation(true, dnnl::algorithm::eltwise_relu); OP_REQUIRES(context, num_args == 1, - errors::InvalidArgument( + absl::InvalidArgumentError( "Fused Conv2D must have one extra argument: bias.")); } else if (fused_ops == std::vector{"BiasAdd", "Relu6"}) { this->set_fuse_biasadd(true); this->SET_FUSE_ACTIVATION_FOR_RELU6; OP_REQUIRES(context, num_args == 1, - errors::InvalidArgument( + absl::InvalidArgumentError( "Fused Conv2D must have one extra argument: bias.")); } else if (fused_ops == std::vector{"BiasAdd", "Elu"}) { this->set_fuse_biasadd(true); this->set_fuse_activation(true, dnnl::algorithm::eltwise_elu, 1.0); OP_REQUIRES(context, num_args == 1, - errors::InvalidArgument( + absl::InvalidArgumentError( "Fused Conv2D must have one extra argument: bias.")); } else if (fused_ops == std::vector{"BiasAdd", "LeakyRelu"}) { this->set_fuse_biasadd(true); @@ -1548,21 +1551,21 @@ class MklFusedConvOp this->set_fuse_activation(true, dnnl::algorithm::eltwise_relu, leakyrelu_alpha); OP_REQUIRES(context, num_args == 1, - errors::InvalidArgument( + absl::InvalidArgumentError( "Fused Conv2D must have one extra argument: bias.")); } else if (fused_ops == std::vector{"BiasAdd", "Add"}) { this->set_fuse_biasadd(true); this->set_fuse_add(true); OP_REQUIRES( context, num_args == 2, - errors::InvalidArgument( + absl::InvalidArgumentError( "Fused Conv2D must have two extra arguments: bias and add.")); } else if (fused_ops == std::vector{"FusedBatchNorm", "Relu"}) { float epsilon; OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon)); OP_REQUIRES( context, num_args == 4, - errors::InvalidArgument( + absl::InvalidArgumentError( "Fused Conv2D with batchnorm must have 4 extra argument")); this->set_fuse_bn(true, epsilon); this->set_fuse_activation(true, dnnl::algorithm::eltwise_relu); @@ -1571,7 +1574,7 @@ class MklFusedConvOp OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon)); OP_REQUIRES( context, num_args == 4, - errors::InvalidArgument( + absl::InvalidArgumentError( "Fused Conv2D with batchnorm must have 4 extra argument")); this->set_fuse_bn(true, epsilon); this->SET_FUSE_ACTIVATION_FOR_RELU6; @@ -1580,7 +1583,7 @@ class MklFusedConvOp OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon)); OP_REQUIRES( context, num_args == 4, - errors::InvalidArgument( + absl::InvalidArgumentError( "Fused Conv2D with batchnorm must have 4 extra argument")); this->set_fuse_bn(true, epsilon); this->set_fuse_activation(true, dnnl::algorithm::eltwise_elu, 1.0); @@ -1592,7 +1595,7 @@ class MklFusedConvOp context->GetAttr("leakyrelu_alpha", &leakyrelu_alpha)); OP_REQUIRES( context, num_args == 4, - errors::InvalidArgument( + absl::InvalidArgumentError( "Fused Conv2D with batchnorm must have 4 extra argument")); this->set_fuse_bn(true, epsilon); this->set_fuse_activation(true, dnnl::algorithm::eltwise_relu, @@ -1603,7 +1606,7 @@ class MklFusedConvOp OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon)); OP_REQUIRES( context, num_args == 4, - errors::InvalidArgument( + absl::InvalidArgumentError( "Fused Conv2D with batchnorm must have 4 extra argument")); this->set_fuse_bn(true, epsilon); this->set_fuse_activation(true, dnnl::algorithm::eltwise_swish, 1.0); @@ -1613,7 +1616,7 @@ class MklFusedConvOp this->set_fuse_activation(true, dnnl::algorithm::eltwise_relu); OP_REQUIRES( context, num_args == 2, - errors::InvalidArgument( + absl::InvalidArgumentError( "Fused Conv2D must have two extra arguments: bias and add.")); } else if (fused_ops == std::vector{"BiasAdd", "Add", "Relu6"}) { this->set_fuse_biasadd(true); @@ -1621,7 +1624,7 @@ class MklFusedConvOp this->SET_FUSE_ACTIVATION_FOR_RELU6; OP_REQUIRES( context, num_args == 2, - errors::InvalidArgument( + absl::InvalidArgumentError( "Fused Conv2D must have two extra arguments: bias and add.")); } else if (fused_ops == std::vector{"BiasAdd", "Add", "Elu"}) { this->set_fuse_biasadd(true); @@ -1629,7 +1632,7 @@ class MklFusedConvOp this->set_fuse_activation(true, dnnl::algorithm::eltwise_elu, 1.0); OP_REQUIRES( context, num_args == 2, - errors::InvalidArgument( + absl::InvalidArgumentError( "Fused Conv2D must have two extra arguments: bias and add.")); } else if (fused_ops == std::vector{"BiasAdd", "Add", "LeakyRelu"}) { @@ -1642,24 +1645,25 @@ class MklFusedConvOp leakyrelu_alpha); OP_REQUIRES( context, num_args == 2, - errors::InvalidArgument( + absl::InvalidArgumentError( "Fused Conv2D must have two extra arguments: bias and add.")); } else if (fused_ops == std::vector{"BiasAdd", "Mish"}) { this->set_fuse_biasadd(true); this->set_fuse_activation(true, dnnl::algorithm::eltwise_mish, 1.0); OP_REQUIRES(context, num_args == 1, - errors::InvalidArgument( + absl::InvalidArgumentError( "_FusedConv2D must have one extra argument: bias.")); } else if (fused_ops == std::vector{"BiasAdd", "_MklSwish"}) { this->set_fuse_biasadd(true); this->set_fuse_activation(true, dnnl::algorithm::eltwise_swish, 1.0); OP_REQUIRES(context, num_args == 1, - errors::InvalidArgument( + absl::InvalidArgumentError( "Fused Conv2D must have one extra argument: bias.")); } else { OP_REQUIRES(context, false, - errors::Unimplemented("Fusion is not implemented: [", - absl::StrJoin(fused_ops, ","), "]")); + absl::UnimplementedError( + absl::StrCat("Fusion is not implemented: [", + absl::StrJoin(fused_ops, ","), "]"))); } if (pad_enabled) { @@ -1706,7 +1710,7 @@ class MklFusedDepthwiseConvOp int num_args; OP_REQUIRES_OK(context, context->GetAttr("num_args", &num_args)); OP_REQUIRES(context, !fused_ops.empty(), - errors::InvalidArgument( + absl::InvalidArgumentError( "Fused DepthwiseConv2D must have at least one fused op.")); if (fused_ops == std::vector{"BiasAdd"}) { @@ -1722,13 +1726,14 @@ class MklFusedDepthwiseConvOp this->set_fuse_activation(true, dnnl::algorithm::eltwise_elu, 1.0); } else { OP_REQUIRES(context, false, - errors::Unimplemented("Fusion is not implemented: [", - absl::StrJoin(fused_ops, ","), "]")); + absl::UnimplementedError( + absl::StrCat("Fusion is not implemented: [", + absl::StrJoin(fused_ops, ","), "]"))); } OP_REQUIRES( context, num_args == 1, - errors::InvalidArgument( + absl::InvalidArgumentError( "Fused DepthwiseConv2D must have one extra argument: bias.")); if (pad_enabled) { @@ -1796,7 +1801,7 @@ class MklQuantizedConvOp // TODO(intel-tf): num_fused_ops and legacy_fused_ops should go away once // old API is abandoned. OP_REQUIRES(context, !(fused_ops_attr.size() > 0 && num_fused_ops > 0), - errors::InvalidArgument( + absl::InvalidArgumentError( "QuantizedConv fused ops should be only available through " "either new API or old API, got both.")); @@ -1813,8 +1818,9 @@ class MklQuantizedConvOp std::find(supported_fusions.begin(), supported_fusions.end(), fused_ops_) != supported_fusions.end(); OP_REQUIRES(context, is_fusion_supported, - errors::InvalidArgument("Unsupported QuantizedConv fusion: [", - absl::StrJoin(fused_ops_, ","), "]")); + absl::InvalidArgumentError( + absl::StrCat("Unsupported QuantizedConv fusion: [", + absl::StrJoin(fused_ops_, ","), "]"))); } // Set the flag for every fused op. @@ -1838,9 +1844,10 @@ class MklQuantizedConvOp const bool fuse_requantize = IsFused(oneDNNFusedOps::kRequantize); OP_REQUIRES_OK(context, context->GetAttr("out_type", &out_dt)); if (fuse_requantize) { - OP_REQUIRES(context, out_dt == DT_QINT8 || out_dt == DT_QUINT8, - errors::InvalidArgument("QuantizedConv: unsupported output " - "type when Requantize is fused.")); + OP_REQUIRES( + context, out_dt == DT_QINT8 || out_dt == DT_QUINT8, + absl::InvalidArgumentError("QuantizedConv: unsupported output " + "type when Requantize is fused.")); } if (context->HasAttr("Tsummand")) { @@ -1848,7 +1855,7 @@ class MklQuantizedConvOp if (!this->get_fuse_add()) { OP_REQUIRES( context, summand_dt == out_dt, - errors::InvalidArgument( + absl::InvalidArgumentError( "QuantizedConv: incorrect summand data type. When Sum is not " "fused, Tsummand attribute must have same value as out_type.")); } @@ -1874,7 +1881,7 @@ class MklQuantizedConvOp OP_REQUIRES( context, is_filter_const, - errors::InvalidArgument("QuantizedConv: filter must be a constant")); + absl::InvalidArgumentError("QuantizedConv: filter must be a constant")); if (num_fused_ops == -1) { // If num_fused_ops is -1 then the new API (ops) are being used. @@ -2020,13 +2027,13 @@ class MklQuantizedConvOp context->input(max_input_idx_).template scalar()(); const Tensor& min_filter_vector = context->input(min_filter_idx_); const Tensor& max_filter_vector = context->input(max_filter_idx_); - OP_REQUIRES( - context, - ((min_filter_vector.NumElements() > 0) && - (max_filter_vector.NumElements() > 0) && - (min_filter_vector.shape() == max_filter_vector.shape())), - errors::InvalidArgument("`min_ and max_filter` must have same" - "shape and contain at least one element.")); + OP_REQUIRES(context, + ((min_filter_vector.NumElements() > 0) && + (max_filter_vector.NumElements() > 0) && + (min_filter_vector.shape() == max_filter_vector.shape())), + absl::InvalidArgumentError( + "`min_ and max_filter` must have same" + "shape and contain at least one element.")); // min_freezed_output and max_freezed_output are the actual range // for the output. @@ -2086,15 +2093,15 @@ class MklQuantizedConvOp OP_REQUIRES( context, TensorShapeUtils::IsScalar(min_freezed_output_tensor.shape()), - errors::InvalidArgument( - "`min_freezed_output` must be rank 0 but is rank ", - min_freezed_output_tensor.dims())); + absl::InvalidArgumentError( + absl::StrCat("`min_freezed_output` must be rank 0 but is rank ", + min_freezed_output_tensor.dims()))); OP_REQUIRES( context, TensorShapeUtils::IsScalar(max_freezed_output_tensor.shape()), - errors::InvalidArgument( - "`max_freezed_output` must be rank 0 but is rank ", - max_freezed_output_tensor.dims())); + absl::InvalidArgumentError( + absl::StrCat("`max_freezed_output` must be rank 0 but is rank ", + max_freezed_output_tensor.dims()))); const Tensor& min_freezed_summand_tensor = context->input(min_summand_idx_); const Tensor& max_freezed_summand_tensor = @@ -2102,15 +2109,15 @@ class MklQuantizedConvOp OP_REQUIRES( context, TensorShapeUtils::IsScalar(min_freezed_summand_tensor.shape()), - errors::InvalidArgument( + absl::InvalidArgumentError(absl::StrCat( "`min_freezed_summand` must be rank 0 but is rank ", - min_freezed_summand_tensor.dims())); + min_freezed_summand_tensor.dims()))); OP_REQUIRES( context, TensorShapeUtils::IsScalar(max_freezed_summand_tensor.shape()), - errors::InvalidArgument( + absl::InvalidArgumentError(absl::StrCat( "`max_freezed_summand` must be rank 0 but is rank ", - max_freezed_summand_tensor.dims())); + max_freezed_summand_tensor.dims()))); const float min_freezed_output = min_freezed_output_tensor.template scalar()(); const float max_freezed_output = @@ -2185,7 +2192,7 @@ class MklQuantizedConvOp OP_REQUIRES(context, context->forward_input_to_output_with_shape( summand_idx, 0, summand.shape(), output_tensor), - errors::InvalidArgument( + absl::InvalidArgumentError( "Summand cannot be forwarded in the current fusion.")); return; } @@ -2466,13 +2473,14 @@ class MklFusedConv3DOp std::vector padding_list; OP_REQUIRES_OK(context, context->GetAttr("padding_list", &padding_list)); if (padding_list.empty()) { - OP_REQUIRES(context, !fused_ops.empty(), - errors::InvalidArgument("Fused Conv3D must have at least one " - "fused op when Pad is not fused.")); + OP_REQUIRES( + context, !fused_ops.empty(), + absl::InvalidArgumentError("Fused Conv3D must have at least one " + "fused op when Pad is not fused.")); if (std::find(fused_ops.begin(), fused_ops.end(), "BiasAdd") == fused_ops.end()) { OP_REQUIRES(context, num_args == 1, - errors::InvalidArgument( + absl::InvalidArgumentError( "Fused Conv3D must have one extra argument: bias.")); } else if (std::find(fused_ops.begin(), fused_ops.end(), "BiasAdd") == fused_ops.end() && @@ -2480,7 +2488,7 @@ class MklFusedConv3DOp fused_ops.end()) { OP_REQUIRES( context, num_args == 2, - errors::InvalidArgument( + absl::InvalidArgumentError( "Fused Conv3D must have two extra arguments: bias and add.")); } } @@ -2533,8 +2541,9 @@ class MklFusedConv3DOp } else { if (padding_list.empty()) { OP_REQUIRES(context, false, - errors::Unimplemented("Fusion is not implemented: [", - absl::StrJoin(fused_ops, ","), "]")); + absl::UnimplementedError( + absl::StrCat("Fusion is not implemented: [", + absl::StrJoin(fused_ops, ","), "]"))); } } } From 8248b0175a25e9e18ec0917cf924c2b9da793f8b Mon Sep 17 00:00:00 2001 From: Gauri1 Deshpande Date: Fri, 14 Jul 2023 16:47:26 -0700 Subject: [PATCH 025/410] Revert "Change errors::* -> absl::*, and necessary formatting etc changes" This reverts commit 45a3dff3490af05b88f87faf319a1fe1b71a9007. --- tensorflow/core/kernels/mkl/mkl_conv_ops.cc | 197 ++++++++++---------- 1 file changed, 94 insertions(+), 103 deletions(-) diff --git a/tensorflow/core/kernels/mkl/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl/mkl_conv_ops.cc index 24a0259cdef3bf..8dae3705e0a811 100644 --- a/tensorflow/core/kernels/mkl/mkl_conv_ops.cc +++ b/tensorflow/core/kernels/mkl/mkl_conv_ops.cc @@ -620,7 +620,7 @@ class MklConvOp : public OpKernel { context, !(context->HasAttr("padding_list") && context->HasAttr("explicit_paddings")), - absl::InvalidArgumentError("Can only have 1 `padding` list at most")); + errors::InvalidArgument("Can only have 1 `padding` list at most")); if (context->HasAttr("padding_list")) { OP_REQUIRES_OK(context, context->GetAttr("padding_list", &padding_list_)); } @@ -632,17 +632,17 @@ class MklConvOp : public OpKernel { OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_)); OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format_str_)); OP_REQUIRES(context, FormatFromString(data_format_str_, &data_format_), - absl::InvalidArgumentError("Invalid data format")); + errors::InvalidArgument("Invalid data format")); OP_REQUIRES(context, (strides_.size() == 4 || strides_.size() == 5), - absl::InvalidArgumentError("Sliding window strides field must " - "specify 4 or 5 dimensions")); + errors::InvalidArgument("Sliding window strides field must " + "specify 4 or 5 dimensions")); const int64 stride_n = GetTensorDim(strides_, data_format_, 'N'); const int64 stride_c = GetTensorDim(strides_, data_format_, 'C'); OP_REQUIRES( context, stride_n == 1 && stride_c == 1, - absl::UnimplementedError("Current implementation does not yet support " - "strides in the batch and depth dimensions.")); + errors::Unimplemented("Current implementation does not yet support " + "strides in the batch and depth dimensions.")); OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); is_filter_const_ = false; @@ -654,29 +654,28 @@ class MklConvOp : public OpKernel { } if (strides_.size() == 4) { - OP_REQUIRES( - context, dilations_.size() == 4, - absl::InvalidArgumentError("Sliding window dilations field must " - "specify 4 dimensions")); + OP_REQUIRES(context, dilations_.size() == 4, + errors::InvalidArgument("Sliding window dilations field must " + "specify 4 dimensions")); const int64 dilation_n = GetTensorDim(dilations_, data_format_, 'N'); const int64 dilation_c = GetTensorDim(dilations_, data_format_, 'C'); const int64 dilation_h = GetTensorDim(dilations_, data_format_, 'H'); const int64 dilation_w = GetTensorDim(dilations_, data_format_, 'W'); OP_REQUIRES(context, dilation_n == 1 && dilation_c == 1, - absl::InvalidArgumentError( + errors::InvalidArgument( "Current implementation does not yet support " "dilations in the batch and depth dimensions.")); OP_REQUIRES( context, dilation_h > 0 && dilation_w > 0, - absl::InvalidArgumentError("Dilated rates should be larger than 0.")); + errors::InvalidArgument("Dilated rates should be larger than 0.")); } else if (strides_.size() == 5) { OP_REQUIRES(context, dilations_.size() == 5, - absl::InvalidArgumentError("Dilation rates field must " - "specify 5 dimensions")); + errors::InvalidArgument("Dilation rates field must " + "specify 5 dimensions")); OP_REQUIRES(context, (GetTensorDim(dilations_, data_format_, 'N') == 1 && GetTensorDim(dilations_, data_format_, 'C') == 1), - absl::InvalidArgumentError( + errors::InvalidArgument( "Current implementation does not yet support " "dilations rates in the batch and depth dimensions.")); OP_REQUIRES( @@ -684,7 +683,7 @@ class MklConvOp : public OpKernel { (GetTensorDim(dilations_, data_format_, '0') > 0 && GetTensorDim(dilations_, data_format_, '1') > 0 && GetTensorDim(dilations_, data_format_, '2') > 0), - absl::InvalidArgumentError("Dilated rates should be larger than 0.")); + errors::InvalidArgument("Dilated rates should be larger than 0.")); } } @@ -695,8 +694,8 @@ class MklConvOp : public OpKernel { const Tensor& filter_tensor = MklGetInput(context, kInputIndex_Filter); OP_REQUIRES( context, filter_tensor.NumElements() > 0, - absl::InvalidArgumentError("filter must not have zero elements " - "(i.e. all dimensions must be non-zero)")); + errors::InvalidArgument("filter must not have zero elements " + "(i.e. all dimensions must be non-zero)")); if (std::is_same::value) { (void)SetFPMathMode(); @@ -708,8 +707,8 @@ class MklConvOp : public OpKernel { native_format); OP_REQUIRES(context, !filter_mkl_shape.IsMklTensor(), - absl::InvalidArgumentError("Filter should not be in " - "Mkl Layout")); + errors::InvalidArgument("Filter should not be in " + "Mkl Layout")); MklDnnData src(&cpu_engine_); MklDnnData filter(&cpu_engine_); @@ -781,18 +780,18 @@ class MklConvOp : public OpKernel { bool is_conv3d = (strides_.size() == 5); if (!is_conv2d && !is_conv3d) { - OP_REQUIRES(context, !pad_enabled, - absl::InvalidArgumentError( - "Pad + Conv fusion only works for 2D/3D")); + OP_REQUIRES( + context, !pad_enabled, + errors::InvalidArgument("Pad + Conv fusion only works for 2D/3D")); OP_REQUIRES( context, !fuse_pad_, - absl::InvalidArgumentError("Pad+Conv fusion only works for 2D/3D")); + errors::InvalidArgument("Pad+Conv fusion only works for 2D/3D")); } // TODO(intel-tf) 3-D support for Depthwise is not there if (is_depthwise) { OP_REQUIRES(context, is_conv2d, - absl::InvalidArgumentError( + errors::InvalidArgument( "Only 2D convolution is supported for depthwise.")); } @@ -805,7 +804,7 @@ class MklConvOp : public OpKernel { auto mkl_fmt_tag = MklTensorFormatToMklDnnDataFormat(tf_fmt); // NOTE: `mkl_fmt_tag` will be `format_tag::undef` for ReLU OP_REQUIRES(context, mkl_fmt_tag != memory::format_tag::undef, - absl::InvalidArgumentError("Invalid data format")); + errors::InvalidArgument("Invalid data format")); // If input is in MKL layout, then simply grab the layout; otherwise, // construct TF layout for input. @@ -863,9 +862,8 @@ class MklConvOp : public OpKernel { // Inputs to FusedBatchNorm have same 1D shape fuse_bn_shape = MklGetInput(context, kInputIndex_BN_Mean).shape(); OP_REQUIRES(context, fuse_bn_shape.dims() == 1, - absl::InvalidArgumentError( - absl::StrCat("FusedBatchNorm must be 1D, not: ", - fuse_bn_shape.DebugString()))); + errors::InvalidArgument("FusedBatchNorm must be 1D, not: ", + fuse_bn_shape.DebugString())); // Note - MKL-DNN expects {1, C, 1, 1} for binary post-op even for NHWC fuse_bn_dims = {1, fuse_bn_shape.dim_size(0), 1, 1}; @@ -993,9 +991,9 @@ class MklConvOp : public OpKernel { string error_msg = tensorflow::strings::StrCat( "Status: ", e.status, ", message: ", string(e.message), ", in file ", __FILE__, ":", __LINE__); - OP_REQUIRES_OK(context, - absl::AbortedError(absl::StrCat( - "Operation received an exception:", error_msg))); + OP_REQUIRES_OK( + context, + errors::Aborted("Operation received an exception:", error_msg)); } } @@ -1008,9 +1006,8 @@ class MklConvOp : public OpKernel { } else { const Tensor& paddings_tf = MklGetInput(context, input_index_pad_); OP_REQUIRES(context, paddings_tf.dims() == 2, - absl::InvalidArgumentError( - absl::StrCat("paddings must be 2-dimensional: ", - paddings_tf.shape().DebugString()))); + errors::InvalidArgument("paddings must be 2-dimensional: ", + paddings_tf.shape().DebugString())); // Flatten tensor to get individual paddings. paddings = static_cast( const_cast(paddings_tf.flat().data())); @@ -1105,9 +1102,9 @@ class MklConvOp : public OpKernel { virtual void ComputeBNScale(OpKernelContext* context, float epsilon, int bn_variance_index, Tinput* scale_buf_ptr) { - OP_REQUIRES(context, false, - absl::UnimplementedError( - "Compute BN scale not expected in base class")); + OP_REQUIRES( + context, false, + errors::Unimplemented("Compute BN scale not expected in base class")); return; } @@ -1218,7 +1215,7 @@ class MklConvOp : public OpKernel { auto output_format_tag = MklTensorFormatToMklDnnDataFormat( output_mkl_shape->GetTfDataFormat()); OP_REQUIRES(context, output_format_tag != memory::format_tag::undef, - absl::InvalidArgumentError( + errors::InvalidArgument( "MklConvOp: AddN fusion: Invalid data format")); auto add_md = add_mkl_shape.IsMklTensor() @@ -1496,14 +1493,14 @@ class MklFusedConvOp int num_args; OP_REQUIRES_OK(context, context->GetAttr("num_args", &num_args)); OP_REQUIRES(context, !fused_ops.empty(), - absl::InvalidArgumentError( + errors::InvalidArgument( "Fused Conv2D must have at least one fused op.")); // TODO(intel-tf): Compact the code for activation checking if (fused_ops == std::vector{"BiasAdd"}) { this->set_fuse_biasadd(true); OP_REQUIRES(context, num_args == 1, - absl::InvalidArgumentError( + errors::InvalidArgument( "Fused Conv2D must have one extra argument: bias.")); } else if (fused_ops == std::vector{"Relu"}) { this->set_fuse_activation(true, dnnl::algorithm::eltwise_relu); @@ -1522,26 +1519,26 @@ class MklFusedConvOp OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon)); OP_REQUIRES( context, num_args == 4, - absl::InvalidArgumentError( + errors::InvalidArgument( "Fused Conv2D with batchnorm must have 4 extra argument")); this->set_fuse_bn(true, epsilon); } else if (fused_ops == std::vector{"BiasAdd", "Relu"}) { this->set_fuse_biasadd(true); this->set_fuse_activation(true, dnnl::algorithm::eltwise_relu); OP_REQUIRES(context, num_args == 1, - absl::InvalidArgumentError( + errors::InvalidArgument( "Fused Conv2D must have one extra argument: bias.")); } else if (fused_ops == std::vector{"BiasAdd", "Relu6"}) { this->set_fuse_biasadd(true); this->SET_FUSE_ACTIVATION_FOR_RELU6; OP_REQUIRES(context, num_args == 1, - absl::InvalidArgumentError( + errors::InvalidArgument( "Fused Conv2D must have one extra argument: bias.")); } else if (fused_ops == std::vector{"BiasAdd", "Elu"}) { this->set_fuse_biasadd(true); this->set_fuse_activation(true, dnnl::algorithm::eltwise_elu, 1.0); OP_REQUIRES(context, num_args == 1, - absl::InvalidArgumentError( + errors::InvalidArgument( "Fused Conv2D must have one extra argument: bias.")); } else if (fused_ops == std::vector{"BiasAdd", "LeakyRelu"}) { this->set_fuse_biasadd(true); @@ -1551,21 +1548,21 @@ class MklFusedConvOp this->set_fuse_activation(true, dnnl::algorithm::eltwise_relu, leakyrelu_alpha); OP_REQUIRES(context, num_args == 1, - absl::InvalidArgumentError( + errors::InvalidArgument( "Fused Conv2D must have one extra argument: bias.")); } else if (fused_ops == std::vector{"BiasAdd", "Add"}) { this->set_fuse_biasadd(true); this->set_fuse_add(true); OP_REQUIRES( context, num_args == 2, - absl::InvalidArgumentError( + errors::InvalidArgument( "Fused Conv2D must have two extra arguments: bias and add.")); } else if (fused_ops == std::vector{"FusedBatchNorm", "Relu"}) { float epsilon; OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon)); OP_REQUIRES( context, num_args == 4, - absl::InvalidArgumentError( + errors::InvalidArgument( "Fused Conv2D with batchnorm must have 4 extra argument")); this->set_fuse_bn(true, epsilon); this->set_fuse_activation(true, dnnl::algorithm::eltwise_relu); @@ -1574,7 +1571,7 @@ class MklFusedConvOp OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon)); OP_REQUIRES( context, num_args == 4, - absl::InvalidArgumentError( + errors::InvalidArgument( "Fused Conv2D with batchnorm must have 4 extra argument")); this->set_fuse_bn(true, epsilon); this->SET_FUSE_ACTIVATION_FOR_RELU6; @@ -1583,7 +1580,7 @@ class MklFusedConvOp OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon)); OP_REQUIRES( context, num_args == 4, - absl::InvalidArgumentError( + errors::InvalidArgument( "Fused Conv2D with batchnorm must have 4 extra argument")); this->set_fuse_bn(true, epsilon); this->set_fuse_activation(true, dnnl::algorithm::eltwise_elu, 1.0); @@ -1595,7 +1592,7 @@ class MklFusedConvOp context->GetAttr("leakyrelu_alpha", &leakyrelu_alpha)); OP_REQUIRES( context, num_args == 4, - absl::InvalidArgumentError( + errors::InvalidArgument( "Fused Conv2D with batchnorm must have 4 extra argument")); this->set_fuse_bn(true, epsilon); this->set_fuse_activation(true, dnnl::algorithm::eltwise_relu, @@ -1606,7 +1603,7 @@ class MklFusedConvOp OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon)); OP_REQUIRES( context, num_args == 4, - absl::InvalidArgumentError( + errors::InvalidArgument( "Fused Conv2D with batchnorm must have 4 extra argument")); this->set_fuse_bn(true, epsilon); this->set_fuse_activation(true, dnnl::algorithm::eltwise_swish, 1.0); @@ -1616,7 +1613,7 @@ class MklFusedConvOp this->set_fuse_activation(true, dnnl::algorithm::eltwise_relu); OP_REQUIRES( context, num_args == 2, - absl::InvalidArgumentError( + errors::InvalidArgument( "Fused Conv2D must have two extra arguments: bias and add.")); } else if (fused_ops == std::vector{"BiasAdd", "Add", "Relu6"}) { this->set_fuse_biasadd(true); @@ -1624,7 +1621,7 @@ class MklFusedConvOp this->SET_FUSE_ACTIVATION_FOR_RELU6; OP_REQUIRES( context, num_args == 2, - absl::InvalidArgumentError( + errors::InvalidArgument( "Fused Conv2D must have two extra arguments: bias and add.")); } else if (fused_ops == std::vector{"BiasAdd", "Add", "Elu"}) { this->set_fuse_biasadd(true); @@ -1632,7 +1629,7 @@ class MklFusedConvOp this->set_fuse_activation(true, dnnl::algorithm::eltwise_elu, 1.0); OP_REQUIRES( context, num_args == 2, - absl::InvalidArgumentError( + errors::InvalidArgument( "Fused Conv2D must have two extra arguments: bias and add.")); } else if (fused_ops == std::vector{"BiasAdd", "Add", "LeakyRelu"}) { @@ -1645,25 +1642,24 @@ class MklFusedConvOp leakyrelu_alpha); OP_REQUIRES( context, num_args == 2, - absl::InvalidArgumentError( + errors::InvalidArgument( "Fused Conv2D must have two extra arguments: bias and add.")); } else if (fused_ops == std::vector{"BiasAdd", "Mish"}) { this->set_fuse_biasadd(true); this->set_fuse_activation(true, dnnl::algorithm::eltwise_mish, 1.0); OP_REQUIRES(context, num_args == 1, - absl::InvalidArgumentError( + errors::InvalidArgument( "_FusedConv2D must have one extra argument: bias.")); } else if (fused_ops == std::vector{"BiasAdd", "_MklSwish"}) { this->set_fuse_biasadd(true); this->set_fuse_activation(true, dnnl::algorithm::eltwise_swish, 1.0); OP_REQUIRES(context, num_args == 1, - absl::InvalidArgumentError( + errors::InvalidArgument( "Fused Conv2D must have one extra argument: bias.")); } else { OP_REQUIRES(context, false, - absl::UnimplementedError( - absl::StrCat("Fusion is not implemented: [", - absl::StrJoin(fused_ops, ","), "]"))); + errors::Unimplemented("Fusion is not implemented: [", + absl::StrJoin(fused_ops, ","), "]")); } if (pad_enabled) { @@ -1710,7 +1706,7 @@ class MklFusedDepthwiseConvOp int num_args; OP_REQUIRES_OK(context, context->GetAttr("num_args", &num_args)); OP_REQUIRES(context, !fused_ops.empty(), - absl::InvalidArgumentError( + errors::InvalidArgument( "Fused DepthwiseConv2D must have at least one fused op.")); if (fused_ops == std::vector{"BiasAdd"}) { @@ -1726,14 +1722,13 @@ class MklFusedDepthwiseConvOp this->set_fuse_activation(true, dnnl::algorithm::eltwise_elu, 1.0); } else { OP_REQUIRES(context, false, - absl::UnimplementedError( - absl::StrCat("Fusion is not implemented: [", - absl::StrJoin(fused_ops, ","), "]"))); + errors::Unimplemented("Fusion is not implemented: [", + absl::StrJoin(fused_ops, ","), "]")); } OP_REQUIRES( context, num_args == 1, - absl::InvalidArgumentError( + errors::InvalidArgument( "Fused DepthwiseConv2D must have one extra argument: bias.")); if (pad_enabled) { @@ -1801,7 +1796,7 @@ class MklQuantizedConvOp // TODO(intel-tf): num_fused_ops and legacy_fused_ops should go away once // old API is abandoned. OP_REQUIRES(context, !(fused_ops_attr.size() > 0 && num_fused_ops > 0), - absl::InvalidArgumentError( + errors::InvalidArgument( "QuantizedConv fused ops should be only available through " "either new API or old API, got both.")); @@ -1818,9 +1813,8 @@ class MklQuantizedConvOp std::find(supported_fusions.begin(), supported_fusions.end(), fused_ops_) != supported_fusions.end(); OP_REQUIRES(context, is_fusion_supported, - absl::InvalidArgumentError( - absl::StrCat("Unsupported QuantizedConv fusion: [", - absl::StrJoin(fused_ops_, ","), "]"))); + errors::InvalidArgument("Unsupported QuantizedConv fusion: [", + absl::StrJoin(fused_ops_, ","), "]")); } // Set the flag for every fused op. @@ -1844,10 +1838,9 @@ class MklQuantizedConvOp const bool fuse_requantize = IsFused(oneDNNFusedOps::kRequantize); OP_REQUIRES_OK(context, context->GetAttr("out_type", &out_dt)); if (fuse_requantize) { - OP_REQUIRES( - context, out_dt == DT_QINT8 || out_dt == DT_QUINT8, - absl::InvalidArgumentError("QuantizedConv: unsupported output " - "type when Requantize is fused.")); + OP_REQUIRES(context, out_dt == DT_QINT8 || out_dt == DT_QUINT8, + errors::InvalidArgument("QuantizedConv: unsupported output " + "type when Requantize is fused.")); } if (context->HasAttr("Tsummand")) { @@ -1855,7 +1848,7 @@ class MklQuantizedConvOp if (!this->get_fuse_add()) { OP_REQUIRES( context, summand_dt == out_dt, - absl::InvalidArgumentError( + errors::InvalidArgument( "QuantizedConv: incorrect summand data type. When Sum is not " "fused, Tsummand attribute must have same value as out_type.")); } @@ -1881,7 +1874,7 @@ class MklQuantizedConvOp OP_REQUIRES( context, is_filter_const, - absl::InvalidArgumentError("QuantizedConv: filter must be a constant")); + errors::InvalidArgument("QuantizedConv: filter must be a constant")); if (num_fused_ops == -1) { // If num_fused_ops is -1 then the new API (ops) are being used. @@ -2027,13 +2020,13 @@ class MklQuantizedConvOp context->input(max_input_idx_).template scalar()(); const Tensor& min_filter_vector = context->input(min_filter_idx_); const Tensor& max_filter_vector = context->input(max_filter_idx_); - OP_REQUIRES(context, - ((min_filter_vector.NumElements() > 0) && - (max_filter_vector.NumElements() > 0) && - (min_filter_vector.shape() == max_filter_vector.shape())), - absl::InvalidArgumentError( - "`min_ and max_filter` must have same" - "shape and contain at least one element.")); + OP_REQUIRES( + context, + ((min_filter_vector.NumElements() > 0) && + (max_filter_vector.NumElements() > 0) && + (min_filter_vector.shape() == max_filter_vector.shape())), + errors::InvalidArgument("`min_ and max_filter` must have same" + "shape and contain at least one element.")); // min_freezed_output and max_freezed_output are the actual range // for the output. @@ -2093,15 +2086,15 @@ class MklQuantizedConvOp OP_REQUIRES( context, TensorShapeUtils::IsScalar(min_freezed_output_tensor.shape()), - absl::InvalidArgumentError( - absl::StrCat("`min_freezed_output` must be rank 0 but is rank ", - min_freezed_output_tensor.dims()))); + errors::InvalidArgument( + "`min_freezed_output` must be rank 0 but is rank ", + min_freezed_output_tensor.dims())); OP_REQUIRES( context, TensorShapeUtils::IsScalar(max_freezed_output_tensor.shape()), - absl::InvalidArgumentError( - absl::StrCat("`max_freezed_output` must be rank 0 but is rank ", - max_freezed_output_tensor.dims()))); + errors::InvalidArgument( + "`max_freezed_output` must be rank 0 but is rank ", + max_freezed_output_tensor.dims())); const Tensor& min_freezed_summand_tensor = context->input(min_summand_idx_); const Tensor& max_freezed_summand_tensor = @@ -2109,15 +2102,15 @@ class MklQuantizedConvOp OP_REQUIRES( context, TensorShapeUtils::IsScalar(min_freezed_summand_tensor.shape()), - absl::InvalidArgumentError(absl::StrCat( + errors::InvalidArgument( "`min_freezed_summand` must be rank 0 but is rank ", - min_freezed_summand_tensor.dims()))); + min_freezed_summand_tensor.dims())); OP_REQUIRES( context, TensorShapeUtils::IsScalar(max_freezed_summand_tensor.shape()), - absl::InvalidArgumentError(absl::StrCat( + errors::InvalidArgument( "`max_freezed_summand` must be rank 0 but is rank ", - max_freezed_summand_tensor.dims()))); + max_freezed_summand_tensor.dims())); const float min_freezed_output = min_freezed_output_tensor.template scalar()(); const float max_freezed_output = @@ -2192,7 +2185,7 @@ class MklQuantizedConvOp OP_REQUIRES(context, context->forward_input_to_output_with_shape( summand_idx, 0, summand.shape(), output_tensor), - absl::InvalidArgumentError( + errors::InvalidArgument( "Summand cannot be forwarded in the current fusion.")); return; } @@ -2473,14 +2466,13 @@ class MklFusedConv3DOp std::vector padding_list; OP_REQUIRES_OK(context, context->GetAttr("padding_list", &padding_list)); if (padding_list.empty()) { - OP_REQUIRES( - context, !fused_ops.empty(), - absl::InvalidArgumentError("Fused Conv3D must have at least one " - "fused op when Pad is not fused.")); + OP_REQUIRES(context, !fused_ops.empty(), + errors::InvalidArgument("Fused Conv3D must have at least one " + "fused op when Pad is not fused.")); if (std::find(fused_ops.begin(), fused_ops.end(), "BiasAdd") == fused_ops.end()) { OP_REQUIRES(context, num_args == 1, - absl::InvalidArgumentError( + errors::InvalidArgument( "Fused Conv3D must have one extra argument: bias.")); } else if (std::find(fused_ops.begin(), fused_ops.end(), "BiasAdd") == fused_ops.end() && @@ -2488,7 +2480,7 @@ class MklFusedConv3DOp fused_ops.end()) { OP_REQUIRES( context, num_args == 2, - absl::InvalidArgumentError( + errors::InvalidArgument( "Fused Conv3D must have two extra arguments: bias and add.")); } } @@ -2541,9 +2533,8 @@ class MklFusedConv3DOp } else { if (padding_list.empty()) { OP_REQUIRES(context, false, - absl::UnimplementedError( - absl::StrCat("Fusion is not implemented: [", - absl::StrJoin(fused_ops, ","), "]"))); + errors::Unimplemented("Fusion is not implemented: [", + absl::StrJoin(fused_ops, ","), "]")); } } } From 9dd9c375de6a4119a09f51c0682e8175d7cbe2c3 Mon Sep 17 00:00:00 2001 From: Gauri1 Deshpande Date: Mon, 17 Jul 2023 09:50:41 -0700 Subject: [PATCH 026/410] address review comments --- tensorflow/core/kernels/mkl/mkl_conv_ops.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/kernels/mkl/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl/mkl_conv_ops.cc index 8dae3705e0a811..7b7bc416689c9f 100644 --- a/tensorflow/core/kernels/mkl/mkl_conv_ops.cc +++ b/tensorflow/core/kernels/mkl/mkl_conv_ops.cc @@ -2643,7 +2643,8 @@ REGISTER_MKL_KERNEL_ALL_BIAS_TYPES( #define TEMPLATE_ARGS(CPUDevice, input_type, bias_type, output_type, \ summand_type, is_depthwise, legacy_fused_ops, \ num_fused_ops) \ - + #define BIAS_TYPE_CONSTRAINT(bias_type) #define SUMMAND_TYPE_CONSTRAINT(summand_type) #define LABEL .Label(mkl_op_registry::kMklQuantizedOpLabel) From b41fbd31d0dfc06fe6638349c6cf8d748f528ef6 Mon Sep 17 00:00:00 2001 From: shuw Date: Mon, 17 Jul 2023 11:18:40 -0700 Subject: [PATCH 027/410] Improve based on review. --- .../compiler/xla/service/gpu/gemm_rewriter.cc | 142 ++++++++++-------- 1 file changed, 83 insertions(+), 59 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc b/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc index 5aa1d7804aaa11..71c79825bec958 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc @@ -97,6 +97,44 @@ bool IsF8Type(const HloInstruction *instr) { return primitive_util::IsF8Type(instr->shape().element_type()); } +// Returns a new shape with non-batch dimensions padded to multiples of 16, as +// required by cuBLASLt FP8 gemms. +Shape PadShapeToMultipleOf16(const Shape old_shape, + const absl::Span batch_dims) { + Shape padded_shape = old_shape; + for (int i = 0; i < old_shape.rank(); ++i) { + if (!absl::c_linear_search(batch_dims, i)) { + int64_t padded_dimension = + RoundUpTo(old_shape.dimensions(i), 16); + padded_shape.set_dimensions(i, padded_dimension); + } + } + return padded_shape; +} + +// Pad the non-batch dimensions of the operands to multiples of 16 as required +// by cuBLASLt. +HloInstruction *PadOperandToMultipleOf16(absl::Span batch_dims, + HloInstruction *instr, + HloInstruction *x) { + PaddingConfig padding_config; + Shape padded_shape = PadShapeToMultipleOf16(x->shape(), batch_dims); + for (int i = 0; i < x->shape().rank(); ++i) { + auto dimension = padding_config.add_dimensions(); + dimension->set_edge_padding_low(0); + dimension->set_edge_padding_high(padded_shape.dimensions(i) - + x->shape().dimensions(i)); + dimension->set_interior_padding(0); + } + if (!ShapeUtil::Equal(padded_shape, x->shape())) { + HloInstruction *zero = instr->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(x->shape().element_type()))); + return instr->AddInstruction( + HloInstruction::CreatePad(padded_shape, x, zero, padding_config)); + } + return x; +} + // Recursively collects unary, pad, divide or multiply operands of instr until // an instruction with FP8 element type is reached. Returns std::nullopt when no // FP8 instruction is reached. @@ -584,22 +622,15 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { instr = new_add; } - if (Match(instr, - m::AddAnyOrder( - GemmOrCublasLtMatmulMaybeF8(&existing_gemm).WithOneUser(), - m::Op(&bias).WithPredicate(is_not_broadcast)))) { - return FuseMatrixBiasAdd(instr, bias, existing_gemm); - } - HloInstruction *optional_bitcast_matrix = nullptr; HloInstruction *optional_slice_matrix = nullptr; if (Match(instr, m::AddAnyOrder( OptionalBitcast( &optional_bitcast_matrix, - OptionalSlice( - &optional_slice_matrix, - CublasLtMatmulMaybeF8(&existing_gemm).WithOneUser())) + OptionalSlice(&optional_slice_matrix, + GemmOrCublasLtMatmulMaybeF8(&existing_gemm) + .WithOneUser())) .WithOneUser(), m::Op(&bias).WithPredicate(is_not_broadcast)))) { return FuseMatrixBiasAdd(instr, bias, existing_gemm, @@ -661,43 +692,6 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { return OkStatus(); } - // Get the padded shape. - Shape pad_shape(const Shape old_shape, - const absl::Span batch_dims) { - Shape padded_shape = old_shape; - for (int i = 0; i < old_shape.rank(); ++i) { - if (!absl::c_linear_search(batch_dims, i)) { - int64_t padded_dimension = - RoundUpTo(old_shape.dimensions(i), 16); - padded_shape.set_dimensions(i, padded_dimension); - } - } - return padded_shape; - } - - // Pad the non-batch dimensions of the operands to multiples of 16 as - // required by cuBLASLt. - void pad_operand(absl::Span batch_dims, HloInstruction *&instr, - HloInstruction *&x) { - PaddingConfig padding_config; - Shape padded_shape = pad_shape(x->shape(), batch_dims); - for (int i = 0; i < x->shape().rank(); ++i) { - auto dimension = padding_config.add_dimensions(); - dimension->set_edge_padding_low(0); - dimension->set_edge_padding_high(padded_shape.dimensions(i) - - x->shape().dimensions(i)); - dimension->set_interior_padding(0); - } - if (!ShapeUtil::Equal(padded_shape, x->shape())) { - HloInstruction *zero = - instr->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::Zero(x->shape().element_type()))); - x = instr->AddInstruction( - HloInstruction::CreatePad(padded_shape, x, zero, padding_config)); - } - return; - } - StatusOr CreateF8CustomCall(HloInstruction *instr, GemmBackendConfig &gemm_backend_config, HloInstruction *a, HloInstruction *b, @@ -879,9 +873,9 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { b = TransposeMatrix(b, b_contracting_dims[0], batch_dims); } - pad_operand(batch_dims, instr, a); - pad_operand(batch_dims, instr, b); - Shape new_output_shape = pad_shape(instr->shape(), batch_dims); + a = PadOperandToMultipleOf16(batch_dims, instr, a); + b = PadOperandToMultipleOf16(batch_dims, instr, b); + Shape new_output_shape = PadShapeToMultipleOf16(instr->shape(), batch_dims); std::vector operands_list = { a, b, scales_f32[0], scales_f32[1], one, one}; @@ -1052,13 +1046,17 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { return OkStatus(); } + // Fuses a matrix bias into a cuBLAS call. 'instr' should be an Add + // instruction in the following form: + // Add(OptionalBitcast(OptionalSlice(gemm)), bias) where 'gemm' is expected + // to be a cuBLAS custom_call. Status FuseMatrixBiasAdd(HloInstruction *instr, HloInstruction *bias, const HloInstruction *gemm, HloInstruction *bitcast = nullptr, HloInstruction *slice = nullptr) { - TF_RET_CHECK( - (bias->shape() == (bitcast ? bitcast->shape() : gemm->shape())) || - (bias->shape() == (slice ? slice->shape() : gemm->shape()))); + TF_RET_CHECK(bias->shape() == (bitcast ? bitcast->shape() + : slice ? slice->shape() + : gemm->shape())); // Do not fuse bias into S32 GEMM, as for this datatype cuBLAS only // supports fixed values for alpha/beta. @@ -1122,8 +1120,9 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { } if (gemm->custom_call_target() == kCublasLtMatmulF8CallTarget) { - pad_operand(config.dot_dimension_numbers().rhs_batch_dimensions(), instr, - broadcast_bias); + broadcast_bias = PadOperandToMultipleOf16( + config.dot_dimension_numbers().rhs_batch_dimensions(), instr, + broadcast_bias); } operands.insert(operands.begin() + 2, broadcast_bias); @@ -1175,6 +1174,13 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { HloInstruction *slice = nullptr, HloInstruction *convert = nullptr, HloInstruction *bitcast = nullptr) { + // Fuses a vector bias into a cuBLAS call. 'instr' should be an Add + // instruction in the following form: + // Add(OptionalBitcast(OptionalSlice(gemm)), broadcast), where 'gemm' is + // expected to be a cuBLAS custom_call. The optional convert is only + // applicable to F8 matmul as cublasLt has specific constraints on the + // vector bias type. The optional bitcast is necessary to handle high rank + // input cases. if (bitcast == nullptr) { TF_RET_CHECK(ShapeUtil::Compatible( broadcast->shape(), (slice ? slice->shape() : gemm->shape()))); @@ -1203,8 +1209,24 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { // dimensions; i.e. its most minor physical dimensions align with most minor // physical dimensions of the gemm output. absl::Span broadcast_dims = broadcast->dimensions(); - if (bitcast) { + if (bitcast != nullptr) { broadcast_dims = gemm->shape().dimensions(); + } else { + for (size_t i = 0; i < num_col_dims; ++i) { + int64_t dim = gemm->shape().layout().minor_to_major(i); + + // Find the corresponding dimension from the bias vector. + auto it = absl::c_find(broadcast_dims, dim); + + if (it == broadcast_dims.end()) { + return false; + } + + int64_t vector_dim = it - broadcast_dims.begin(); + if (bias->shape().layout().minor_to_major(i) != vector_dim) { + return false; + } + } } std::vector operands(gemm->operands().begin(), @@ -1251,9 +1273,11 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { } } - if (bitcast) { - pad_operand(config.dot_dimension_numbers().rhs_batch_dimensions(), instr, - bias); + // In the case of high rank input, it is necessary to consider potential + // padding for the bias. + if (bitcast != nullptr && slice != nullptr) { + bias = PadOperandToMultipleOf16( + config.dot_dimension_numbers().rhs_batch_dimensions(), instr, bias); } // Replace add(gemm, broadcast) with fused new_gemm. operands.push_back(bias); From 6e669dff1a9dd963f9c0ad597be8c4d6c352fd2a Mon Sep 17 00:00:00 2001 From: Sachin Muradi Date: Wed, 19 Jul 2023 17:35:49 -0700 Subject: [PATCH 028/410] Fix swish and mish --- .../core/grappler/optimizers/remapper.cc | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc index 6706f1c97ff8be..9a8ffe9125d947 100644 --- a/tensorflow/core/grappler/optimizers/remapper.cc +++ b/tensorflow/core/grappler/optimizers/remapper.cc @@ -1860,6 +1860,18 @@ bool FindSigmoidAndMul(RemapperContext* ctx, int node_index, sigmoidmul_pattern, {}, ctx->graph_view.GetNode(node_index), matched_nodes_map, remove_node_indices); + if (found_op_type_match) { + NodeDef* matched_sigmoid_node = + ctx->graph_view.GetNode(matched_nodes_map->at("sigmoid"))->node(); + auto in_tensor_sigmoid = matched_sigmoid_node->input(0); + if ((mul_node_def->input(0) != in_tensor_sigmoid) && + (mul_node_def->input(1) != in_tensor_sigmoid)) { + // If the input tensor of Sigmoid doesn't match with either of input + // tensors of mul return false + found_op_type_match = false; + } + } + return found_op_type_match; } @@ -4185,6 +4197,18 @@ bool FindSoftplusAndTanhAndMul(RemapperContext* ctx, int node_index, softplustanhmul_pattern, {}, ctx->graph_view.GetNode(node_index), matched_nodes_map, remove_node_indices); + if (found_op_type_match) { + NodeDef* matched_softplus_node = + ctx->graph_view.GetNode(matched_nodes_map->at("softplus"))->node(); + auto in_tensor_softplus = matched_softplus_node->input(0); + if ((mul_node_def->input(0) != in_tensor_softplus) && + (mul_node_def->input(1) != in_tensor_softplus)) { + // If the input tensor of Softplus doesn't match with either of input + // tensors of mul return false + found_op_type_match = false; + } + } + return found_op_type_match; } From e939fdf0ffe26b3766e064ff5efb45fc0b8ce121 Mon Sep 17 00:00:00 2001 From: johnnkp <22496821+johnnkp@users.noreply.github.com> Date: Thu, 20 Jul 2023 13:49:29 +0800 Subject: [PATCH 029/410] gpu_kernel_helper.h: Restrict `int tf_min/max()` overloading --- tensorflow/core/util/gpu_kernel_helper.h | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tensorflow/core/util/gpu_kernel_helper.h b/tensorflow/core/util/gpu_kernel_helper.h index 04d4d97810a186..ba6ae9ab153fa0 100644 --- a/tensorflow/core/util/gpu_kernel_helper.h +++ b/tensorflow/core/util/gpu_kernel_helper.h @@ -170,6 +170,19 @@ __host__ __device__ inline double tf_max(double x, double y) { return fmax(x, y); } +#ifdef _MSC_VER +#if _MSC_VER >= 1930 +using std::max; +using std::min; +__host__ __device__ inline int tf_min(int x, int y) { + return min(x, y); +} +__host__ __device__ inline int tf_max(int x, int y) { + return max(x, y); +} +#endif +#endif + // ROCM TODO re-enable them after adding fp16 support logic #if GOOGLE_CUDA __device__ inline Eigen::half GpuShuffleSync(unsigned mask, Eigen::half value, From cd6cfe7bcc78c2eb6883025a9be800058bf4bb5c Mon Sep 17 00:00:00 2001 From: johnnkp <22496821+johnnkp@users.noreply.github.com> Date: Thu, 20 Jul 2023 13:57:44 +0800 Subject: [PATCH 030/410] Update segment_reduction_ops_gpu.cu.h --- .../core/kernels/segment_reduction_ops_gpu.cu.h | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.h b/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.h index 3e7c37f4e14ad3..4c4ff91ef4b792 100644 --- a/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.h +++ b/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.h @@ -347,7 +347,7 @@ __global__ void SegmentReduceVectorKernel( Treducevec block_result = x_ok && y_ok ? input_vec[input_idx] : Tvec(initial_value); // Apply weights if provided. - if (weights && y_ok) block_result *= Tvec(weights[y_idx]); + if (weights && y_ok) block_result = block_result * Tvec(weights[y_idx]); // MSVC fix // Reduce along the columns of the block, returning result in first row. block_result = ReduceBlockAlongCols(reduce_op, block_result, x_ok); if (y == 0 && x_ok) { @@ -363,9 +363,9 @@ __global__ void SegmentReduceVectorKernel( typename RealTypeIfComplex::type total_weight(end - begin); // Normalize the results if necessary. if (is_mean) { - result /= Treducevec(total_weight); + result = result / Treducevec(total_weight); // MSVC fix } else if (is_sqrtn) { - result /= Treducevec(sqrt(total_weight)); + result = result / Treducevec(sqrt(static_cast(total_weight))); // MSVC fix } } // Cast from Treducevec to Tvec. @@ -439,10 +439,11 @@ __global__ void SegmentReduceEpilogueKernel( // Empty segment. val = Treducevec(empty_segment_value); } else if (is_mean) { - val /= Treducevec(segment_size); + val = val / Treducevec(segment_size); // MSVC fix } else if (is_sqrtn) { - val /= Treducevec( - sqrt(typename RealTypeIfComplex::type(segment_size))); + // MSVC fix + val = val / Treducevec( + sqrt(static_cast(typename RealTypeIfComplex::type(segment_size)))); } // Cast from Treducevec to Tvec. output[seg] = static_cast(val); @@ -491,7 +492,7 @@ struct LookupAndScaleAndCastInputsFunctor { __device__ Treducevec operator()(Tindex idx) const { if (indices_) idx = indices_[idx]; Treducevec result = static_cast(input_vec_[idx]); - if (weights_) result *= Tvec(weights_[idx]); + if (weights_) result = result * Tvec(weights_[idx]); // MSVC fix return result; } From 9b7832bc4d2f3d90308b6ccb1788519d38e2ed94 Mon Sep 17 00:00:00 2001 From: Milos Puzovic Date: Thu, 16 Mar 2023 23:51:35 +0000 Subject: [PATCH 031/410] Reduce MKL overheads on small shapes by not rewriting node to use MKL --- .../core/common_runtime/mkl_layout_pass.cc | 321 +++++++++++++++++- .../core/common_runtime/mkl_layout_pass.h | 18 + .../optimize_function_graph_utils.cc | 7 + .../core/grappler/optimizers/remapper.cc | 93 +++-- tensorflow/tsl/platform/cpu_info.cc | 91 +++++ 5 files changed, 459 insertions(+), 71 deletions(-) diff --git a/tensorflow/core/common_runtime/mkl_layout_pass.cc b/tensorflow/core/common_runtime/mkl_layout_pass.cc index 5cfd191072bfc8..85099ac94eec96 100644 --- a/tensorflow/core/common_runtime/mkl_layout_pass.cc +++ b/tensorflow/core/common_runtime/mkl_layout_pass.cc @@ -33,6 +33,7 @@ limitations under the License. #include "absl/base/call_once.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/common_runtime/process_util.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/graph/algorithm.h" @@ -49,6 +50,18 @@ limitations under the License. namespace tensorflow { +/// This table contains for each node name descriptors on which +/// hardware to check whether we should rewrite the operations +/// to use MKL based on the parameters for heuristic +static const RewriteThreshold rewrite_thresholds[] = { +#ifdef DNNL_AARCH64_USE_ACL + {"Conv2D", 0x41, 0xd40, {0.9349, 22.603}}, + {"_FusedConv2D", 0x41, 0xd40, {0.9349, 22.603}}, + {"FusedBatchNormV3", 0x41, 0xd40, {0.3223, -0.8822}}, + {"Sigmoid", 0x41, 0xd40, {0.0, 0.064736}}, +#endif // DNNL_AARCH64_USE_ACL + {"", 0x0, 0x0, {0, 0}}}; + // This pass implements rewriting of graph to support following scenarios: // (A) Merging nodes in the graph // (B) Rewriting a node in the graph to a new node @@ -239,7 +252,7 @@ namespace tensorflow { // class MklLayoutRewritePass : public GraphOptimizationPass { public: - MklLayoutRewritePass() { + MklLayoutRewritePass() : num_intra_threads_(0) { // NOTE: names are alphabetically sorted. csinfo_.addn = "AddN"; csinfo_.avg_pool = "AvgPool"; @@ -380,6 +393,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass { csinfo_.mul = "Mul"; csinfo_.squared_difference = "SquaredDifference"; csinfo_.sub = "Sub"; + csinfo_.sigmoid = "Sigmoid"; + csinfo_.swish = "_MklSwish"; // End - element-wise ops. See note above. const bool native_fmt = NativeFormatEnabled(); @@ -425,9 +440,12 @@ class MklLayoutRewritePass : public GraphOptimizationPass { {csinfo_.conjugate_transpose, mkl_op_registry::GetMklOpName(csinfo_.conjugate_transpose), CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange}); - rinfo_.push_back( - {csinfo_.conv2d, mkl_op_registry::GetMklOpName(csinfo_.conv2d), - CopyAttrsConvCheckConstFilter, AlwaysRewrite, GetRewriteCause()}); + rinfothr_.push_back( + {{csinfo_.conv2d, mkl_op_registry::GetMklOpName(csinfo_.conv2d), + CopyAttrsConvCheckConstFilter, + std::function(), // we set this function to empty + GetRewriteCause()}, + Conv2DRewrite}); rinfo_.push_back({csinfo_.conv2d_with_bias, native_fmt ? csinfo_.mkl_native_conv2d_with_bias : csinfo_.mkl_conv2d_with_bias, @@ -486,10 +504,11 @@ class MklLayoutRewritePass : public GraphOptimizationPass { // Using CopyAttrsAll for V3 on CPU, as there are no additional // attributes. - rinfo_.push_back( - {csinfo_.fused_batch_norm_v3, - mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_v3), - CopyAttrsAll, FusedBatchNormV3Rewrite, GetRewriteCause()}); + rinfothr_.push_back( + {{csinfo_.fused_batch_norm_v3, + mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_v3), + CopyAttrsAll, std::function(), GetRewriteCause()}, + FusedBatchNormV3RewriteWithThreads}); rinfo_.push_back( {csinfo_.fused_batch_norm_grad_v3, mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad_v3), @@ -499,11 +518,14 @@ class MklLayoutRewritePass : public GraphOptimizationPass { : csinfo_.mkl_fused_batch_norm_ex, CopyAttrsAll, FusedBatchNormExRewrite, GetRewriteCause()}); - rinfo_.push_back({csinfo_.fused_conv2d, - native_fmt ? csinfo_.mkl_native_fused_conv2d - : csinfo_.mkl_fused_conv2d, - CopyAttrsAllCheckConstFilter, FusedConv2DRewrite, - GetRewriteCause()}); + rinfothr_.push_back( + {{csinfo_.fused_conv2d, + native_fmt ? csinfo_.mkl_native_fused_conv2d + : csinfo_.mkl_fused_conv2d, + CopyAttrsAllCheckConstFilter, + std::function(), // we set this function to empty + GetRewriteCause()}, + FusedConv2DRewrite}); rinfo_.push_back({csinfo_.fused_conv3d, csinfo_.mkl_native_fused_conv3d, CopyAttrsAllCheckConstFilter, AlwaysRewrite, kRewriteForOpNameChange}); @@ -743,6 +765,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass { wsinfo_.push_back( {csinfo_.max_pool3d, csinfo_.max_pool3d_grad, 0, 1, 1, 3}); + // Rule for merging sigmoid and multiplication to swish + minfo_.push_back( + {csinfo_.sigmoid, csinfo_.mul, csinfo_.swish, GetSigmoidAndMul}); // Add a rule for merging nodes minfo_.push_back({csinfo_.conv2d, csinfo_.bias_add, csinfo_.conv2d_with_bias, GetConv2DOrBiasAdd}); @@ -808,6 +833,15 @@ class MklLayoutRewritePass : public GraphOptimizationPass { RewriteCause rewrite_cause; } RewriteInfo; + /// Structure that carries the original rewrite info, but + /// in this case it is using the function that can accept + /// what is number of threads that will be used to run + /// the operation in parallel + typedef struct { + RewriteInfo rinfo; + std::function rewrite_rule; + } RewriteInfoThreadCount; + /// Structure to specify a forward op, a backward op, and the slot numbers /// in the forward and backward ops where we will add a workspace edge. typedef struct { @@ -977,6 +1011,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass { string relu6; string relu6_grad; string requantize; + string sigmoid; + string swish; string tanh; string tanh_grad; string transpose; @@ -991,6 +1027,11 @@ class MklLayoutRewritePass : public GraphOptimizationPass { private: /// Maintain info about nodes to rewrite std::vector rinfo_; + /// Mantain info about nodes to rewrite with additional + /// information that holds number of threads that should + /// be used to run kernel on so that we can decide + /// whether it is worth rewriting op to run with MKL + std::vector rinfothr_; /// Maintain info about nodes to add workspace edge std::vector wsinfo_; @@ -1004,6 +1045,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass { /// Maintain structure of constant strings static ConstStringsInfo csinfo_; + /// Number of threads used for intra-parallelism + int num_intra_threads_; + private: // Is OpDef::ArgDef a list type? It could be N * T or list(type). // Refer to opdef.proto for details of list type. @@ -1089,11 +1133,27 @@ class MklLayoutRewritePass : public GraphOptimizationPass { Status MergeNode(std::unique_ptr* g, Node* m, Node* n); // Helper function to merge different nodes + Status MergeSigmoidWithMul(std::unique_ptr* g, Node* m, Node* n); Status MergeConv2DWithBiasAdd(std::unique_ptr* g, Node* m, Node* n); Status MergePadWithConv2D(std::unique_ptr* g, Node* m, Node* n); Status MergeConv2DBackpropFilterWithBiasAddGrad(std::unique_ptr* g, Node* m, Node* n); + static Node* GetSigmoidAndMul(const Node* m) { + DCHECK(m); + Node* n = nullptr; + + if (m->type_string() == csinfo_.sigmoid) { + for (const Edge* e : m->out_edges()) { + if (!e->IsControlEdge() && e->dst()->type_string() == csinfo_.mul && + e->dst_input() == 0) { + n = e->dst(); + break; + } + } + } + return n; + } // Find BiasAdd or Conv2D node that can be merged with input node 'm'. // If input 'm' is BiasAdd, then check if there exists Conv2D node that can be // merged with 'm'. If input 'm' is Conv2D, then check if there exists BiasAdd @@ -1134,6 +1194,27 @@ class MklLayoutRewritePass : public GraphOptimizationPass { return n; } + static double FindRewriteThreshold(const Node* n, int threads) { + int cpu_family_ = port::CPUFamily(); + int cpu_model_num_ = port::CPUModelNum(); + + if (threads == 0) { + // if we do not have information how many threads are used + // to parallelise operation we revert to the old behaviour + return 0; + } + + for (const RewriteThreshold* i = rewrite_thresholds; + i->op != "" && threads > 0; i++) { + if (n->type_string() == i->op && cpu_family_ == i->cpu_family && + cpu_model_num_ == i->cpu_model_num) { + return i->params.thread_sync_cost * threads + i->params.framework_cost; + } + } + + return 0; + } + // Find Pad or Conv2D node that can be merged with input node 'm'. // If input 'm' is Pad, then check if there exists Conv2D node that can be // merged with 'm'. If input 'm' is Conv2D, then check if there exists Pad @@ -1751,6 +1832,16 @@ class MklLayoutRewritePass : public GraphOptimizationPass { return true; } + static bool FusedBatchNormV3RewriteWithThreads(const Node* n, int threads) { + double mflops = CalculateNodeMFlops(n); + double thr = FindRewriteThreshold(n, threads); + if (mflops > 0 && mflops < thr) { + return false; + } + + return FusedBatchNormV3Rewrite(n); + } + static bool FusedBatchNormExRewrite(const Node* n) { DCHECK(n); @@ -1776,7 +1867,66 @@ class MklLayoutRewritePass : public GraphOptimizationPass { return true; } - static bool FusedConv2DRewrite(const Node* n) { + static double CalculateNodeMFlops(const Node* n) { + // Check if we can obtained dimensions for this node + std::vector shape_attrs; + if (!TryGetNodeAttr(n->attrs(), "_input_shapes", &shape_attrs)) { + // We can't obtain shape so we will revert to default behaviour + // to rewrite node + return -1; + } + + if ((n->type_string() == "Conv2D" || n->type_string() == "_FusedConv2D") && + shape_attrs.size() == 2) { + TensorShape input_shape, filter_shape; + if (TensorShape::BuildTensorShape(*shape_attrs[0], &input_shape) != + tsl::OkStatus()) { + return -1; + } + if (TensorShape::BuildTensorShape(*shape_attrs[1], &filter_shape) != + tsl::OkStatus()) { + return -1; + } + + // MFLOPS = N * H * W * C * FH * FW * FC / 1e6. + return input_shape.dim_size(0) * input_shape.dim_size(1) * + input_shape.dim_size(2) * input_shape.dim_size(3) * + filter_shape.dim_size(0) * filter_shape.dim_size(1) * + filter_shape.dim_size(3) / (double)1e6; + } else if ((n->type_string() == "FusedBatchNormV3" || + n->type_string() == "Sigmoid") && + shape_attrs.size() >= 1) { + TensorShape input_shape; + if (TensorShape::BuildTensorShape(*shape_attrs[0], &input_shape) != + tsl::OkStatus()) { + return -1; + } + return input_shape.dim_size(0) * input_shape.dim_size(1) * + input_shape.dim_size(2) * input_shape.dim_size(3) / (double)1e6; + } + + return -1; + } + + static bool Conv2DRewrite(const Node* n, int threads) { + // Find out what are dimensions of the convolution + // If dimensions are small we will not rewrite node + // to use MKL operations as overhead to call into MKL + // data set up is higher then actual useful work we + // might end up doing + double total_mflops = CalculateNodeMFlops(n); + double thr = FindRewriteThreshold(n, threads); + + return true ? (total_mflops < 0 || total_mflops >= thr) : false; + } + + static bool FusedConv2DRewrite(const Node* n, int threads) { + // Decide whether it is worth rewriting it to MKL operation + // due to overheads as they will dominate for small shapes + if (!Conv2DRewrite(n, threads)) { + return false; + } + // MKL DNN currently doesn't support all fusions that grappler fuses // together with Conv2D (ex. batchnorm). We rewrite _FusedConv2D only if // it includes those we support. @@ -2998,6 +3148,129 @@ Node* MklLayoutRewritePass::CheckForNodeMerge(const Node* a) const { return nullptr; } +Status MklLayoutRewritePass::MergeSigmoidWithMul(std::unique_ptr* g, + Node* m, Node* n) { + CHECK_EQ( + (m->type_string() == csinfo_.sigmoid && n->type_string() == csinfo_.mul), + true); + + // Decide whether it is worth optimizing SigMoid+Mul to + // call to _MklSiwsh + double total_mflops = CalculateNodeMFlops(m); + double thr = FindRewriteThreshold(m, num_intra_threads_); + + if (total_mflops != -1 && total_mflops < thr) { + // Do not merge and execute them as they are + // because overhead of going to MKL will dominate + // any benefits from accelerating compute + return Status(error::Code::CANCELLED, + "Sigmoid and Mul operate on small shapes, " + "so there is no benefit in optimizing to Swish. " + "Will skip node merge optimization"); + } + + Node* sigmoid = m; + Node* mul = n; + + DataType T_sigmoid, T_mul; + TF_CHECK_OK(GetNodeAttr(sigmoid->def(), "T", &T_sigmoid)); + TF_CHECK_OK(GetNodeAttr(mul->def(), "T", &T_mul)); + + const int mul_num = mul->num_inputs(); + gtl::InlinedVector mul_control_edges; + gtl::InlinedVector, 4> mul_in(mul_num); + FillInputs(mul, &mul_control_edges, &mul_in); + + const int sigmoid_num = sigmoid->num_inputs(); + gtl::InlinedVector sigmoid_control_edges; + gtl::InlinedVector, 4> sigmoid_in(sigmoid_num); + FillInputs(sigmoid, &sigmoid_control_edges, &sigmoid_in); + + CHECK_EQ(sigmoid->in_edges().size(), 1); // Sigmoid has 1 input + CHECK_EQ(mul->in_edges().size(), 2); // Mul has 2 inputs + for (const Edge* e : mul->in_edges()) { + const int kFirstInputSlot = 0; // this should be sigmoid + const int kSecondInputSlot = + 1; // this should be thae same input as in sigmoid + if (e->dst_input() == kFirstInputSlot && e->src() != sigmoid) { + return Status(error::Code::INVALID_ARGUMENT, + "Sigmoid doesn't feed to Mul. " + "Will skip node merge optimization"); + } + const Edge* kSigmoidInputEdge = nullptr; + TF_CHECK_OK(sigmoid->input_edge(kFirstInputSlot, &kSigmoidInputEdge)); + if (e->dst_input() == kSecondInputSlot && + e->src() != kSigmoidInputEdge->src()) { + return Status(error::Code::INVALID_ARGUMENT, + "Input to Sigmoid and Mul is not the same. " + "Will skip node merge optimization"); + } + } + + NodeBuilder nb(mul->name(), csinfo_.swish); + nb.Input(sigmoid_in[0].first, sigmoid_in[0].second); + nb.Attr("T", T_mul); // copy type attribute + nb.Device(mul->def().device()); + + // Create new node + Node* new_node; + TF_CHECK_OK(nb.Finalize(&**g, &new_node)); + + std::unordered_set unique_node; + for (const Edge* e : sigmoid->in_edges()) { + if (e->IsControlEdge()) { + auto result = unique_node.insert(e->src()); + if (result.second) { + (*g)->AddControlEdge(e->src(), new_node, true); + } + } + } + unique_node.clear(); + + for (const Edge* e : mul->in_edges()) { + if (e->IsControlEdge()) { + auto result = unique_node.insert(e->src()); + if (result.second) { + (*g)->AddControlEdge(e->src(), new_node, true); + } + } + } + unique_node.clear(); + + for (const Edge* e : sigmoid->out_edges()) { + if (e->IsControlEdge()) { + auto result = unique_node.insert(e->dst()); + if (result.second) { + (*g)->AddControlEdge(new_node, e->dst(), true); + } + } + } + unique_node.clear(); + for (const Edge* e : mul->out_edges()) { + if (e->IsControlEdge()) { + auto results = unique_node.insert(e->dst()); + if (results.second) { + (*g)->AddControlEdge(new_node, e->dst(), true); + } + } else { + const int kMulOutputSlot = 0; + auto new_edge = + (*g)->AddEdge(new_node, kMulOutputSlot, e->dst(), e->dst_input()); + DCHECK(new_edge); + } + } + + new_node->set_assigned_device_name(mul->assigned_device_name()); + VLOG(1) << "MklLayoutRewritePass: Merged old node: " << sigmoid->DebugString() + << ", and node: " << mul->DebugString() + << ", into node: " << new_node->DebugString(); + + (*g)->RemoveNode(sigmoid); + (*g)->RemoveNode(mul); + + return OkStatus(); +} + Status MklLayoutRewritePass::MergeConv2DWithBiasAdd(std::unique_ptr* g, Node* m, Node* n) { CHECK_EQ(((m->type_string() == csinfo_.bias_add && @@ -3481,6 +3754,9 @@ Status MklLayoutRewritePass::MergeNode(std::unique_ptr* g, Node* m, DCHECK(m); DCHECK(n); + if (m->type_string() == csinfo_.sigmoid && n->type_string() == csinfo_.mul) { + return this->MergeSigmoidWithMul(g, m, n); + } if (((m->type_string() == csinfo_.bias_add && n->type_string() == csinfo_.conv2d)) || ((n->type_string() == csinfo_.bias_add && @@ -3489,10 +3765,12 @@ Status MklLayoutRewritePass::MergeNode(std::unique_ptr* g, Node* m, } if ((m->type_string() == csinfo_.pad && (n->type_string() == csinfo_.conv2d || - (n->type_string() == csinfo_.fused_conv2d && FusedConv2DRewrite(n)))) || + (n->type_string() == csinfo_.fused_conv2d && + FusedConv2DRewrite(n, num_intra_threads_)))) || (n->type_string() == csinfo_.pad && (m->type_string() == csinfo_.conv2d || - (m->type_string() == csinfo_.fused_conv2d && FusedConv2DRewrite(m))))) { + (m->type_string() == csinfo_.fused_conv2d && + FusedConv2DRewrite(m, num_intra_threads_))))) { return this->MergePadWithConv2D(g, m, n); } @@ -3811,6 +4089,13 @@ MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const { } } + for (auto rit = rinfothr_.cbegin(); rit != rinfothr_.cend(); ++rit) { + if (n->type_string().compare(rit->rinfo.name) == 0 && + rit->rewrite_rule(n, num_intra_threads_)) { + return &(rit->rinfo); + } + } + // Else return not found. return nullptr; } @@ -4167,6 +4452,10 @@ Status MklLayoutRewritePass::Run(const GraphOptimizationPassOptions& options) { return OkStatus(); } + if (options.session_options != nullptr) { + num_intra_threads_ = + options.session_options->config.intra_op_parallelism_threads(); + } auto process_graph = [&](std::unique_ptr* g) { // Get the ownership of a graph std::unique_ptr* ng = std::move(g); diff --git a/tensorflow/core/common_runtime/mkl_layout_pass.h b/tensorflow/core/common_runtime/mkl_layout_pass.h index 6b5c586ceabb3e..8c14efb8e6dc68 100644 --- a/tensorflow/core/common_runtime/mkl_layout_pass.h +++ b/tensorflow/core/common_runtime/mkl_layout_pass.h @@ -25,6 +25,24 @@ limitations under the License. #include "tensorflow/core/graph/graph.h" namespace tensorflow { + +struct RewriteThreshold { + string op; + int cpu_family; + int cpu_model_num; + // The model that is used to decide whether it is worth + // accelerating operations using oneDNN is: + // threshold = thread_synchronisation*thread_num + framework_tax + // which finds threshold when framework overhead and thread synchronisations + // are amortized with amount of computation that has to be performed. + // If we are below this threshold then we will not rewrite the operation to + // to be run using oneDNN primitive. + struct PerformanceParameters { + double thread_sync_cost; + double framework_cost; + } params; +}; + // Interface to invoke the pass for unit test // // Returns true if and only if 'g' is mutated. diff --git a/tensorflow/core/common_runtime/optimize_function_graph_utils.cc b/tensorflow/core/common_runtime/optimize_function_graph_utils.cc index 921eb3a4c064ba..d1394850c4a1f8 100644 --- a/tensorflow/core/common_runtime/optimize_function_graph_utils.cc +++ b/tensorflow/core/common_runtime/optimize_function_graph_utils.cc @@ -553,6 +553,13 @@ StatusOr OptimizeFunctionGraph( options.shape_inference_on_tfe_dialect_import; optimization_options.debug_filename_prefix = function_name; + if (cpu_device->tensorflow_cpu_worker_threads() != nullptr) { + // Pass to the optimisation pass number of intra threads that are used to + // parallelise operations + session_options.config.set_intra_op_parallelism_threads( + cpu_device->tensorflow_cpu_worker_threads()->num_threads); + } + DEBUG_DATA_DUMPER()->DumpGraph(function_name, kDebugGroupMain, "before_pre_placement_passes", graph.get(), &reachable_lib_def, false); diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc index 6706f1c97ff8be..f7807d273dab74 100644 --- a/tensorflow/core/grappler/optimizers/remapper.cc +++ b/tensorflow/core/grappler/optimizers/remapper.cc @@ -733,6 +733,23 @@ bool IsBiasSemanticAdd(const RemapperContext& ctx, return false; } +void AddInputShapesAttr(const RemapperContext& ctx, int node_index) { + auto mutable_node = ctx.graph_view.graph()->mutable_node(node_index); + + AttrValue attr_input_shape; + auto tensor_properties = + ctx.graph_properties.GetInputProperties(mutable_node->name()); + for (const auto& tensor_property : tensor_properties) { + TensorShapeProto* proto = attr_input_shape.mutable_list()->add_shape(); + *proto = tensor_property.shape(); + } + + if (IsMKLEnabled()) { + (*mutable_node->mutable_attr())["_input_shapes"] = + std::move(attr_input_shape); + } +} + bool FindContractionWithBias(const RemapperContext& ctx, int node_index, ContractionWithBiasAdd* matched, bool check_device_compatible = true) { @@ -1819,50 +1836,6 @@ bool FindMulAndMaximum(RemapperContext* ctx, int node_index, return found_op_type_match; } -bool FindSigmoidAndMul(RemapperContext* ctx, int node_index, - std::map* matched_nodes_map, - std::set* remove_node_indices) { - // Gelu fusion is enabled only with oneDNN library. - if (!IsMKLEnabled()) return false; - - using utils::MatchingDirection; - using utils::NodeStatus; - // clang-format off - // Convert Sigmoid+Mul to Swish - // Mul(x, Sigmoid(x)) --> _MklSwish(x) - - utils::OpTypePattern sigmoidmul_pattern{ - "Mul", "mul_to_swish", NodeStatus::kReplace, - { - { "Sigmoid", "sigmoid", NodeStatus::kRemove, - { - { "*", "input", NodeStatus::kRemain} - } - }, - { "*", "input", NodeStatus::kRemain} - } - }; - // clang-format on - // check for data types - auto* mul_node_def = ctx->graph_view.GetNode(node_index)->node(); - if (!HasDataType(mul_node_def, DT_FLOAT) && - !HasDataType(mul_node_def, DT_BFLOAT16)) - return false; - - if (!NodeIsOnCpu(mul_node_def)) return false; - - bool found_op_type_match = false; - utils::SubGraphMatcher graph_matcher( - &(ctx->graph_view)); - matched_nodes_map->clear(); - remove_node_indices->clear(); - found_op_type_match = graph_matcher.GetMatchedNodes( - sigmoidmul_pattern, {}, ctx->graph_view.GetNode(node_index), - matched_nodes_map, remove_node_indices); - - return found_op_type_match; -} - // Keras LayerNormalization api uses multiple TensorFlow ops. Current fusion // pattern is only for the case, when LayerNormalization uses FusedBatcNormV3. // We further restrict it to only 2D or 3D tensor inputs to keras @@ -2762,6 +2735,11 @@ void CopyConv2DAttributes(const NodeDef& conv2d, NodeDef* fused_conv2d, (*attr)["dilations"] = src_attr.at("dilations"); (*attr)["data_format"] = src_attr.at("data_format"); (*attr)["use_cudnn_on_gpu"] = src_attr.at("use_cudnn_on_gpu"); + // When copying attributes check whether this convolution has + // attribute that describes the shapes on which it is working + if (IsMKLEnabled()) { + (*attr)["_input_shapes"] = src_attr.at("_input_shapes"); + } // Copy LeakyRelu's attr alpha to FusedConv2D's attr leakyrelu_alpha if (activation != nullptr && IsLeakyRelu(*activation)) { auto& activation_attr = activation->attr(); @@ -2914,6 +2892,7 @@ Status AddFusedContractionNode(RemapperContext* ctx, fused_op.add_input(bias_add.input(matched.bias_port)); // 2: bias if (IsConv2D(contraction)) { fused_op.set_op(kFusedConv2D); + AddInputShapesAttr(*ctx, matched.contraction); CopyConv2DAttributes(contraction, &fused_op); } else if (IsDepthwiseConv2dNative(contraction)) { fused_op.set_op(kFusedDepthwiseConv2dNative); @@ -3017,6 +2996,7 @@ Status AddFusedContractionNode( if (IsConv2D(contraction)) { fused_op.set_op(kFusedConv2D); + AddInputShapesAttr(*ctx, matched.contraction); // leaky relu has a special attribute alpha CopyConv2DAttributes(contraction, &fused_op, &activation); } else if (IsDepthwiseConv2dNative(contraction)) { @@ -3071,6 +3051,7 @@ Status AddFusedConvNode(RemapperContext* ctx, if (IsConv2D(contraction)) { fused_conv.set_op(kFusedConv2D); + AddInputShapesAttr(*ctx, matched.contraction); CopyConv2DAttributes(contraction, &fused_conv); } else if (IsConv3D(contraction)) { fused_conv.set_op(kFusedConv3D); @@ -3121,6 +3102,7 @@ Status AddFusedConv2DNode(RemapperContext* ctx, fused_conv2d.add_input(fused_batch_norm.input(3)); // 4: mean fused_conv2d.add_input(fused_batch_norm.input(4)); // 5: variance + AddInputShapesAttr(*ctx, matched.contraction); CopyConv2DAttributes(contraction, &fused_conv2d); SetFusedOpAttributes(&fused_conv2d, {"FusedBatchNorm"}, /*num_args=*/4, /*epsilon=*/matched.epsilon); @@ -3164,6 +3146,7 @@ Status AddFusedConv2DNode(RemapperContext* ctx, fused_conv2d.add_input(fused_batch_norm.input(3)); // 4: mean fused_conv2d.add_input(fused_batch_norm.input(4)); // 5: variance + AddInputShapesAttr(*ctx, matched.contraction); CopyConv2DAttributes(contraction, &fused_conv2d, &activation); SetFusedOpAttributes(&fused_conv2d, {"FusedBatchNorm", activation.op()}, /*num_args=*/4, /*epsilon=*/matched.epsilon); @@ -3209,6 +3192,7 @@ Status AddFusedContractionNode(RemapperContext* ctx, if (IsConv2D(contraction)) { contraction_node.set_op(kFusedConv2D); + AddInputShapesAttr(*ctx, matched.contraction); CopyConv2DAttributes(contraction, &contraction_node); } else if (IsMatMul(contraction)) { contraction_node.set_op(kFusedMatMul); @@ -3309,6 +3293,7 @@ Status AddFusedContractionNode( if (IsConv2D(contraction)) { fused_conv.set_op(kFusedConv2D); + AddInputShapesAttr(*ctx, matched.contraction); CopyConv2DAttributes(contraction, &fused_conv); } else if (IsConv3D(contraction)) { fused_conv.set_op(kFusedConv3D); @@ -4423,6 +4408,15 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item, ContractionWithActivation contract_with_activation; ContractionWithBiasAndAddActivation contract_with_bias_and_add_activation; + // Store dimensions so that they can be retrieved later in + // mkl_layout_rewrite_pass when deciding whether to rewrite node + if (IsConv2D(ctx.graph_view.graph()->node(i)) || + IsFusedBatchNorm(ctx.graph_view.graph()->node(i)) || + IsDepthwiseConv2dNative(ctx.graph_view.graph()->node(i)) || + IsSigmoid(ctx.graph_view.graph()->node(i))) { + AddInputShapesAttr(ctx, i); + } + if (IsMKLEnabled()) { // Remap Conv2D+BiasAdd+Add+relu into the _FusedConv2D. // or Remap Conv3D+BiasAdd+Add+relu into _FusedConv3D @@ -4515,17 +4509,6 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item, continue; } - // Remap Mul(x, Sigmoid(x)) pattern, fuse them into the Swish(x). - std::map sigmoidmul_matched_nodes_map; - std::set sigmoidmul_remove_node_indices; - if (FindSigmoidAndMul(&ctx, i, &sigmoidmul_matched_nodes_map, - &sigmoidmul_remove_node_indices)) { - TF_RETURN_IF_ERROR(ReplaceSigmoidMulWithSwish( - &ctx, sigmoidmul_matched_nodes_map, sigmoidmul_remove_node_indices, - &invalidated_nodes, &nodes_to_delete)); - continue; - } - // Remap smaller ops from layernorm python api into _MklLayerNorm matched_nodes_map.clear(); remove_node_indices.clear(); diff --git a/tensorflow/tsl/platform/cpu_info.cc b/tensorflow/tsl/platform/cpu_info.cc index fae0be99ac2903..075e60fbeb16ca 100644 --- a/tensorflow/tsl/platform/cpu_info.cc +++ b/tensorflow/tsl/platform/cpu_info.cc @@ -22,6 +22,11 @@ limitations under the License. #if defined(PLATFORM_IS_X86) #include // NOLINT #endif +#if defined(PLATFORM_IS_ARM64) +#include + +#include +#endif // SIMD extension querying is only available on x86. #ifdef PLATFORM_IS_X86 @@ -345,6 +350,86 @@ void InitCPUIDInfo() { #endif // PLATFORM_IS_X86 +#ifdef PLATFORM_IS_ARM64 + +class CPUIDInfo; +void InitCPUIDInfo(); + +CPUIDInfo* cpuid = nullptr; + +// Structure for basic CPUID info +class CPUIDInfo { + public: + CPUIDInfo() : implementer_(0), variant_(0), cpunum_(0) {} + + static void Initialize() { + // Initialize cpuid struct + CHECK(cpuid == nullptr) << __func__ << " ran more than once"; + cpuid = new CPUIDInfo; + + if (!(getauxval(AT_HWCAP) & HWCAP_CPUID)) { + return; + } + + std::ifstream CPUspresent; + CPUspresent.open("/sys/devices/system/cpu/present", std::ios::in); + int present_cpu = -1; + if (CPUspresent.is_open()) { + std::string line; + if (bool(getline(CPUspresent, line))) { + // We just need to find one CPU that is active + // from which we can read MIDR register to find + // implement, variant and revision information + auto ending = line.end(); + for (auto i = line.begin(); i < line.end(); ++i) { + if (*i == '-' || *i == ',') { + ending = i; + break; + } + } + line.erase(ending, line.end()); + // That should be the fist number + present_cpu = std::stoi(line); + } + } + + if (present_cpu == -1) { + return; + } + + std::stringstream str; + str << "/sys/devices/system/cpu/cpu" << present_cpu + << "/regs/identification/midr_el1"; + std::ifstream midr_el1_file(str.str(), std::ios::in); + if (midr_el1_file.is_open()) { + std::string line; + if (bool(getline(midr_el1_file, line))) { + uint32 midr_el1 = std::stoul(line, nullptr, 16); + + // Unpack variant and CPU ID + cpuid->implementer_ = (midr_el1 >> 24) & 0xFF; + cpuid->variant_ = (midr_el1 >> 20) & 0xF; + cpuid->cpunum_ = (midr_el1 >> 4) & 0xFFF; + } + } + } + + int implementer() const { return implementer_; } + int cpunum() const { return cpunum_; } + + private: + int implementer_; + int variant_; + int cpunum_; +}; + +absl::once_flag cpuid_once_flag; + +void InitCPUIDInfo() { + absl::call_once(cpuid_once_flag, CPUIDInfo::Initialize); +} + +#endif } // namespace bool TestCPUFeature(CPUFeature feature) { @@ -368,6 +453,9 @@ int CPUFamily() { #ifdef PLATFORM_IS_X86 InitCPUIDInfo(); return cpuid->family(); +#elif defined(PLATFORM_IS_ARM64) + InitCPUIDInfo(); + return cpuid->implementer(); #else return 0; #endif @@ -377,6 +465,9 @@ int CPUModelNum() { #ifdef PLATFORM_IS_X86 InitCPUIDInfo(); return cpuid->model_num(); +#elif defined(PLATFORM_IS_ARM64) + InitCPUIDInfo(); + return cpuid->cpunum(); #else return 0; #endif From 4c8e63e959cff57152c02aa20a880bf8ca7cb540 Mon Sep 17 00:00:00 2001 From: Milos Puzovic Date: Wed, 29 Mar 2023 09:17:22 +0100 Subject: [PATCH 032/410] Address comments from reviewers --- .../core/common_runtime/mkl_layout_pass.cc | 91 +++++++++++-------- .../core/common_runtime/mkl_layout_pass.h | 4 +- .../optimize_function_graph_utils.cc | 4 +- .../core/grappler/optimizers/remapper.cc | 4 +- tensorflow/tsl/platform/cpu_info.cc | 15 +-- 5 files changed, 70 insertions(+), 48 deletions(-) diff --git a/tensorflow/core/common_runtime/mkl_layout_pass.cc b/tensorflow/core/common_runtime/mkl_layout_pass.cc index 85099ac94eec96..0389dd6da4b1b8 100644 --- a/tensorflow/core/common_runtime/mkl_layout_pass.cc +++ b/tensorflow/core/common_runtime/mkl_layout_pass.cc @@ -51,8 +51,8 @@ limitations under the License. namespace tensorflow { /// This table contains for each node name descriptors on which -/// hardware to check whether we should rewrite the operations -/// to use MKL based on the parameters for heuristic +/// hardware to check whether we should rewrite the operations +/// to use oneDNN based on the parameters for heuristic. static const RewriteThreshold rewrite_thresholds[] = { #ifdef DNNL_AARCH64_USE_ACL {"Conv2D", 0x41, 0xd40, {0.9349, 22.603}}, @@ -252,7 +252,7 @@ static const RewriteThreshold rewrite_thresholds[] = { // class MklLayoutRewritePass : public GraphOptimizationPass { public: - MklLayoutRewritePass() : num_intra_threads_(0) { + MklLayoutRewritePass() : num_intra_threads_(port::MaxParallelism()) { // NOTE: names are alphabetically sorted. csinfo_.addn = "AddN"; csinfo_.avg_pool = "AvgPool"; @@ -765,7 +765,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { wsinfo_.push_back( {csinfo_.max_pool3d, csinfo_.max_pool3d_grad, 0, 1, 1, 3}); - // Rule for merging sigmoid and multiplication to swish + // Rule for merging sigmoid and multiplication to oneDNN swish. minfo_.push_back( {csinfo_.sigmoid, csinfo_.mul, csinfo_.swish, GetSigmoidAndMul}); // Add a rule for merging nodes @@ -836,7 +836,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { /// Structure that carries the original rewrite info, but /// in this case it is using the function that can accept /// what is number of threads that will be used to run - /// the operation in parallel + /// the operation in parallel. typedef struct { RewriteInfo rinfo; std::function rewrite_rule; @@ -1027,10 +1027,10 @@ class MklLayoutRewritePass : public GraphOptimizationPass { private: /// Maintain info about nodes to rewrite std::vector rinfo_; - /// Mantain info about nodes to rewrite with additional + /// Maintain info about nodes to rewrite with additional /// information that holds number of threads that should - /// be used to run kernel on so that we can decide - /// whether it is worth rewriting op to run with MKL + /// be used to parallelise the kernel so that we can decide + /// whether it is worth rewriting op to run with oneDNN. std::vector rinfothr_; /// Maintain info about nodes to add workspace edge @@ -1045,7 +1045,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { /// Maintain structure of constant strings static ConstStringsInfo csinfo_; - /// Number of threads used for intra-parallelism + /// Number of threads used for intra-parallelism. int num_intra_threads_; private: @@ -1140,10 +1140,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass { Node* m, Node* n); static Node* GetSigmoidAndMul(const Node* m) { - DCHECK(m); Node* n = nullptr; - if (m->type_string() == csinfo_.sigmoid) { + if (m && m->type_string() == csinfo_.sigmoid) { for (const Edge* e : m->out_edges()) { if (!e->IsControlEdge() && e->dst()->type_string() == csinfo_.mul && e->dst_input() == 0) { @@ -1868,11 +1867,11 @@ class MklLayoutRewritePass : public GraphOptimizationPass { } static double CalculateNodeMFlops(const Node* n) { - // Check if we can obtained dimensions for this node + // Check if we can obtained dimensions for this node. std::vector shape_attrs; if (!TryGetNodeAttr(n->attrs(), "_input_shapes", &shape_attrs)) { // We can't obtain shape so we will revert to default behaviour - // to rewrite node + // to rewrite node. return -1; } @@ -1909,11 +1908,11 @@ class MklLayoutRewritePass : public GraphOptimizationPass { } static bool Conv2DRewrite(const Node* n, int threads) { - // Find out what are dimensions of the convolution - // If dimensions are small we will not rewrite node - // to use MKL operations as overhead to call into MKL - // data set up is higher then actual useful work we - // might end up doing + // Find out what are dimensions of the convolution, + // if dimensions are small we will not rewrite node + // to use oneDNN operations as overhead to call into oneDNN + // as data setup is higher then actual useful work we + // might end up doing. double total_mflops = CalculateNodeMFlops(n); double thr = FindRewriteThreshold(n, threads); @@ -1921,8 +1920,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass { } static bool FusedConv2DRewrite(const Node* n, int threads) { - // Decide whether it is worth rewriting it to MKL operation - // due to overheads as they will dominate for small shapes + // Decide whether it is worth rewriting it to oneDNN operation + // due to overheads as they will dominate for small shapes. if (!Conv2DRewrite(n, threads)) { return false; } @@ -3150,20 +3149,23 @@ Node* MklLayoutRewritePass::CheckForNodeMerge(const Node* a) const { Status MklLayoutRewritePass::MergeSigmoidWithMul(std::unique_ptr* g, Node* m, Node* n) { - CHECK_EQ( - (m->type_string() == csinfo_.sigmoid && n->type_string() == csinfo_.mul), - true); + if (!(m->type_string() == csinfo_.sigmoid && + n->type_string() == csinfo_.mul)) { + return Status(absl::StatusCode::kCancelled, + "Mul doesn't follow Sigmoid. " + "Will skip node merge optimization"); + } // Decide whether it is worth optimizing SigMoid+Mul to - // call to _MklSiwsh + // call to _MklSwish. double total_mflops = CalculateNodeMFlops(m); double thr = FindRewriteThreshold(m, num_intra_threads_); if (total_mflops != -1 && total_mflops < thr) { - // Do not merge and execute them as they are - // because overhead of going to MKL will dominate - // any benefits from accelerating compute - return Status(error::Code::CANCELLED, + // Do not merge and execute them as they are, + // because overheads of going to oneDNN will dominate + // any benefits from accelerating compute. + return Status(absl::StatusCode::kCancelled, "Sigmoid and Mul operate on small shapes, " "so there is no benefit in optimizing to Swish. " "Will skip node merge optimization"); @@ -3186,14 +3188,25 @@ Status MklLayoutRewritePass::MergeSigmoidWithMul(std::unique_ptr* g, gtl::InlinedVector, 4> sigmoid_in(sigmoid_num); FillInputs(sigmoid, &sigmoid_control_edges, &sigmoid_in); - CHECK_EQ(sigmoid->in_edges().size(), 1); // Sigmoid has 1 input - CHECK_EQ(mul->in_edges().size(), 2); // Mul has 2 inputs + // Sigmoid has 1 input. + if (sigmoid->in_edges().size() != 1) { + return Status(absl::StatusCode::kCancelled, + "Sigmoid must have only one input edge." + "Will skip node merge optimization"); + } + // Mul has 2 inputs. + if (mul->in_edges().size() != 2) { + return Status(absl::StatusCode::kCancelled, + "Mul must have only two input edges." + "Will skip node merge optimization"); + } + for (const Edge* e : mul->in_edges()) { - const int kFirstInputSlot = 0; // this should be sigmoid + const int kFirstInputSlot = 0; // This should be sigmoid. const int kSecondInputSlot = - 1; // this should be thae same input as in sigmoid + 1; // This should be the same input as in sigmoid. if (e->dst_input() == kFirstInputSlot && e->src() != sigmoid) { - return Status(error::Code::INVALID_ARGUMENT, + return Status(absl::StatusCode::kInvalidArgument, "Sigmoid doesn't feed to Mul. " "Will skip node merge optimization"); } @@ -3201,7 +3214,7 @@ Status MklLayoutRewritePass::MergeSigmoidWithMul(std::unique_ptr* g, TF_CHECK_OK(sigmoid->input_edge(kFirstInputSlot, &kSigmoidInputEdge)); if (e->dst_input() == kSecondInputSlot && e->src() != kSigmoidInputEdge->src()) { - return Status(error::Code::INVALID_ARGUMENT, + return Status(absl::StatusCode::kInvalidArgument, "Input to Sigmoid and Mul is not the same. " "Will skip node merge optimization"); } @@ -3209,10 +3222,10 @@ Status MklLayoutRewritePass::MergeSigmoidWithMul(std::unique_ptr* g, NodeBuilder nb(mul->name(), csinfo_.swish); nb.Input(sigmoid_in[0].first, sigmoid_in[0].second); - nb.Attr("T", T_mul); // copy type attribute + nb.Attr("T", T_mul); // Copy type attribute. nb.Device(mul->def().device()); - // Create new node + // Create new node. Node* new_node; TF_CHECK_OK(nb.Finalize(&**g, &new_node)); @@ -3256,7 +3269,11 @@ Status MklLayoutRewritePass::MergeSigmoidWithMul(std::unique_ptr* g, const int kMulOutputSlot = 0; auto new_edge = (*g)->AddEdge(new_node, kMulOutputSlot, e->dst(), e->dst_input()); - DCHECK(new_edge); + if (!new_edge) { + return Status(absl::StatusCode::kCancelled, + "Failed to create a new edge from new node to output." + "Will skip node merge optimization"); + } } } diff --git a/tensorflow/core/common_runtime/mkl_layout_pass.h b/tensorflow/core/common_runtime/mkl_layout_pass.h index 8c14efb8e6dc68..9b5dbd6a0b0494 100644 --- a/tensorflow/core/common_runtime/mkl_layout_pass.h +++ b/tensorflow/core/common_runtime/mkl_layout_pass.h @@ -32,8 +32,10 @@ struct RewriteThreshold { int cpu_model_num; // The model that is used to decide whether it is worth // accelerating operations using oneDNN is: + // // threshold = thread_synchronisation*thread_num + framework_tax - // which finds threshold when framework overhead and thread synchronisations + // + // This finds threshold when framework overhead and thread synchronisations // are amortized with amount of computation that has to be performed. // If we are below this threshold then we will not rewrite the operation to // to be run using oneDNN primitive. diff --git a/tensorflow/core/common_runtime/optimize_function_graph_utils.cc b/tensorflow/core/common_runtime/optimize_function_graph_utils.cc index d1394850c4a1f8..222dab9efbaabe 100644 --- a/tensorflow/core/common_runtime/optimize_function_graph_utils.cc +++ b/tensorflow/core/common_runtime/optimize_function_graph_utils.cc @@ -554,8 +554,8 @@ StatusOr OptimizeFunctionGraph( optimization_options.debug_filename_prefix = function_name; if (cpu_device->tensorflow_cpu_worker_threads() != nullptr) { - // Pass to the optimisation pass number of intra threads that are used to - // parallelise operations + // Forward to the optimisation pass number of intra threads that are used to + // parallelise operations. session_options.config.set_intra_op_parallelism_threads( cpu_device->tensorflow_cpu_worker_threads()->num_threads); } diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc index f7807d273dab74..e3d539c97bb100 100644 --- a/tensorflow/core/grappler/optimizers/remapper.cc +++ b/tensorflow/core/grappler/optimizers/remapper.cc @@ -2736,7 +2736,7 @@ void CopyConv2DAttributes(const NodeDef& conv2d, NodeDef* fused_conv2d, (*attr)["data_format"] = src_attr.at("data_format"); (*attr)["use_cudnn_on_gpu"] = src_attr.at("use_cudnn_on_gpu"); // When copying attributes check whether this convolution has - // attribute that describes the shapes on which it is working + // attribute that describes the shapes on which it is working. if (IsMKLEnabled()) { (*attr)["_input_shapes"] = src_attr.at("_input_shapes"); } @@ -4409,7 +4409,7 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item, ContractionWithBiasAndAddActivation contract_with_bias_and_add_activation; // Store dimensions so that they can be retrieved later in - // mkl_layout_rewrite_pass when deciding whether to rewrite node + // mkl_layout_rewrite_pass when deciding whether to rewrite node. if (IsConv2D(ctx.graph_view.graph()->node(i)) || IsFusedBatchNorm(ctx.graph_view.graph()->node(i)) || IsDepthwiseConv2dNative(ctx.graph_view.graph()->node(i)) || diff --git a/tensorflow/tsl/platform/cpu_info.cc b/tensorflow/tsl/platform/cpu_info.cc index 075e60fbeb16ca..081ff0a49d6639 100644 --- a/tensorflow/tsl/platform/cpu_info.cc +++ b/tensorflow/tsl/platform/cpu_info.cc @@ -357,14 +357,17 @@ void InitCPUIDInfo(); CPUIDInfo* cpuid = nullptr; -// Structure for basic CPUID info +// Structure for basic CPUID info. class CPUIDInfo { public: CPUIDInfo() : implementer_(0), variant_(0), cpunum_(0) {} static void Initialize() { - // Initialize cpuid struct - CHECK(cpuid == nullptr) << __func__ << " ran more than once"; + // Initialize cpuid struct. + if (cpuid != nullptr) { + return; + } + cpuid = new CPUIDInfo; if (!(getauxval(AT_HWCAP) & HWCAP_CPUID)) { @@ -379,7 +382,7 @@ class CPUIDInfo { if (bool(getline(CPUspresent, line))) { // We just need to find one CPU that is active // from which we can read MIDR register to find - // implement, variant and revision information + // implement, variant and revision information. auto ending = line.end(); for (auto i = line.begin(); i < line.end(); ++i) { if (*i == '-' || *i == ',') { @@ -388,7 +391,7 @@ class CPUIDInfo { } } line.erase(ending, line.end()); - // That should be the fist number + // That should be the fist number. present_cpu = std::stoi(line); } } @@ -406,7 +409,7 @@ class CPUIDInfo { if (bool(getline(midr_el1_file, line))) { uint32 midr_el1 = std::stoul(line, nullptr, 16); - // Unpack variant and CPU ID + // Unpack variant and CPU ID. cpuid->implementer_ = (midr_el1 >> 24) & 0xFF; cpuid->variant_ = (midr_el1 >> 20) & 0xF; cpuid->cpunum_ = (midr_el1 >> 4) & 0xFFF; From 9318fefbfb0c4b63ea82d976dbd2258ebdc3be0b Mon Sep 17 00:00:00 2001 From: Milos Puzovic Date: Fri, 21 Apr 2023 11:21:02 +0100 Subject: [PATCH 033/410] Address comments from review --- tensorflow/core/common_runtime/mkl_layout_pass.cc | 14 +++++++------- tensorflow/core/common_runtime/mkl_layout_pass.h | 2 +- tensorflow/core/graph/mkl_graph_util.h | 1 + 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/tensorflow/core/common_runtime/mkl_layout_pass.cc b/tensorflow/core/common_runtime/mkl_layout_pass.cc index 0389dd6da4b1b8..90b8eeea2e25db 100644 --- a/tensorflow/core/common_runtime/mkl_layout_pass.cc +++ b/tensorflow/core/common_runtime/mkl_layout_pass.cc @@ -50,9 +50,9 @@ limitations under the License. namespace tensorflow { -/// This table contains for each node name descriptors on which -/// hardware to check whether we should rewrite the operations -/// to use oneDNN based on the parameters for heuristic. +// Table storing thread synchronization and framework overhead costs on each CPU +// architecture for each oneNN-eligible operation. Our heuristics use these +// costs to determine whether we should rewrite the operation to use oneDNN. static const RewriteThreshold rewrite_thresholds[] = { #ifdef DNNL_AARCH64_USE_ACL {"Conv2D", 0x41, 0xd40, {0.9349, 22.603}}, @@ -330,6 +330,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { csinfo_.mkl_native_pad_with_fused_conv2d = "_MklNativePadWithFusedConv2D"; csinfo_.mkl_pad_with_conv2d = "_MklPadWithConv2D"; csinfo_.mkl_pad_with_fused_conv2d = "_MklPadWithFusedConv2D"; + csinfo_.mkl_swish = "_MklSwish"; csinfo_.pad = "Pad"; csinfo_.pad_with_conv2d = "__MklDummyPadWithConv2D"; csinfo_.pad_with_fused_conv2d = "__MklDummyPadWithFusedConv2D"; @@ -394,7 +395,6 @@ class MklLayoutRewritePass : public GraphOptimizationPass { csinfo_.squared_difference = "SquaredDifference"; csinfo_.sub = "Sub"; csinfo_.sigmoid = "Sigmoid"; - csinfo_.swish = "_MklSwish"; // End - element-wise ops. See note above. const bool native_fmt = NativeFormatEnabled(); @@ -767,7 +767,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { // Rule for merging sigmoid and multiplication to oneDNN swish. minfo_.push_back( - {csinfo_.sigmoid, csinfo_.mul, csinfo_.swish, GetSigmoidAndMul}); + {csinfo_.sigmoid, csinfo_.mul, csinfo_.mkl_swish, GetSigmoidAndMul}); // Add a rule for merging nodes minfo_.push_back({csinfo_.conv2d, csinfo_.bias_add, csinfo_.conv2d_with_bias, GetConv2DOrBiasAdd}); @@ -977,6 +977,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { string mkl_native_pad_with_fused_conv2d; string mkl_pad_with_conv2d; string mkl_pad_with_fused_conv2d; + string mkl_swish; string mul; string pad; string pad_with_conv2d; @@ -1012,7 +1013,6 @@ class MklLayoutRewritePass : public GraphOptimizationPass { string relu6_grad; string requantize; string sigmoid; - string swish; string tanh; string tanh_grad; string transpose; @@ -3220,7 +3220,7 @@ Status MklLayoutRewritePass::MergeSigmoidWithMul(std::unique_ptr* g, } } - NodeBuilder nb(mul->name(), csinfo_.swish); + NodeBuilder nb(mul->name(), csinfo_.mkl_swish); nb.Input(sigmoid_in[0].first, sigmoid_in[0].second); nb.Attr("T", T_mul); // Copy type attribute. nb.Device(mul->def().device()); diff --git a/tensorflow/core/common_runtime/mkl_layout_pass.h b/tensorflow/core/common_runtime/mkl_layout_pass.h index 9b5dbd6a0b0494..d4303f1f1e8651 100644 --- a/tensorflow/core/common_runtime/mkl_layout_pass.h +++ b/tensorflow/core/common_runtime/mkl_layout_pass.h @@ -33,7 +33,7 @@ struct RewriteThreshold { // The model that is used to decide whether it is worth // accelerating operations using oneDNN is: // - // threshold = thread_synchronisation*thread_num + framework_tax + // threshold = thread_synchronisation * thread_num + framework_tax // // This finds threshold when framework overhead and thread synchronisations // are amortized with amount of computation that has to be performed. diff --git a/tensorflow/core/graph/mkl_graph_util.h b/tensorflow/core/graph/mkl_graph_util.h index 886c4051c8ca13..8cf7788327a814 100644 --- a/tensorflow/core/graph/mkl_graph_util.h +++ b/tensorflow/core/graph/mkl_graph_util.h @@ -288,6 +288,7 @@ static inline bool IsMklElementWiseOp(const string& op_name, DataType T) { 0 == op_name.compare(GetMklOpName("Sub")) || 0 == op_name.compare(GetMklOpName("Mul")) || 0 == op_name.compare(GetMklOpName("Maximum")) || + 0 == op_name.compare(GetMklOpName("Sigmoid")) || 0 == op_name.compare(GetMklOpName("SquaredDifference"))); return result; From 8eca57d7a7504e21eca966df94809d78175995bb Mon Sep 17 00:00:00 2001 From: Milos Puzovic Date: Tue, 4 Jul 2023 09:48:10 +0100 Subject: [PATCH 034/410] Address comments from reviewers --- .../core/common_runtime/mkl_layout_pass.cc | 245 +----------------- tensorflow/core/grappler/grappler_item.h | 4 + .../grappler/optimizers/meta_optimizer.cc | 6 + .../core/grappler/optimizers/remapper.cc | 89 ++++++- tensorflow/core/util/BUILD | 18 ++ tensorflow/core/util/mkl_heuristics.h | 123 +++++++++ tensorflow/core/util/mkl_heuristics_test.cc | 120 +++++++++ 7 files changed, 363 insertions(+), 242 deletions(-) create mode 100644 tensorflow/core/util/mkl_heuristics.h create mode 100644 tensorflow/core/util/mkl_heuristics_test.cc diff --git a/tensorflow/core/common_runtime/mkl_layout_pass.cc b/tensorflow/core/common_runtime/mkl_layout_pass.cc index 90b8eeea2e25db..bc21b8b019352b 100644 --- a/tensorflow/core/common_runtime/mkl_layout_pass.cc +++ b/tensorflow/core/common_runtime/mkl_layout_pass.cc @@ -45,23 +45,12 @@ limitations under the License. #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/mkl_heuristics.h" #include "tensorflow/core/util/tensor_format.h" #include "tensorflow/core/util/util.h" namespace tensorflow { -// Table storing thread synchronization and framework overhead costs on each CPU -// architecture for each oneNN-eligible operation. Our heuristics use these -// costs to determine whether we should rewrite the operation to use oneDNN. -static const RewriteThreshold rewrite_thresholds[] = { -#ifdef DNNL_AARCH64_USE_ACL - {"Conv2D", 0x41, 0xd40, {0.9349, 22.603}}, - {"_FusedConv2D", 0x41, 0xd40, {0.9349, 22.603}}, - {"FusedBatchNormV3", 0x41, 0xd40, {0.3223, -0.8822}}, - {"Sigmoid", 0x41, 0xd40, {0.0, 0.064736}}, -#endif // DNNL_AARCH64_USE_ACL - {"", 0x0, 0x0, {0, 0}}}; - // This pass implements rewriting of graph to support following scenarios: // (A) Merging nodes in the graph // (B) Rewriting a node in the graph to a new node @@ -765,9 +754,6 @@ class MklLayoutRewritePass : public GraphOptimizationPass { wsinfo_.push_back( {csinfo_.max_pool3d, csinfo_.max_pool3d_grad, 0, 1, 1, 3}); - // Rule for merging sigmoid and multiplication to oneDNN swish. - minfo_.push_back( - {csinfo_.sigmoid, csinfo_.mul, csinfo_.mkl_swish, GetSigmoidAndMul}); // Add a rule for merging nodes minfo_.push_back({csinfo_.conv2d, csinfo_.bias_add, csinfo_.conv2d_with_bias, GetConv2DOrBiasAdd}); @@ -1133,26 +1119,11 @@ class MklLayoutRewritePass : public GraphOptimizationPass { Status MergeNode(std::unique_ptr* g, Node* m, Node* n); // Helper function to merge different nodes - Status MergeSigmoidWithMul(std::unique_ptr* g, Node* m, Node* n); Status MergeConv2DWithBiasAdd(std::unique_ptr* g, Node* m, Node* n); Status MergePadWithConv2D(std::unique_ptr* g, Node* m, Node* n); Status MergeConv2DBackpropFilterWithBiasAddGrad(std::unique_ptr* g, Node* m, Node* n); - static Node* GetSigmoidAndMul(const Node* m) { - Node* n = nullptr; - - if (m && m->type_string() == csinfo_.sigmoid) { - for (const Edge* e : m->out_edges()) { - if (!e->IsControlEdge() && e->dst()->type_string() == csinfo_.mul && - e->dst_input() == 0) { - n = e->dst(); - break; - } - } - } - return n; - } // Find BiasAdd or Conv2D node that can be merged with input node 'm'. // If input 'm' is BiasAdd, then check if there exists Conv2D node that can be // merged with 'm'. If input 'm' is Conv2D, then check if there exists BiasAdd @@ -1193,27 +1164,6 @@ class MklLayoutRewritePass : public GraphOptimizationPass { return n; } - static double FindRewriteThreshold(const Node* n, int threads) { - int cpu_family_ = port::CPUFamily(); - int cpu_model_num_ = port::CPUModelNum(); - - if (threads == 0) { - // if we do not have information how many threads are used - // to parallelise operation we revert to the old behaviour - return 0; - } - - for (const RewriteThreshold* i = rewrite_thresholds; - i->op != "" && threads > 0; i++) { - if (n->type_string() == i->op && cpu_family_ == i->cpu_family && - cpu_model_num_ == i->cpu_model_num) { - return i->params.thread_sync_cost * threads + i->params.framework_cost; - } - } - - return 0; - } - // Find Pad or Conv2D node that can be merged with input node 'm'. // If input 'm' is Pad, then check if there exists Conv2D node that can be // merged with 'm'. If input 'm' is Conv2D, then check if there exists Pad @@ -1832,8 +1782,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass { } static bool FusedBatchNormV3RewriteWithThreads(const Node* n, int threads) { - double mflops = CalculateNodeMFlops(n); - double thr = FindRewriteThreshold(n, threads); + double mflops = CalculateNodeMFlops(n->attrs(), n->type_string()); + double thr = FindRewriteThreshold(n->type_string(), threads); if (mflops > 0 && mflops < thr) { return false; } @@ -1866,55 +1816,14 @@ class MklLayoutRewritePass : public GraphOptimizationPass { return true; } - static double CalculateNodeMFlops(const Node* n) { - // Check if we can obtained dimensions for this node. - std::vector shape_attrs; - if (!TryGetNodeAttr(n->attrs(), "_input_shapes", &shape_attrs)) { - // We can't obtain shape so we will revert to default behaviour - // to rewrite node. - return -1; - } - - if ((n->type_string() == "Conv2D" || n->type_string() == "_FusedConv2D") && - shape_attrs.size() == 2) { - TensorShape input_shape, filter_shape; - if (TensorShape::BuildTensorShape(*shape_attrs[0], &input_shape) != - tsl::OkStatus()) { - return -1; - } - if (TensorShape::BuildTensorShape(*shape_attrs[1], &filter_shape) != - tsl::OkStatus()) { - return -1; - } - - // MFLOPS = N * H * W * C * FH * FW * FC / 1e6. - return input_shape.dim_size(0) * input_shape.dim_size(1) * - input_shape.dim_size(2) * input_shape.dim_size(3) * - filter_shape.dim_size(0) * filter_shape.dim_size(1) * - filter_shape.dim_size(3) / (double)1e6; - } else if ((n->type_string() == "FusedBatchNormV3" || - n->type_string() == "Sigmoid") && - shape_attrs.size() >= 1) { - TensorShape input_shape; - if (TensorShape::BuildTensorShape(*shape_attrs[0], &input_shape) != - tsl::OkStatus()) { - return -1; - } - return input_shape.dim_size(0) * input_shape.dim_size(1) * - input_shape.dim_size(2) * input_shape.dim_size(3) / (double)1e6; - } - - return -1; - } - static bool Conv2DRewrite(const Node* n, int threads) { // Find out what are dimensions of the convolution, // if dimensions are small we will not rewrite node // to use oneDNN operations as overhead to call into oneDNN // as data setup is higher then actual useful work we // might end up doing. - double total_mflops = CalculateNodeMFlops(n); - double thr = FindRewriteThreshold(n, threads); + double total_mflops = CalculateNodeMFlops(n->attrs(), n->type_string()); + double thr = FindRewriteThreshold(n->type_string(), threads); return true ? (total_mflops < 0 || total_mflops >= thr) : false; } @@ -3147,147 +3056,6 @@ Node* MklLayoutRewritePass::CheckForNodeMerge(const Node* a) const { return nullptr; } -Status MklLayoutRewritePass::MergeSigmoidWithMul(std::unique_ptr* g, - Node* m, Node* n) { - if (!(m->type_string() == csinfo_.sigmoid && - n->type_string() == csinfo_.mul)) { - return Status(absl::StatusCode::kCancelled, - "Mul doesn't follow Sigmoid. " - "Will skip node merge optimization"); - } - - // Decide whether it is worth optimizing SigMoid+Mul to - // call to _MklSwish. - double total_mflops = CalculateNodeMFlops(m); - double thr = FindRewriteThreshold(m, num_intra_threads_); - - if (total_mflops != -1 && total_mflops < thr) { - // Do not merge and execute them as they are, - // because overheads of going to oneDNN will dominate - // any benefits from accelerating compute. - return Status(absl::StatusCode::kCancelled, - "Sigmoid and Mul operate on small shapes, " - "so there is no benefit in optimizing to Swish. " - "Will skip node merge optimization"); - } - - Node* sigmoid = m; - Node* mul = n; - - DataType T_sigmoid, T_mul; - TF_CHECK_OK(GetNodeAttr(sigmoid->def(), "T", &T_sigmoid)); - TF_CHECK_OK(GetNodeAttr(mul->def(), "T", &T_mul)); - - const int mul_num = mul->num_inputs(); - gtl::InlinedVector mul_control_edges; - gtl::InlinedVector, 4> mul_in(mul_num); - FillInputs(mul, &mul_control_edges, &mul_in); - - const int sigmoid_num = sigmoid->num_inputs(); - gtl::InlinedVector sigmoid_control_edges; - gtl::InlinedVector, 4> sigmoid_in(sigmoid_num); - FillInputs(sigmoid, &sigmoid_control_edges, &sigmoid_in); - - // Sigmoid has 1 input. - if (sigmoid->in_edges().size() != 1) { - return Status(absl::StatusCode::kCancelled, - "Sigmoid must have only one input edge." - "Will skip node merge optimization"); - } - // Mul has 2 inputs. - if (mul->in_edges().size() != 2) { - return Status(absl::StatusCode::kCancelled, - "Mul must have only two input edges." - "Will skip node merge optimization"); - } - - for (const Edge* e : mul->in_edges()) { - const int kFirstInputSlot = 0; // This should be sigmoid. - const int kSecondInputSlot = - 1; // This should be the same input as in sigmoid. - if (e->dst_input() == kFirstInputSlot && e->src() != sigmoid) { - return Status(absl::StatusCode::kInvalidArgument, - "Sigmoid doesn't feed to Mul. " - "Will skip node merge optimization"); - } - const Edge* kSigmoidInputEdge = nullptr; - TF_CHECK_OK(sigmoid->input_edge(kFirstInputSlot, &kSigmoidInputEdge)); - if (e->dst_input() == kSecondInputSlot && - e->src() != kSigmoidInputEdge->src()) { - return Status(absl::StatusCode::kInvalidArgument, - "Input to Sigmoid and Mul is not the same. " - "Will skip node merge optimization"); - } - } - - NodeBuilder nb(mul->name(), csinfo_.mkl_swish); - nb.Input(sigmoid_in[0].first, sigmoid_in[0].second); - nb.Attr("T", T_mul); // Copy type attribute. - nb.Device(mul->def().device()); - - // Create new node. - Node* new_node; - TF_CHECK_OK(nb.Finalize(&**g, &new_node)); - - std::unordered_set unique_node; - for (const Edge* e : sigmoid->in_edges()) { - if (e->IsControlEdge()) { - auto result = unique_node.insert(e->src()); - if (result.second) { - (*g)->AddControlEdge(e->src(), new_node, true); - } - } - } - unique_node.clear(); - - for (const Edge* e : mul->in_edges()) { - if (e->IsControlEdge()) { - auto result = unique_node.insert(e->src()); - if (result.second) { - (*g)->AddControlEdge(e->src(), new_node, true); - } - } - } - unique_node.clear(); - - for (const Edge* e : sigmoid->out_edges()) { - if (e->IsControlEdge()) { - auto result = unique_node.insert(e->dst()); - if (result.second) { - (*g)->AddControlEdge(new_node, e->dst(), true); - } - } - } - unique_node.clear(); - for (const Edge* e : mul->out_edges()) { - if (e->IsControlEdge()) { - auto results = unique_node.insert(e->dst()); - if (results.second) { - (*g)->AddControlEdge(new_node, e->dst(), true); - } - } else { - const int kMulOutputSlot = 0; - auto new_edge = - (*g)->AddEdge(new_node, kMulOutputSlot, e->dst(), e->dst_input()); - if (!new_edge) { - return Status(absl::StatusCode::kCancelled, - "Failed to create a new edge from new node to output." - "Will skip node merge optimization"); - } - } - } - - new_node->set_assigned_device_name(mul->assigned_device_name()); - VLOG(1) << "MklLayoutRewritePass: Merged old node: " << sigmoid->DebugString() - << ", and node: " << mul->DebugString() - << ", into node: " << new_node->DebugString(); - - (*g)->RemoveNode(sigmoid); - (*g)->RemoveNode(mul); - - return OkStatus(); -} - Status MklLayoutRewritePass::MergeConv2DWithBiasAdd(std::unique_ptr* g, Node* m, Node* n) { CHECK_EQ(((m->type_string() == csinfo_.bias_add && @@ -3771,9 +3539,6 @@ Status MklLayoutRewritePass::MergeNode(std::unique_ptr* g, Node* m, DCHECK(m); DCHECK(n); - if (m->type_string() == csinfo_.sigmoid && n->type_string() == csinfo_.mul) { - return this->MergeSigmoidWithMul(g, m, n); - } if (((m->type_string() == csinfo_.bias_add && n->type_string() == csinfo_.conv2d)) || ((n->type_string() == csinfo_.bias_add && diff --git a/tensorflow/core/grappler/grappler_item.h b/tensorflow/core/grappler/grappler_item.h index c7faf23566adf9..8fcfba288f2f38 100644 --- a/tensorflow/core/grappler/grappler_item.h +++ b/tensorflow/core/grappler/grappler_item.h @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/variable.pb.h" #include "tensorflow/core/protobuf/queue_runner.pb.h" +#include "tensorflow/tsl/platform/cpu_info.h" namespace tensorflow { namespace grappler { @@ -102,6 +103,9 @@ struct GrapplerItem { // Mark the grapper optimization run in eager mode or not. bool is_eager_mode = false; + + // Number of intra threads used to run operation. + int intra_op_parallelism_threads = tsl::port::MaxParallelism(); }; const std::unordered_set& devices() const; diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc index 605b18bf102016..0680d25f5f8262 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc @@ -1383,6 +1383,12 @@ Status OptimizeGraph( tensorflow::grappler::GrapplerItem item; item.id = grappler_item_id; item.optimization_options() = optimization_options; + if (cpu_device->tensorflow_cpu_worker_threads() != nullptr) { + // Forward to the optimisation pass number of intra threads that are used to + // parallelise operations. + item.optimization_options().intra_op_parallelism_threads = + cpu_device->tensorflow_cpu_worker_threads()->num_threads; + } // Add all available devices so that inlined function can be placed. for (const Device* d : device_set.devices()) { diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc index e3d539c97bb100..70c3a86e6f236c 100644 --- a/tensorflow/core/grappler/optimizers/remapper.cc +++ b/tensorflow/core/grappler/optimizers/remapper.cc @@ -40,6 +40,9 @@ limitations under the License. #include "tensorflow/core/protobuf/rewriter_config.pb.h" #include "tensorflow/core/util/env_var.h" #include "tensorflow/core/util/use_cudnn.h" +#ifdef INTEL_MKL +#include "tensorflow/core/util/mkl_heuristics.h" +#endif // INTEL_MKL #include "tensorflow/core/util/util.h" #if GOOGLE_CUDA @@ -744,7 +747,7 @@ void AddInputShapesAttr(const RemapperContext& ctx, int node_index) { *proto = tensor_property.shape(); } - if (IsMKLEnabled()) { + if (IsMKLEnabled() && tensor_properties.size() > 0) { (*mutable_node->mutable_attr())["_input_shapes"] = std::move(attr_input_shape); } @@ -1836,6 +1839,50 @@ bool FindMulAndMaximum(RemapperContext* ctx, int node_index, return found_op_type_match; } +bool FindSigmoidAndMul(RemapperContext* ctx, int node_index, + std::map* matched_nodes_map, + std::set* remove_node_indices) { + // Gelu fusion is enabled only with oneDNN library. + if (!IsMKLEnabled()) return false; + + using utils::MatchingDirection; + using utils::NodeStatus; + // clang-format off + // Convert Sigmoid+Mul to Swish + // Mul(x, Sigmoid(x)) --> _MklSwish(x) + + utils::OpTypePattern sigmoidmul_pattern{ + "Mul", "mul_to_swish", NodeStatus::kReplace, + { + { "Sigmoid", "sigmoid", NodeStatus::kRemove, + { + { "*", "input", NodeStatus::kRemain} + } + }, + { "*", "input", NodeStatus::kRemain} + } + }; + // clang-format on + // check for data types + auto* mul_node_def = ctx->graph_view.GetNode(node_index)->node(); + if (!HasDataType(mul_node_def, DT_FLOAT) && + !HasDataType(mul_node_def, DT_BFLOAT16)) + return false; + + if (!NodeIsOnCpu(mul_node_def)) return false; + + bool found_op_type_match = false; + utils::SubGraphMatcher graph_matcher( + &(ctx->graph_view)); + matched_nodes_map->clear(); + remove_node_indices->clear(); + found_op_type_match = graph_matcher.GetMatchedNodes( + sigmoidmul_pattern, {}, ctx->graph_view.GetNode(node_index), + matched_nodes_map, remove_node_indices); + + return found_op_type_match; +} + // Keras LayerNormalization api uses multiple TensorFlow ops. Current fusion // pattern is only for the case, when LayerNormalization uses FusedBatcNormV3. // We further restrict it to only 2D or 3D tensor inputs to keras @@ -2737,7 +2784,7 @@ void CopyConv2DAttributes(const NodeDef& conv2d, NodeDef* fused_conv2d, (*attr)["use_cudnn_on_gpu"] = src_attr.at("use_cudnn_on_gpu"); // When copying attributes check whether this convolution has // attribute that describes the shapes on which it is working. - if (IsMKLEnabled()) { + if (IsMKLEnabled() && src_attr.find("_input_shapes") != src_attr.end()) { (*attr)["_input_shapes"] = src_attr.at("_input_shapes"); } // Copy LeakyRelu's attr alpha to FusedConv2D's attr leakyrelu_alpha @@ -3351,6 +3398,7 @@ Status FuseConv2DSwish(RemapperContext* ctx, SetFusedOpAttributes(&fused_op, {"FusedBatchNorm", "_MklSwish"}, /*num_args=*/4, /*epsilon=*/epsilon); } + AddInputShapesAttr(*ctx, matched_nodes_map.at("conv")); CopyConv2DAttributes(*conv2d, &fused_op); utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder(); @@ -4376,6 +4424,8 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item, ctx.graph_view.SortTopologically(/*ignore_cycles=*/false, {})); const int num_nodes = item.graph.node_size(); + const int intra_op_parallelism_threads = + item.optimization_options().intra_op_parallelism_threads; // Skip nodes that were invalidated by a remapper, e.g. do not process BiasAdd // and Activation nodes that were fused into a Conv2D node. std::vector invalidated_nodes(num_nodes); @@ -4509,6 +4559,41 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item, continue; } + // Remap Mul(x, Sigmoid(x)) pattern, fuse them into the Swish(x). + std::map sigmoidmul_matched_nodes_map; + std::set sigmoidmul_remove_node_indices; + if (FindSigmoidAndMul(&ctx, i, &sigmoidmul_matched_nodes_map, + &sigmoidmul_remove_node_indices)) { + bool replace = true; +#ifdef DNNL_AARCH64_USE_ACL + // Need to check whether the cost of rewriting node + // to execute using oneDNN kernel will be amortised + // based on the size of the input + const int sigmoid_idx = sigmoidmul_matched_nodes_map.at("sigmoid"); + // We need to infer what is the shape of sigmoid + AddInputShapesAttr(ctx, sigmoid_idx); + const NodeDef* sigmoid = ctx.graph_view.GetNode(sigmoid_idx)->node(); + + double total_mflops = + CalculateNodeMFlops(AttrSlice(*sigmoid), "Sigmoid"); + double thr = + FindRewriteThreshold("Sigmoid", intra_op_parallelism_threads); + + if (total_mflops != -1 && total_mflops < thr) { + // The overhead of using oneDNN kernel is not amortized + // so we are not going to rewrite node + replace = false; + } +#endif + if (replace) { + TF_RETURN_IF_ERROR( + ReplaceSigmoidMulWithSwish(&ctx, sigmoidmul_matched_nodes_map, + sigmoidmul_remove_node_indices, + &invalidated_nodes, &nodes_to_delete)); + continue; + } + } + // Remap smaller ops from layernorm python api into _MklLayerNorm matched_nodes_map.clear(); remove_node_indices.clear(); diff --git a/tensorflow/core/util/BUILD b/tensorflow/core/util/BUILD index f7436f0a35c7b1..6c122b264aeb0b 100644 --- a/tensorflow/core/util/BUILD +++ b/tensorflow/core/util/BUILD @@ -11,6 +11,7 @@ load( "check_deps", "tf_cc_test", "tf_cc_tests", + "tf_cc_test_mkl", "tf_copts", "tf_cuda_library", "tf_cuda_only_cc_test", @@ -163,6 +164,7 @@ filegroup( "matmul_autotune.h", "matmul_bcast.h", "mirror_pad_mode.h", + "mkl_heuristics.h", "mkl_util.h", "onednn_env_vars.h", "overflow.h", @@ -295,6 +297,7 @@ filegroup( filegroup( name = "mkl_util_hdrs", srcs = [ + "mkl_heuristics.h", "mkl_util.h", "onednn_env_vars.h", "//tensorflow/tsl/util:onednn_util_hdrs", @@ -957,6 +960,21 @@ tf_cc_test( ], ) +tf_cc_test_mkl( + name = "mkl_heuristics_test", + size = "small", + srcs = ["mkl_heuristics_test.cc"], + linkstatic = 1, # Fixes dyld error on MacOS. + deps = [ + "//tensorflow/core:framework_lite", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:graph", + "//tensorflow/core:framework", + "//tensorflow/core/kernels:ops_testutil", + ], +) + # Proto libraries. tf_proto_library( name = "test_log_proto", diff --git a/tensorflow/core/util/mkl_heuristics.h b/tensorflow/core/util/mkl_heuristics.h new file mode 100644 index 00000000000000..162dbe81331c90 --- /dev/null +++ b/tensorflow/core/util/mkl_heuristics.h @@ -0,0 +1,123 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file contains heuristics data and methods that are used to +// decide whether to rewrite node to use oneDNN kernels + +#ifndef TENSORFLOW_CORE_UTIL_MKL_HEURISTICS_H +#define TENSORFLOW_CORE_UTIL_MKL_HEURISTICS_H_ +#ifdef INTEL_MKL + +#include "tensorflow/tsl/platform/cpu_info.h" + +namespace tensorflow { + +struct RewriteThreshold { + string op; + int cpu_family; + int cpu_model_num; + // The model that is used to decide whether it is worth + // accelerating operations using oneDNN is: + // + // threshold = thread_synchronisation * thread_num + framework_tax + // + // This finds threshold when framework overhead and thread synchronisations + // are amortized with amount of computation that has to be performed. + // If we are below this threshold then we will not rewrite the operation to + // to be run using oneDNN primitive. + struct PerformanceParameters { + double thread_sync_cost; + double framework_cost; + } params; +}; + +// Table storing thread synchronization and framework overhead costs on each CPU +// architecture for each oneNN-eligible operation. Our heuristics use these +// costs to determine whether we should rewrite the operation to use oneDNN. +static const RewriteThreshold rewrite_thresholds[] = { +#ifdef DNNL_AARCH64_USE_ACL + {"Conv2D", 0x41, 0xd40, {0.9349, 22.603}}, + {"_FusedConv2D", 0x41, 0xd40, {0.9349, 22.603}}, + {"FusedBatchNormV3", 0x41, 0xd40, {0.3223, -0.8822}}, + {"Sigmoid", 0x41, 0xd40, {0.0, 0.064736}}, +#endif // DNNL_AARCH64_USE_ACL + {"", 0x0, 0x0, {0, 0}}}; + +static double FindRewriteThreshold(const string node_name, int threads) { + int cpu_family_ = tsl::port::CPUFamily(); + int cpu_model_num_ = tsl::port::CPUModelNum(); + + if (threads == 0) { + // if we do not have information how many threads are used + // to parallelise operation we revert to the old behaviour + return 0; + } + + for (const RewriteThreshold* i = rewrite_thresholds; + i->op != "" && threads > 0; i++) { + if (node_name == i->op && cpu_family_ == i->cpu_family && + cpu_model_num_ == i->cpu_model_num) { + return i->params.thread_sync_cost * threads + i->params.framework_cost; + } + } + + return 0; +} + +static double CalculateNodeMFlops(const AttrSlice& attrs, + const string node_name) { + // Check if we can obtained dimensions for this node. + std::vector shape_attrs; + if (!TryGetNodeAttr(attrs, "_input_shapes", &shape_attrs)) { + // We can't obtain shape so we will revert to default behaviour + // to rewrite node. + return -1; + } + + if ((node_name == "Conv2D" || node_name == "_FusedConv2D") && + shape_attrs.size() == 2) { + TensorShape input_shape, filter_shape; + if (TensorShape::BuildTensorShape(*shape_attrs[0], &input_shape) != + tsl::OkStatus()) { + return -1; + } + if (TensorShape::BuildTensorShape(*shape_attrs[1], &filter_shape) != + tsl::OkStatus()) { + return -1; + } + + // MFLOPS = N * H * W * C * FH * FW * FC / 1e6. + return input_shape.dim_size(0) * input_shape.dim_size(1) * + input_shape.dim_size(2) * input_shape.dim_size(3) * + filter_shape.dim_size(0) * filter_shape.dim_size(1) * + filter_shape.dim_size(3) / (double)1e6; + } else if ((node_name == "FusedBatchNormV3" || node_name == "Sigmoid") && + shape_attrs.size() >= 1) { + TensorShape input_shape; + if (TensorShape::BuildTensorShape(*shape_attrs[0], &input_shape) != + tsl::OkStatus()) { + return -1; + } + return input_shape.dim_size(0) * input_shape.dim_size(1) * + input_shape.dim_size(2) * input_shape.dim_size(3) / (double)1e6; + } + + return -1; +} + +} // namespace tensorflow + +#endif // INTEL_MKL +#endif // TENSORFLOW_CORE_UTIL_MKL_HEURISTICS_H_ diff --git a/tensorflow/core/util/mkl_heuristics_test.cc b/tensorflow/core/util/mkl_heuristics_test.cc new file mode 100644 index 00000000000000..5dfdd1fa24a051 --- /dev/null +++ b/tensorflow/core/util/mkl_heuristics_test.cc @@ -0,0 +1,120 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifdef INTEL_MKL +#define EIGEN_USE_THREADS + +#include "tensorflow/core/util/mkl_heuristics.h" + +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +namespace { + +TEST(MklHeuristicsTest, MklCalculateMFlops) { + int batch = 8; + int width = 32; + int height = 32; + int in_depth = 3; + + int filter_h = 3; + int filter_w = 3; + int out_depth = 64; + + // Test calculation for number of MFLOPs for convolution + AttrValue attr_input_shape; + TensorShapeProto* proto = attr_input_shape.mutable_list()->add_shape(); + proto->add_dim()->set_size(batch); + proto->add_dim()->set_size(width); + proto->add_dim()->set_size(height); + proto->add_dim()->set_size(in_depth); + proto = attr_input_shape.mutable_list()->add_shape(); + proto->add_dim()->set_size(filter_h); + proto->add_dim()->set_size(filter_w); + proto->add_dim()->set_size(in_depth); + proto->add_dim()->set_size(out_depth); + + NodeDef ndef; + + // If node doesn't have any _input_shapes it should return -1 + double calculated_empty_mflops = + CalculateNodeMFlops(AttrSlice(ndef), "Conv2D"); + EXPECT_EQ(calculated_empty_mflops, -1); + + (*ndef.mutable_attr())["_input_shapes"] = attr_input_shape; + + double conv_calculated_mflops = + CalculateNodeMFlops(AttrSlice(ndef), "Conv2D"); + double expected_conv_mflops = batch * width * height * in_depth * filter_h * + filter_w * out_depth / double(1e6); + EXPECT_EQ(conv_calculated_mflops, expected_conv_mflops); + + // We should get the same calculation for fused convolution too + double fused_calculated_mflops = + CalculateNodeMFlops(AttrSlice(ndef), "_FusedConv2D"); + EXPECT_EQ(conv_calculated_mflops, expected_conv_mflops); + + // Finally calculate for sigmoid number of MFLOPS + double sigmoid_calculated_mflops = + CalculateNodeMFlops(AttrSlice(ndef), "Sigmoid"); + double expected_sigmoid_mflops = + batch * width * height * in_depth / double(1e6); + EXPECT_EQ(sigmoid_calculated_mflops, expected_sigmoid_mflops); +} + +#ifdef DNNL_AARCH64_USE_ACL +TEST(MklHeuristicsTest, MklThresholds) { + int cpu_family = tsl::port::CPUFamily(); + int cpu_model_num = tsl::port::CPUModelNum(); + + int neoverse_v1_family = 0x41; + int neoverse_v1_model = 0xd40; + + string op_type = "Conv2D"; + + if (neoverse_v1_family == cpu_family && neoverse_v1_model == cpu_model_num) { + double thread_sync_cost = -1; + double framework_cost = -1; + for (const RewriteThreshold* i = rewrite_thresholds; i->op != ""; i++) { + if (i->op == op_type) { + thread_sync_cost = i->params.thread_sync_cost; + framework_cost = i->params.framework_cost; + break; + } + } + + EXPECT_NE(thread_sync_cost, -1); + EXPECT_NE(thread_sync_cost, -1); + + int no_threads = 0; + double calculated_threshold_zero_threads = + FindRewriteThreshold(op_type, no_threads); + EXPECT_EQ(calculated_threshold_zero_threads, 0); + + int threads = 8; + double calculated_threshold = FindRewriteThreshold(op_type, threads); + double expected_threshold = threads * thread_sync_cost + framework_cost; + EXPECT_EQ(expected_threshold, calculated_threshold); + } +} +#endif // DNNL_AARCG64_USE_ACL + +} // namespace +} // namespace tensorflow + +#endif // INTEL_MKL \ No newline at end of file From 2bb6bb743d595ec6eef4fe0344167ddfa4117aba Mon Sep 17 00:00:00 2001 From: Milos Puzovic Date: Tue, 4 Jul 2023 11:44:47 +0100 Subject: [PATCH 035/410] Remove definition of RewriteThreshold from mkl_layout_pass.h --- .../core/common_runtime/mkl_layout_pass.h | 20 ------------------- 1 file changed, 20 deletions(-) diff --git a/tensorflow/core/common_runtime/mkl_layout_pass.h b/tensorflow/core/common_runtime/mkl_layout_pass.h index d4303f1f1e8651..6b5c586ceabb3e 100644 --- a/tensorflow/core/common_runtime/mkl_layout_pass.h +++ b/tensorflow/core/common_runtime/mkl_layout_pass.h @@ -25,26 +25,6 @@ limitations under the License. #include "tensorflow/core/graph/graph.h" namespace tensorflow { - -struct RewriteThreshold { - string op; - int cpu_family; - int cpu_model_num; - // The model that is used to decide whether it is worth - // accelerating operations using oneDNN is: - // - // threshold = thread_synchronisation * thread_num + framework_tax - // - // This finds threshold when framework overhead and thread synchronisations - // are amortized with amount of computation that has to be performed. - // If we are below this threshold then we will not rewrite the operation to - // to be run using oneDNN primitive. - struct PerformanceParameters { - double thread_sync_cost; - double framework_cost; - } params; -}; - // Interface to invoke the pass for unit test // // Returns true if and only if 'g' is mutated. From 7123130f71d0d93c5072613e52095dd5fb9f70cc Mon Sep 17 00:00:00 2001 From: Milos Puzovic Date: Tue, 4 Jul 2023 12:20:59 +0100 Subject: [PATCH 036/410] Additional shapes for which to infer dimensions --- tensorflow/core/grappler/optimizers/remapper.cc | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc index 70c3a86e6f236c..e9488d3358ff3b 100644 --- a/tensorflow/core/grappler/optimizers/remapper.cc +++ b/tensorflow/core/grappler/optimizers/remapper.cc @@ -4401,7 +4401,9 @@ bool RequiresInferredShapes(const RemapperContext& ctx, int node_index, if (IsMKLEnabled()) return is_batch_norm_candidate() || is_batch_norm_fusion_candidate() || IsContractionWithAdd(ctx, node_index) || - is_act_biasadd_conv_candidate(); + is_act_biasadd_conv_candidate() || + IsBiasAdd(*node_def) || + IsTranspose(*node_def); return is_act_biasadd_conv_candidate() || is_batch_norm_candidate() || is_batch_norm_fusion_candidate() || @@ -4463,6 +4465,8 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item, if (IsConv2D(ctx.graph_view.graph()->node(i)) || IsFusedBatchNorm(ctx.graph_view.graph()->node(i)) || IsDepthwiseConv2dNative(ctx.graph_view.graph()->node(i)) || + IsBiasAdd(ctx.graph_view.graph()->node(i)) || + IsTranspose(ctx.graph_view.graph()->node(i)) || IsSigmoid(ctx.graph_view.graph()->node(i))) { AddInputShapesAttr(ctx, i); } From e30650cbfa6bff4192c558f17a366dcf48639673 Mon Sep 17 00:00:00 2001 From: Milos Puzovic Date: Tue, 4 Jul 2023 14:01:21 +0100 Subject: [PATCH 037/410] Fix HWCAP_CPUID when not defined in some cases --- tensorflow/tsl/platform/cpu_info.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tensorflow/tsl/platform/cpu_info.cc b/tensorflow/tsl/platform/cpu_info.cc index 081ff0a49d6639..1f9d6905fe4e20 100644 --- a/tensorflow/tsl/platform/cpu_info.cc +++ b/tensorflow/tsl/platform/cpu_info.cc @@ -24,7 +24,9 @@ limitations under the License. #endif #if defined(PLATFORM_IS_ARM64) #include - +#ifndef HWCAP_CPUID +#define HWCAP_CPUID (1 << 11) +#endif #include #endif From 2ef4747ec75ef5057698acce25715e33c121a388 Mon Sep 17 00:00:00 2001 From: shuw Date: Thu, 20 Jul 2023 11:40:18 -0700 Subject: [PATCH 038/410] Improve based on review 1. --- .../compiler/xla/service/gpu/gemm_rewriter.cc | 103 ++++++++++-------- .../service/gpu/tests/gemm_rewrite_test.cc | 68 +++++++++++- 2 files changed, 125 insertions(+), 46 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc b/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc index 71c79825bec958..9de8bdf2ec0e79 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc @@ -112,27 +112,36 @@ Shape PadShapeToMultipleOf16(const Shape old_shape, return padded_shape; } -// Pad the non-batch dimensions of the operands to multiples of 16 as required -// by cuBLASLt. -HloInstruction *PadOperandToMultipleOf16(absl::Span batch_dims, - HloInstruction *instr, - HloInstruction *x) { +// Pad the non-batch dimensions of the operands to the target shape. +HloInstruction *PadOperandToTargetShape(const Shape &target, + HloInstruction *instr, + HloInstruction *x) { + if (ShapeUtil::Equal(target, x->shape())) { + return x; + } + PaddingConfig padding_config; - Shape padded_shape = PadShapeToMultipleOf16(x->shape(), batch_dims); for (int i = 0; i < x->shape().rank(); ++i) { auto dimension = padding_config.add_dimensions(); dimension->set_edge_padding_low(0); - dimension->set_edge_padding_high(padded_shape.dimensions(i) - + dimension->set_edge_padding_high(target.dimensions(i) - x->shape().dimensions(i)); dimension->set_interior_padding(0); } - if (!ShapeUtil::Equal(padded_shape, x->shape())) { - HloInstruction *zero = instr->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::Zero(x->shape().element_type()))); - return instr->AddInstruction( - HloInstruction::CreatePad(padded_shape, x, zero, padding_config)); - } - return x; + + HloInstruction *zero = instr->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(x->shape().element_type()))); + return instr->AddInstruction( + HloInstruction::CreatePad(target, x, zero, padding_config)); +} + +// Pad the non-batch dimensions of the operands to multiples of 16 as required +// by cuBLASLt FP8 gemms. +HloInstruction *PadOperandToMultipleOf16(absl::Span batch_dims, + HloInstruction *instr, + HloInstruction *x) { + Shape padded_shape = PadShapeToMultipleOf16(x->shape(), batch_dims); + return PadOperandToTargetShape(padded_shape, instr, x); } // Recursively collects unary, pad, divide or multiply operands of instr until @@ -1048,8 +1057,9 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { // Fuses a matrix bias into a cuBLAS call. 'instr' should be an Add // instruction in the following form: - // Add(OptionalBitcast(OptionalSlice(gemm)), bias) where 'gemm' is expected - // to be a cuBLAS custom_call. + // Add(OptionalBitcast(OptionalSlice(gemm)), bias) + // where 'gemm' is expected to be a cuBLAS custom_call. Slices are introduced + // when the inputs of the gemm are possibly padded. Status FuseMatrixBiasAdd(HloInstruction *instr, HloInstruction *bias, const HloInstruction *gemm, HloInstruction *bitcast = nullptr, @@ -1113,19 +1123,20 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { std::vector operands(gemm->operands().begin(), gemm->operands().end()); - HloInstruction *broadcast_bias = MaybeConstantFoldBias(bias); + HloInstruction *maybe_constant_folded_bias = MaybeConstantFoldBias(bias); if (slice && bitcast) { - broadcast_bias = instr->AddInstruction( - HloInstruction::CreateBitcast(slice->shape(), broadcast_bias)); + maybe_constant_folded_bias = + instr->AddInstruction(HloInstruction::CreateBitcast( + slice->shape(), maybe_constant_folded_bias)); } - if (gemm->custom_call_target() == kCublasLtMatmulF8CallTarget) { - broadcast_bias = PadOperandToMultipleOf16( - config.dot_dimension_numbers().rhs_batch_dimensions(), instr, - broadcast_bias); + if (gemm->custom_call_target() == kCublasLtMatmulF8CallTarget || + gemm->custom_call_target() == kCublasLtMatmulCallTarget) { + maybe_constant_folded_bias = PadOperandToTargetShape( + gemm->shape(), instr, maybe_constant_folded_bias); } - operands.insert(operands.begin() + 2, broadcast_bias); + operands.insert(operands.begin() + 2, maybe_constant_folded_bias); std::unique_ptr fused_op = gemm->CloneWithNewOperands(gemm->shape(), operands); @@ -1154,6 +1165,11 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { } TF_RETURN_IF_ERROR(SetName(instr->GetModule(), fused_op.get())); if (slice != nullptr) { + int slice_op_dim = slice->operand(0)->shape().rank(); + if (slice->slice_starts() != std::vector(slice_op_dim, 0)) { + return OkStatus(); + } + fused_op = slice->CloneWithNewOperands( slice->shape(), {slice->parent()->AddInstruction(std::move(fused_op))}); @@ -1168,19 +1184,19 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { return ReplaceWithNewInstruction(instr, std::move(fused_op)); } + // Fuses a vector bias into a cuBLAS call. 'instr' should be an Add + // instruction in the following form: + // Add(OptionalBitcast(OptionalSlice(gemm)), broadcast) + // where 'gemm' is expected to be a cuBLAS custom_call. The optional + // convert is only used for F8 matmuls as cublasLt has specific constraints + // on the vector bias type for such matmuls. The optional bitcast is + // necessary to handle high rank input cases. StatusOr FuseVectorBiasAdd(HloInstruction *instr, HloInstruction *broadcast, HloInstruction *gemm, HloInstruction *slice = nullptr, HloInstruction *convert = nullptr, HloInstruction *bitcast = nullptr) { - // Fuses a vector bias into a cuBLAS call. 'instr' should be an Add - // instruction in the following form: - // Add(OptionalBitcast(OptionalSlice(gemm)), broadcast), where 'gemm' is - // expected to be a cuBLAS custom_call. The optional convert is only - // applicable to F8 matmul as cublasLt has specific constraints on the - // vector bias type. The optional bitcast is necessary to handle high rank - // input cases. if (bitcast == nullptr) { TF_RET_CHECK(ShapeUtil::Compatible( broadcast->shape(), (slice ? slice->shape() : gemm->shape()))); @@ -1209,23 +1225,20 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { // dimensions; i.e. its most minor physical dimensions align with most minor // physical dimensions of the gemm output. absl::Span broadcast_dims = broadcast->dimensions(); - if (bitcast != nullptr) { - broadcast_dims = gemm->shape().dimensions(); - } else { - for (size_t i = 0; i < num_col_dims; ++i) { - int64_t dim = gemm->shape().layout().minor_to_major(i); + for (size_t i = 0; i < num_col_dims; ++i) { + int64_t dim = + (bitcast ? bitcast : gemm)->shape().layout().minor_to_major(i); - // Find the corresponding dimension from the bias vector. - auto it = absl::c_find(broadcast_dims, dim); + // Find the corresponding dimension from the bias vector. + auto it = absl::c_find(broadcast_dims, dim); - if (it == broadcast_dims.end()) { - return false; - } + if (it == broadcast_dims.end()) { + return false; + } - int64_t vector_dim = it - broadcast_dims.begin(); - if (bias->shape().layout().minor_to_major(i) != vector_dim) { - return false; - } + int64_t vector_dim = it - broadcast_dims.begin(); + if (bias->shape().layout().minor_to_major(i) != vector_dim) { + return false; } } diff --git a/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc index 4db9f60667d57b..984a0cdc57f654 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc @@ -4845,7 +4845,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDMatrixBiasPaddedF8) { ; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[31,14]{1,0} parameter(1) ; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[14,31]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[C1:%[^ ]+]] = f8e4m3fn[] constant(0) -; CHECK-NEXT: [[P1_TRANSPOSE_PADDED:%[^ ]+]] = f8e4m3fn[16,32]{1,0} pad([[P1_TRANSPOSE]], [[C1]]) +; CHECK-NEXT: [[P1_TRANSPOSE_PADDED:%[^ ]+]] = f8e4m3fn[16,32]{1,0} pad([[P1_TRANSPOSE]], [[C1]]), padding=0_2x0_1 ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[14,14]{1,0} parameter(2) ; CHECK-NEXT: [[C2:%[^ ]+]] = f32[] constant(0) ; CHECK-NEXT: [[P2_PADDED:%[^ ]+]] = f32[16,16]{1,0} pad([[P2]], [[C2]]), padding=0_2x0_2 @@ -5639,6 +5639,72 @@ TEST_P(ParameterizedFp8GemmRewriteTest, )"); } +TEST_P(ParameterizedFp8GemmRewriteTest, + ScaledABUnscaledDMatrixBiasWithSliceF8) { +#if CUDA_VERSION < 12000 + GTEST_SKIP() << "A matrix bias on a matmul is only supported in CUDA 12"; +#endif + const char* hlo_text = R"( + HloModule test + ENTRY test { + x = f8e4m3fn[48,16] parameter(0) + y = f8e4m3fn[16,32] parameter(1) + b = f32[32,16] parameter(2) + x_f32 = f32[48,16] convert(x) + y_f32 = f32[16,32] convert(y) + x_scale = f32[] parameter(3) + y_scale = f32[] parameter(4) + x_scale_bcast = f32[48,16] broadcast(x_scale), dimensions={} + y_scale_bcast = f32[16,32] broadcast(y_scale), dimensions={} + x_unscaled = f32[48,16] multiply(x_f32, x_scale_bcast) + y_unscaled = f32[16,32] multiply(y_f32, y_scale_bcast) + dot_a = f32[48,32] dot(x_unscaled, y_unscaled), lhs_contracting_dims={1}, rhs_contracting_dims={0} + dot_a_sliced = f32[32,16] slice(dot_a), slice={[16:48], [16:32]} + ROOT out = f32[32,16] add(dot_a_sliced, b) + } +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + GemmRewriter pass( + se::CudaComputeCapability{se::CudaComputeCapability::HOPPER, 0}); + TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); + EXPECT_TRUE(changed); + + RunAndFilecheckHloRewrite(hlo_text, + GemmRewriter(se::CudaComputeCapability{ + se::CudaComputeCapability::HOPPER, 0}), + R"( +; CHECK-LABEL: ENTRY %test (x: f8e4m3fn[48,16], y: f8e4m3fn[16,32], b: f32[32,16], x_scale: f32[], y_scale: f32[]) -> f32[32,16] { +; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fn[48,16]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = f8e4m3fn[16,32]{1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = f8e4m3fn[32,16]{1,0} transpose([[P1]]), dimensions={1,0} +; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(3) +; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(4) +; CHECK-NEXT: [[C:%[^ ]+]] = f32[] constant(1) +; CHECK-NEXT: [[GEMM:%[^ ]+]] = f32[48,32]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C]], /*index=5*/[[C]]), +; CHECK: custom_call_target="__cublas$lt$matmul$f8", +; CHECK: backend_config="{ +; CHECK-DAG: \"alpha_real\":1 +; CHECK-DAG: \"alpha_imag\":0 +; CHECK-DAG: \"beta\":0 +; CHECK-DAG: \"dot_dimension_numbers\":{ +; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"rhs_contracting_dimensions\":[\"1\"] +; CHECK-DAG: \"lhs_batch_dimensions\":[] +; CHECK-DAG: \"rhs_batch_dimensions\":[] +; CHECK-DAG: } +; CHECK-DAG: \"precision_config\":{ +; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] +; CHECK-DAG: } +; CHECK-DAG: \"epilogue\":\"DEFAULT\" +; CHECK: }" +; CHECK-NEXT: [[SLICE:%[^ ]+]] = f32[32,16]{1,0} slice([[GEMM]]), slice={[16:48], [16:32]} +; CHECK-NEXT: [[B:%[^ ]+]] = f32[32,16]{1,0} parameter(2) +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[32,16]{1,0} add([[SLICE]], [[B]]) + )"); +} + TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDMatrixBiasThenVectorBiasF8) { #if CUDA_VERSION < 12000 From e6065e5c0c14869ef27809679339212bc36f7734 Mon Sep 17 00:00:00 2001 From: "guozhong.zhuang" Date: Thu, 20 Jul 2023 14:50:01 -0700 Subject: [PATCH 039/410] Use correct data type for dim size --- tensorflow/core/kernels/mkl/mkl_conv_ops.cc | 9 ++++----- .../kernels/mkl/mkl_fused_instance_norm_op.cc | 4 ++-- .../core/kernels/mkl/mkl_layer_norm_op.cc | 6 +++--- .../core/kernels/mkl/mkl_matmul_op_fused.cc | 8 ++++---- .../core/kernels/mkl/mkl_matmul_ops_common.h | 4 ++-- tensorflow/core/kernels/mkl/mkl_qmatmul_op.cc | 18 +++++++++--------- 6 files changed, 24 insertions(+), 25 deletions(-) diff --git a/tensorflow/core/kernels/mkl/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl/mkl_conv_ops.cc index c179fca517df3b..8cf49990803797 100644 --- a/tensorflow/core/kernels/mkl/mkl_conv_ops.cc +++ b/tensorflow/core/kernels/mkl/mkl_conv_ops.cc @@ -456,9 +456,8 @@ class MklConvFwdPrimitive : public MklPrimitive { } else if (post_op_param.name == "wei_scale") { is_scale_set.insert({"wei", true}); const int scale_size = post_op_param.param.size(); - const int mask = scale_size == 1 ? 0 - : convFwdDims.is_depthwise ? 3 - : 1; + const int mask = + scale_size == 1 ? 0 : convFwdDims.is_depthwise ? 3 : 1; post_ops_attr.set_scales_mask(DNNL_ARG_WEIGHTS, mask); context_.wei_scale_md.reset(new memory::desc( {scale_size}, MklDnnType(), memory::format_tag::x)); @@ -1786,8 +1785,8 @@ class MklFusedConvOp Eigen::Tensor bn_rsqrt = (bn_var_tensor.flat() + static_cast(epsilon)).rsqrt(); Tinput* bn_rsqrt_data = bn_rsqrt.data(); - size_t num_elem = bn_var_tensor.shape().dim_size(0); - for (size_t i = 0; i < num_elem; i++) { + int64_t num_elem = bn_var_tensor.shape().dim_size(0); + for (int64_t i = 0; i < num_elem; i++) { scale_buf_ptr[i] = bn_rsqrt_data[i]; } return; diff --git a/tensorflow/core/kernels/mkl/mkl_fused_instance_norm_op.cc b/tensorflow/core/kernels/mkl/mkl_fused_instance_norm_op.cc index c103986c198faa..58f0b140b73d4e 100644 --- a/tensorflow/core/kernels/mkl/mkl_fused_instance_norm_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_fused_instance_norm_op.cc @@ -15,7 +15,6 @@ limitations under the License. #ifdef INTEL_MKL -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "dnnl.hpp" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -23,6 +22,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/util/mkl_util.h" #include "tensorflow/core/util/tensor_format.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" using namespace dnnl; using dnnl::batch_normalization_forward; @@ -80,7 +80,7 @@ class MklFusedInstanceNormOp : public OpKernel { std::shared_ptr engine_stream_ptr; engine_stream_ptr.reset(CreateStream(&eigen_tp, cpu_engine_)); - const int batch_size = src_tensor.shape().dim_size(0); + const int64_t batch_size = src_tensor.shape().dim_size(0); const int64_t elems_per_batch = src_tensor.shape().num_elements() / batch_size; diff --git a/tensorflow/core/kernels/mkl/mkl_layer_norm_op.cc b/tensorflow/core/kernels/mkl/mkl_layer_norm_op.cc index ae5ad08b3f4393..afd7d6c6825905 100644 --- a/tensorflow/core/kernels/mkl/mkl_layer_norm_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_layer_norm_op.cc @@ -15,7 +15,6 @@ limitations under the License. #ifdef INTEL_MKL -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "dnnl.hpp" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -23,6 +22,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/util/mkl_util.h" #include "tensorflow/core/util/tensor_format.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" using CPUDevice = Eigen::ThreadPoolDevice; using dnnl::layer_normalization_forward; @@ -53,8 +53,8 @@ class MklLayerNormOp : public OpKernel { OP_REQUIRES(ctx, shift_tensor.dims() == 1, errors::InvalidArgument("offset must be 1D tensor", shift_tensor.shape().DebugString())); - size_t num_elements_scale = scale_tensor.dim_size(0); - size_t num_elements_shift = shift_tensor.dim_size(0); + int64_t num_elements_scale = scale_tensor.dim_size(0); + int64_t num_elements_shift = shift_tensor.dim_size(0); OP_REQUIRES( ctx, num_elements_scale == num_elements_shift, errors::InvalidArgument("Number of elements in scale and shift", diff --git a/tensorflow/core/kernels/mkl/mkl_matmul_op_fused.cc b/tensorflow/core/kernels/mkl/mkl_matmul_op_fused.cc index 1d388705e7ddd0..a6dde0d357b07c 100644 --- a/tensorflow/core/kernels/mkl/mkl_matmul_op_fused.cc +++ b/tensorflow/core/kernels/mkl/mkl_matmul_op_fused.cc @@ -99,10 +99,10 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase { // Get dimension size of each matrix, dim_pair[] is the location of k // in the inputs, we have constraint that k of the two inputs are // the same - const int dim_pair[] = {1, transpose_b_ ? 1 : 0}; - const int batch = src_tf_shape.dim_size(1 - dim_pair[0]); - const int k = src_tf_shape.dim_size(dim_pair[0]); - const int channel = weight_tf_shape.dim_size(1 - dim_pair[1]); + const int64_t dim_pair[] = {1, transpose_b_ ? 1 : 0}; + const int64_t batch = src_tf_shape.dim_size(1 - dim_pair[0]); + const int64_t k = src_tf_shape.dim_size(dim_pair[0]); + const int64_t channel = weight_tf_shape.dim_size(1 - dim_pair[1]); OP_REQUIRES( ctx, k == weight_tf_shape.dim_size(dim_pair[1]), diff --git a/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h b/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h index 3e55f11cd24abe..b9cfe85369e628 100644 --- a/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h +++ b/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h @@ -21,12 +21,12 @@ limitations under the License. #include #include -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "dnnl.hpp" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/util/mkl_util.h" #include "tensorflow/core/util/onednn_env_vars.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #ifdef DNNL_AARCH64_USE_ACL #include "tensorflow/core/platform/mutex.h" #endif @@ -79,7 +79,7 @@ inline bool ExecuteSingleThreadedGemm(int64_t m, int64_t n, int64_t k, constexpr float kHeuristicMultiplier = 1.01; const float mul_size = bytes * (m * n + k * (m + n)); const float l2_heur = l2_size * kHeuristicMultiplier; - return mul_size < l2_heur; + return (!(mul_size < 0) && (mul_size < l2_heur)); } // This structure aggregates multiple inputs to MklDnnMatMul* methods. diff --git a/tensorflow/core/kernels/mkl/mkl_qmatmul_op.cc b/tensorflow/core/kernels/mkl/mkl_qmatmul_op.cc index 259dfacc0bf51b..a757b41b590b90 100644 --- a/tensorflow/core/kernels/mkl/mkl_qmatmul_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_qmatmul_op.cc @@ -536,8 +536,8 @@ class MklDnnQuantizedMatMulOp // compensated with B's32 = Q'a * Qw * Bf32 + Q'a * Qw * Min(Af32) * 1 * // Wf32. if (mode_ == QUANTIZE_MODE_MIN_FIRST) { - int k = weight_tensor.dim_size(0); - int n = weight_tensor.dim_size(1); + int64_t k = weight_tensor.dim_size(0); + int64_t n = weight_tensor.dim_size(1); float* comp_bias = GetCompBiasBuffer(n); qint8* wt_buf = static_cast( @@ -553,10 +553,10 @@ class MklDnnQuantizedMatMulOp std::max(std::abs(max_weight), std::abs(min_weight))); #ifndef ENABLE_ONEDNN_OPENMP - auto parallel_func = [&](int64 start, int64 end) { - for (int64 j = start; j < end; j++) { - int x = 0; - for (int64 i = 0; i < k; ++i) { + auto parallel_func = [&](int64_t start, int64_t end) { + for (int64_t j = start; j < end; j++) { + int64_t x = 0; + for (int64_t i = 0; i < k; ++i) { x += wt_buf[i * n + j]; } comp_bias[j] = @@ -573,9 +573,9 @@ class MklDnnQuantizedMatMulOp parallel_func); #else #pragma omp parallel for schedule(static) - for (int j = 0; j < n; ++j) { - int x = 0; - for (int i = 0; i < k; ++i) { + for (int64_t j = 0; j < n; ++j) { + int64_t x = 0; + for (int64_t i = 0; i < k; ++i) { x += wt_buf[i * n + j]; } comp_bias[j] = From 528c1dc11d995ece9c2841adbd834cbd2afcae9e Mon Sep 17 00:00:00 2001 From: Milos Puzovic Date: Fri, 21 Jul 2023 11:20:55 +0100 Subject: [PATCH 040/410] Address comments from review --- tensorflow/tsl/platform/cpu_info.cc | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tensorflow/tsl/platform/cpu_info.cc b/tensorflow/tsl/platform/cpu_info.cc index 1f9d6905fe4e20..e06276fba5a95a 100644 --- a/tensorflow/tsl/platform/cpu_info.cc +++ b/tensorflow/tsl/platform/cpu_info.cc @@ -376,9 +376,10 @@ class CPUIDInfo { return; } + int present_cpu = -1; +#if !defined(PLATFORM_WINDOWS) && !defined(__APPLE__) && !defined(__OpenBSD__) std::ifstream CPUspresent; CPUspresent.open("/sys/devices/system/cpu/present", std::ios::in); - int present_cpu = -1; if (CPUspresent.is_open()) { std::string line; if (bool(getline(CPUspresent, line))) { @@ -397,11 +398,13 @@ class CPUIDInfo { present_cpu = std::stoi(line); } } +#endif if (present_cpu == -1) { return; } +#if !defined(PLATFORM_WINDOWS) && !defined(__APPLE__) && !defined(__OpenBSD__) std::stringstream str; str << "/sys/devices/system/cpu/cpu" << present_cpu << "/regs/identification/midr_el1"; @@ -417,6 +420,7 @@ class CPUIDInfo { cpuid->cpunum_ = (midr_el1 >> 4) & 0xFFF; } } +#endif } int implementer() const { return implementer_; } From 52488d3248cf2678c3d760098750c6a61c37aa36 Mon Sep 17 00:00:00 2001 From: johnnkp <22496821+johnnkp@users.noreply.github.com> Date: Fri, 21 Jul 2023 19:53:00 +0800 Subject: [PATCH 041/410] segment_reduction_ops_gpu: Comments cleanup https://github.com/tensorflow/tensorflow/pull/61339 --- .../core/kernels/segment_reduction_ops_gpu.cu.h | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.h b/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.h index 4c4ff91ef4b792..b975d2ebdff15d 100644 --- a/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.h +++ b/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.h @@ -69,6 +69,8 @@ DEFINE_REDUCE_UPDATE_OP_FOR(functor::Max, AtomicMaxOpGpu, NonAtomicMaxOpGpu) DEFINE_REDUCE_UPDATE_OP_FOR(functor::Min, AtomicMinOpGpu, NonAtomicMinOpGpu) #undef DEFINE_REDUCE_UPDATE_OP_FOR +// PR#61339: MSVC does not support compound-assignment operators on device + // SortedSegmentReductionFunctor kernel reduces input data just as // UnsortedSegmentReductionCustomKernel does except that input data // is partitioned along the outer reduction dimension. This is @@ -347,7 +349,7 @@ __global__ void SegmentReduceVectorKernel( Treducevec block_result = x_ok && y_ok ? input_vec[input_idx] : Tvec(initial_value); // Apply weights if provided. - if (weights && y_ok) block_result = block_result * Tvec(weights[y_idx]); // MSVC fix + if (weights && y_ok) block_result = block_result * Tvec(weights[y_idx]); // Reduce along the columns of the block, returning result in first row. block_result = ReduceBlockAlongCols(reduce_op, block_result, x_ok); if (y == 0 && x_ok) { @@ -363,9 +365,9 @@ __global__ void SegmentReduceVectorKernel( typename RealTypeIfComplex::type total_weight(end - begin); // Normalize the results if necessary. if (is_mean) { - result = result / Treducevec(total_weight); // MSVC fix + result = result / Treducevec(total_weight); } else if (is_sqrtn) { - result = result / Treducevec(sqrt(static_cast(total_weight))); // MSVC fix + result = result / Treducevec(sqrt(static_cast(total_weight))); } } // Cast from Treducevec to Tvec. @@ -439,9 +441,8 @@ __global__ void SegmentReduceEpilogueKernel( // Empty segment. val = Treducevec(empty_segment_value); } else if (is_mean) { - val = val / Treducevec(segment_size); // MSVC fix + val = val / Treducevec(segment_size); } else if (is_sqrtn) { - // MSVC fix val = val / Treducevec( sqrt(static_cast(typename RealTypeIfComplex::type(segment_size)))); } @@ -492,7 +493,7 @@ struct LookupAndScaleAndCastInputsFunctor { __device__ Treducevec operator()(Tindex idx) const { if (indices_) idx = indices_[idx]; Treducevec result = static_cast(input_vec_[idx]); - if (weights_) result = result * Tvec(weights_[idx]); // MSVC fix + if (weights_) result = result * Tvec(weights_[idx]); return result; } From 8cd941b54c44d259e74992309d8c4def4eb93745 Mon Sep 17 00:00:00 2001 From: johnnkp <22496821+johnnkp@users.noreply.github.com> Date: Fri, 21 Jul 2023 23:18:22 +0800 Subject: [PATCH 042/410] gpu_kernel_helper.h: Add uint `tf_min/max()` overloading --- tensorflow/core/util/gpu_kernel_helper.h | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tensorflow/core/util/gpu_kernel_helper.h b/tensorflow/core/util/gpu_kernel_helper.h index ba6ae9ab153fa0..6857239488f307 100644 --- a/tensorflow/core/util/gpu_kernel_helper.h +++ b/tensorflow/core/util/gpu_kernel_helper.h @@ -180,6 +180,12 @@ __host__ __device__ inline int tf_min(int x, int y) { __host__ __device__ inline int tf_max(int x, int y) { return max(x, y); } +__host__ __device__ inline int tf_min(unsigned int x, int y) { + return min(static_cast(x), y); +} +__host__ __device__ inline int tf_max(unsigned int x, int y) { + return max(static_cast(x), y); +} #endif #endif From 2cc1fcd602fa175fc25ba70f12977067aee2160d Mon Sep 17 00:00:00 2001 From: shuw Date: Fri, 21 Jul 2023 08:33:17 -0700 Subject: [PATCH 043/410] Improve based on review 2 --- .../compiler/xla/service/gpu/gemm_rewriter.cc | 57 ++++++++++--------- .../service/gpu/tests/gemm_rewrite_test.cc | 2 + 2 files changed, 33 insertions(+), 26 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc b/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc index 9de8bdf2ec0e79..4e3b48a9dc5d27 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc @@ -114,7 +114,6 @@ Shape PadShapeToMultipleOf16(const Shape old_shape, // Pad the non-batch dimensions of the operands to the target shape. HloInstruction *PadOperandToTargetShape(const Shape &target, - HloInstruction *instr, HloInstruction *x) { if (ShapeUtil::Equal(target, x->shape())) { return x; @@ -129,19 +128,18 @@ HloInstruction *PadOperandToTargetShape(const Shape &target, dimension->set_interior_padding(0); } - HloInstruction *zero = instr->AddInstruction(HloInstruction::CreateConstant( + HloInstruction *zero = x->AddInstruction(HloInstruction::CreateConstant( LiteralUtil::Zero(x->shape().element_type()))); - return instr->AddInstruction( + return x->AddInstruction( HloInstruction::CreatePad(target, x, zero, padding_config)); } // Pad the non-batch dimensions of the operands to multiples of 16 as required // by cuBLASLt FP8 gemms. HloInstruction *PadOperandToMultipleOf16(absl::Span batch_dims, - HloInstruction *instr, HloInstruction *x) { Shape padded_shape = PadShapeToMultipleOf16(x->shape(), batch_dims); - return PadOperandToTargetShape(padded_shape, instr, x); + return PadOperandToTargetShape(padded_shape, x); } // Recursively collects unary, pad, divide or multiply operands of instr until @@ -882,8 +880,8 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { b = TransposeMatrix(b, b_contracting_dims[0], batch_dims); } - a = PadOperandToMultipleOf16(batch_dims, instr, a); - b = PadOperandToMultipleOf16(batch_dims, instr, b); + a = PadOperandToMultipleOf16(batch_dims, a); + b = PadOperandToMultipleOf16(batch_dims, b); Shape new_output_shape = PadShapeToMultipleOf16(instr->shape(), batch_dims); std::vector operands_list = { @@ -1058,8 +1056,9 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { // Fuses a matrix bias into a cuBLAS call. 'instr' should be an Add // instruction in the following form: // Add(OptionalBitcast(OptionalSlice(gemm)), bias) - // where 'gemm' is expected to be a cuBLAS custom_call. Slices are introduced - // when the inputs of the gemm are possibly padded. + // where 'gemm' is expected to be a cuBLAS custom_call. Slice is introduced + // when the inputs of the gemm are possibly padded. Bitcast is introduced to + // handle high rank input. Status FuseMatrixBiasAdd(HloInstruction *instr, HloInstruction *bias, const HloInstruction *gemm, HloInstruction *bitcast = nullptr, @@ -1074,6 +1073,15 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { return OkStatus(); } + // To ensure numerical correctness, only slices that chop off the ends of + // dimensions are supported. + if (slice) { + int slice_op_dim = slice->operand(0)->shape().rank(); + if (slice->slice_starts() != std::vector(slice_op_dim, 0) || + slice->slice_strides() != std::vector(slice_op_dim, 1)) { + return OkStatus(); + } + } // Cublas gemm overwrites the bias matrix, so fusion is only possible if the // gemm is the only user. CublasLt gemm can operate out-of-place. bool can_overwrite_bias = [bias]() { @@ -1124,16 +1132,18 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { std::vector operands(gemm->operands().begin(), gemm->operands().end()); HloInstruction *maybe_constant_folded_bias = MaybeConstantFoldBias(bias); - if (slice && bitcast) { + // if (slice && bitcast) { + if (bitcast) { maybe_constant_folded_bias = instr->AddInstruction(HloInstruction::CreateBitcast( slice->shape(), maybe_constant_folded_bias)); } if (gemm->custom_call_target() == kCublasLtMatmulF8CallTarget || - gemm->custom_call_target() == kCublasLtMatmulCallTarget) { + gemm->custom_call_target() == kCublasLtMatmulCallTarget || + gemm->custom_call_target() == kGemmCallTarget) { maybe_constant_folded_bias = PadOperandToTargetShape( - gemm->shape(), instr, maybe_constant_folded_bias); + gemm->shape(), maybe_constant_folded_bias); } operands.insert(operands.begin() + 2, maybe_constant_folded_bias); @@ -1164,18 +1174,13 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { ->set_output_to_operand_aliasing({{{}, {2, {}}}}); } TF_RETURN_IF_ERROR(SetName(instr->GetModule(), fused_op.get())); - if (slice != nullptr) { - int slice_op_dim = slice->operand(0)->shape().rank(); - if (slice->slice_starts() != std::vector(slice_op_dim, 0)) { - return OkStatus(); - } - + if (slice) { fused_op = slice->CloneWithNewOperands( slice->shape(), {slice->parent()->AddInstruction(std::move(fused_op))}); } - if (bitcast != nullptr) { + if (bitcast) { fused_op = bitcast->CloneWithNewOperands( bitcast->shape(), {bitcast->parent()->AddInstruction(std::move(fused_op))}); @@ -1186,7 +1191,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { // Fuses a vector bias into a cuBLAS call. 'instr' should be an Add // instruction in the following form: - // Add(OptionalBitcast(OptionalSlice(gemm)), broadcast) + // Add(OptionalBitcast(OptionalSlice(gemm)), Broadcast(OptionalConvert())) // where 'gemm' is expected to be a cuBLAS custom_call. The optional // convert is only used for F8 matmuls as cublasLt has specific constraints // on the vector bias type for such matmuls. The optional bitcast is @@ -1197,7 +1202,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { HloInstruction *slice = nullptr, HloInstruction *convert = nullptr, HloInstruction *bitcast = nullptr) { - if (bitcast == nullptr) { + if (!bitcast) { TF_RET_CHECK(ShapeUtil::Compatible( broadcast->shape(), (slice ? slice->shape() : gemm->shape()))); } @@ -1288,9 +1293,9 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { // In the case of high rank input, it is necessary to consider potential // padding for the bias. - if (bitcast != nullptr && slice != nullptr) { + if (bitcast) { bias = PadOperandToMultipleOf16( - config.dot_dimension_numbers().rhs_batch_dimensions(), instr, bias); + config.dot_dimension_numbers().rhs_batch_dimensions(), bias); } // Replace add(gemm, broadcast) with fused new_gemm. operands.push_back(bias); @@ -1299,12 +1304,12 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { gemm->CloneWithNewOperands(gemm->shape(), operands); TF_RETURN_IF_ERROR(result->set_backend_config(config)); TF_RETURN_IF_ERROR(SetName(result->GetModule(), result.get())); - if (slice != nullptr) { + if (slice) { result = slice->CloneWithNewOperands( slice->shape(), {slice->parent()->AddInstruction(std::move(result))}); } - if (bitcast != nullptr) { + if (bitcast) { result = bitcast->CloneWithNewOperands( bitcast->shape(), {bitcast->parent()->AddInstruction(std::move(result))}); @@ -1341,7 +1346,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { TF_RETURN_IF_ERROR(result->set_backend_config(config)); TF_RETURN_IF_ERROR(SetName(result->GetModule(), result.get())); - if (slice_or_bitcast != nullptr) { + if (slice_or_bitcast) { result = slice_or_bitcast->CloneWithNewOperands( slice_or_bitcast->shape(), {slice_or_bitcast->parent()->AddInstruction(std::move(result))}); diff --git a/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc index 984a0cdc57f654..5fedc8a5fc2c3d 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc @@ -5639,6 +5639,8 @@ TEST_P(ParameterizedFp8GemmRewriteTest, )"); } +// Do not fuse matrix bias When there is a slice that does not chop off the ends +// of dimensions. TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDMatrixBiasWithSliceF8) { #if CUDA_VERSION < 12000 From 5cfd215b403aeeb90f48e8a7eb29120012b676ff Mon Sep 17 00:00:00 2001 From: Milos Puzovic Date: Fri, 21 Jul 2023 19:03:21 +0100 Subject: [PATCH 044/410] Change to comment to start with capital letter and end with fullstop --- tensorflow/core/common_runtime/mkl_layout_pass.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/common_runtime/mkl_layout_pass.cc b/tensorflow/core/common_runtime/mkl_layout_pass.cc index bc21b8b019352b..e612fe32994f37 100644 --- a/tensorflow/core/common_runtime/mkl_layout_pass.cc +++ b/tensorflow/core/common_runtime/mkl_layout_pass.cc @@ -432,7 +432,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { rinfothr_.push_back( {{csinfo_.conv2d, mkl_op_registry::GetMklOpName(csinfo_.conv2d), CopyAttrsConvCheckConstFilter, - std::function(), // we set this function to empty + std::function(), // We set this function to empty. GetRewriteCause()}, Conv2DRewrite}); rinfo_.push_back({csinfo_.conv2d_with_bias, From 7f67d86a9de6fe1aaed178ffa7f6e9545304e676 Mon Sep 17 00:00:00 2001 From: shuw Date: Fri, 21 Jul 2023 12:53:01 -0700 Subject: [PATCH 045/410] Remove comments --- .../compiler/xla/service/gpu/gemm_rewriter.cc | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc b/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc index 4e3b48a9dc5d27..26f6c4b0af9007 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc @@ -112,7 +112,7 @@ Shape PadShapeToMultipleOf16(const Shape old_shape, return padded_shape; } -// Pad the non-batch dimensions of the operands to the target shape. +// Pad the dimensions of the operands to the target shape. HloInstruction *PadOperandToTargetShape(const Shape &target, HloInstruction *x) { if (ShapeUtil::Equal(target, x->shape())) { @@ -1073,8 +1073,8 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { return OkStatus(); } - // To ensure numerical correctness, only slices that chop off the ends of - // dimensions are supported. + // To ensure correctness, only slices that chop off the ends of dimensions + // are supported. if (slice) { int slice_op_dim = slice->operand(0)->shape().rank(); if (slice->slice_starts() != std::vector(slice_op_dim, 0) || @@ -1132,19 +1132,14 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { std::vector operands(gemm->operands().begin(), gemm->operands().end()); HloInstruction *maybe_constant_folded_bias = MaybeConstantFoldBias(bias); - // if (slice && bitcast) { if (bitcast) { maybe_constant_folded_bias = instr->AddInstruction(HloInstruction::CreateBitcast( slice->shape(), maybe_constant_folded_bias)); } - if (gemm->custom_call_target() == kCublasLtMatmulF8CallTarget || - gemm->custom_call_target() == kCublasLtMatmulCallTarget || - gemm->custom_call_target() == kGemmCallTarget) { - maybe_constant_folded_bias = PadOperandToTargetShape( - gemm->shape(), maybe_constant_folded_bias); - } + maybe_constant_folded_bias = PadOperandToTargetShape( + gemm->shape(), maybe_constant_folded_bias); operands.insert(operands.begin() + 2, maybe_constant_folded_bias); From 3abb8b750a628d5a115737a5b335f519ec61f868 Mon Sep 17 00:00:00 2001 From: Milos Puzovic Date: Fri, 21 Jul 2023 22:00:34 +0100 Subject: [PATCH 046/410] Correct ordering to include headers for heuristics test --- tensorflow/core/util/mkl_heuristics_test.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/util/mkl_heuristics_test.cc b/tensorflow/core/util/mkl_heuristics_test.cc index 5dfdd1fa24a051..b45190721d246b 100644 --- a/tensorflow/core/util/mkl_heuristics_test.cc +++ b/tensorflow/core/util/mkl_heuristics_test.cc @@ -18,6 +18,8 @@ limitations under the License. #include "tensorflow/core/util/mkl_heuristics.h" +#include + #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/kernels/ops_testutil.h" #include "tensorflow/core/platform/test.h" @@ -117,4 +119,4 @@ TEST(MklHeuristicsTest, MklThresholds) { } // namespace } // namespace tensorflow -#endif // INTEL_MKL \ No newline at end of file +#endif // INTEL_MKL From a1f135efcaa05abb86ad7b1f5f91258cf9ae5e60 Mon Sep 17 00:00:00 2001 From: Milos Puzovic Date: Fri, 21 Jul 2023 22:16:36 +0100 Subject: [PATCH 047/410] Fix failure to build Mkl heuristics test --- tensorflow/core/util/mkl_heuristics.h | 6 +++++- tensorflow/core/util/mkl_heuristics_test.cc | 3 --- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tensorflow/core/util/mkl_heuristics.h b/tensorflow/core/util/mkl_heuristics.h index 162dbe81331c90..518712b85c7d5f 100644 --- a/tensorflow/core/util/mkl_heuristics.h +++ b/tensorflow/core/util/mkl_heuristics.h @@ -20,12 +20,16 @@ limitations under the License. #define TENSORFLOW_CORE_UTIL_MKL_HEURISTICS_H_ #ifdef INTEL_MKL +#include + +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/kernels/ops_testutil.h" #include "tensorflow/tsl/platform/cpu_info.h" namespace tensorflow { struct RewriteThreshold { - string op; + std::string op; int cpu_family; int cpu_model_num; // The model that is used to decide whether it is worth diff --git a/tensorflow/core/util/mkl_heuristics_test.cc b/tensorflow/core/util/mkl_heuristics_test.cc index b45190721d246b..9485a5247f915e 100644 --- a/tensorflow/core/util/mkl_heuristics_test.cc +++ b/tensorflow/core/util/mkl_heuristics_test.cc @@ -18,9 +18,6 @@ limitations under the License. #include "tensorflow/core/util/mkl_heuristics.h" -#include - -#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/kernels/ops_testutil.h" #include "tensorflow/core/platform/test.h" From 8a30aa731c21612fe098a6b620a54922578611c2 Mon Sep 17 00:00:00 2001 From: Philipp Hack Date: Wed, 7 Jun 2023 20:37:48 +0000 Subject: [PATCH 048/410] Support for FP8 convolutions in XLA. --- .../xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.td | 15 +- tensorflow/compiler/xla/service/gpu/BUILD | 4 +- .../xla/service/gpu/backend_configs.proto | 4 + .../xla/service/gpu/buffer_comparator.cc | 1329 +++++++++++------ .../xla/service/gpu/conv_algorithm_picker.cc | 37 +- .../xla/service/gpu/conv_algorithm_picker.h | 2 +- .../service/gpu/conv_layout_normalization.cc | 3 +- .../xla/service/gpu/convolution_thunk.cc | 4 +- .../xla/service/gpu/convolution_thunk.h | 5 +- .../compiler/xla/service/gpu/cublas_cudnn.cc | 8 + .../compiler/xla/service/gpu/cublas_cudnn.h | 3 + .../service/gpu/cudnn_fused_conv_rewriter.cc | 242 ++- .../gpu/cudnn_fused_conv_rewriter_test.cc | 188 +++ .../service/gpu/cudnn_pad_for_convolutions.cc | 1 + .../gpu/gpu_conv_padding_legalization.cc | 5 +- .../xla/service/gpu/gpu_conv_runner.cc | 75 + .../xla/service/gpu/gpu_conv_runner.h | 57 +- .../xla/service/gpu/gpu_layout_assignment.cc | 39 +- .../xla/service/gpu/ir_emitter_unnested.cc | 9 +- .../compiler/xla/service/gpu/runtime/conv.h | 2 +- .../xla/service/gpu/stream_executor_util.cc | 6 + .../compiler/xla/service/layout_assignment.cc | 4 +- .../xla/stream_executor/cuda/cuda_dnn.cc | 439 +++++- .../xla/stream_executor/cuda/cuda_dnn.h | 24 + .../compiler/xla/stream_executor/dnn.cc | 30 + tensorflow/compiler/xla/stream_executor/dnn.h | 33 + .../xla/stream_executor/lazy_op_runner.h | 26 + .../compiler/xla/stream_executor/stream.h | 19 + .../stream_executor/stream_executor_pimpl.cc | 22 + .../stream_executor/stream_executor_pimpl.h | 14 + .../mhlo_to_lhlo_with_xla.cc | 7 + tensorflow/core/kernels/conv_ops_gpu.cc | 1 + tensorflow/tsl/protobuf/dnn.proto | 1 + 33 files changed, 2155 insertions(+), 503 deletions(-) diff --git a/tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.td b/tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.td index 8e656ea706406f..b76587b9648d96 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.td +++ b/tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.td @@ -49,7 +49,7 @@ class GpuConvolutionAttributes { } // Provide a custom assembly format for all LHLO_GPU convolution operations. -class LHLOGPU_ConvBaseOp : LHLOGPU_Op { +class LHLOGPU_ConvBaseOp traits = []> : LHLOGPU_Op { let assemblyFormat = [{ `(`operands`)` `dim_numbers` `=` custom($dimension_numbers) `,` @@ -140,6 +140,19 @@ def LHLOGPU_CudnnConvReorderFilterAndBiasOp : I64ElementsAttr:$filter_dims); } +def LHLOGPU_ConvForwardGraphOp : + LHLOGPU_ConvBaseOp<"conv_forward_graph", [AttrSizedOperandSegments]> { + let arguments = !con( + (ins + Arg:$input, + Arg:$filter, + Arg, "", [MemRead]>:$binary_operands, + Arg:$output, + Arg:$scratch), + GpuConvolutionAttributes<(ins + StrAttr:$serialized_graph)>.attributes); +} + //===----------------------------------------------------------------------===// // LMHLO ops representing other library functions. //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 13b122e93d849f..7d84276d40126c 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -3082,7 +3082,9 @@ cc_library( "//tensorflow/compiler/xla/stream_executor:stream_executor_headers", "//tensorflow/tsl/platform:errors", "//tensorflow/tsl/platform:statusor", - ], + ] + if_cuda_is_configured([ + "@local_config_cuda//cuda:cuda_headers", + ]), ) xla_cc_test( diff --git a/tensorflow/compiler/xla/service/gpu/backend_configs.proto b/tensorflow/compiler/xla/service/gpu/backend_configs.proto index fcce88b29bb82b..1236d07be4cb6d 100644 --- a/tensorflow/compiler/xla/service/gpu/backend_configs.proto +++ b/tensorflow/compiler/xla/service/gpu/backend_configs.proto @@ -57,6 +57,10 @@ message CudnnConvBackendConfig { // compatible with NVidia's IMMA instruction (sm75+). bool reordered_int8_nchw_vect = 7; } + + // Serialization of the graph described by the convolution and adjacent + // pointwise ops. + optional string serialized_graph = 8; } // Backend config for the GEMM operation running through cuBLAS. diff --git a/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc index 3f1de6468e4200..bc36fcf8eeb19b 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc +++ b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc @@ -38,562 +38,977 @@ namespace gpu { static constexpr double kTolerance = 0.1f; // Comparison kernel code: compare two buffers of -// bf16/fp16/fp32/fp64/int8_t/int32_t of length buffer_length where the relative -// error does not exceed the passed rel_error_threshold. Write the number of -// mismatches into out parameter mismatch_count. -// +// fp8/bf16/fp16/fp32/fp64/int8_t/int32_t of length buffer_length where the +// relative error does not exceed the passed rel_error_threshold. Write the +// number of mismatches into out parameter mismatch_count. + // NaN's are considered equal, and for half's we clamp all numbers to largest // and smallest numbers representable to avoid miscomparisons due to overflows. -// + // The PTX below is compiled from the following CUDA code: -// -// #include + // #include -// +// #include +// #include + // namespace { -// + // __device__ __inline__ float __xla_buffer_comparator_canonicalize(float input) // { // // All fp16 infinities are treated as 65505 or -65505, in order to avoid // // differences due to overflows. // return isnan(input) ? input : max(-65505.0f, min(input, 65505.0f)); // } -// + // } // end anonymous namespace -// -// extern "C" { // avoid name mangling -// -// -// __global__ void __xla_fp16_comparison(__half* buffer_a, __half* buffer_b, + +// extern "C" { // avoid name mangling + +// __global__ void __xla_fp8_e4m3fn_comparison(__nv_fp8_storage_t *buffer_a, +// __nv_fp8_storage_t *buffer_b, +// float rel_error_threshold, +// unsigned long long buffer_length, +// int *mismatch_count) { +// int idx = threadIdx.x + blockIdx.x * blockDim.x; +// if (idx >= buffer_length) +// return; +// // TODO(philipphack): Replace with direct conversion to float when this +// // functionality becomes availabe. +// float elem_a = +// __half2float(__nv_cvt_fp8_to_halfraw(buffer_a[idx], __NV_E4M3)); +// float elem_b = +// __half2float(__nv_cvt_fp8_to_halfraw(buffer_b[idx], __NV_E4M3)); +// elem_a = __xla_buffer_comparator_canonicalize(elem_a); +// elem_b = __xla_buffer_comparator_canonicalize(elem_b); +// if (isnan(elem_a) && isnan(elem_b)) +// return; + +// float rel_error = abs(elem_a - elem_b) / (max(abs(elem_a), abs(elem_b)) + +// 1); + +// if (rel_error > rel_error_threshold || isnan(rel_error)) +// atomicAdd(mismatch_count, 1); +// } + +// __global__ void __xla_fp8_e5m2_comparison(__nv_fp8_storage_t *buffer_a, +// __nv_fp8_storage_t *buffer_b, +// float rel_error_threshold, +// unsigned long long buffer_length, +// int *mismatch_count) { +// int idx = threadIdx.x + blockIdx.x * blockDim.x; +// if (idx >= buffer_length) +// return; +// // TODO(philipphack): Replace with direct conversion to float when this +// // functionality becomes availabe. +// float elem_a = +// __half2float(__nv_cvt_fp8_to_halfraw(buffer_a[idx], __NV_E5M2)); +// float elem_b = +// __half2float(__nv_cvt_fp8_to_halfraw(buffer_b[idx], __NV_E5M2)); +// elem_a = __xla_buffer_comparator_canonicalize(elem_a); +// elem_b = __xla_buffer_comparator_canonicalize(elem_b); +// if (isnan(elem_a) && isnan(elem_b)) +// return; + +// float rel_error = abs(elem_a - elem_b) / (max(abs(elem_a), abs(elem_b)) + +// 1); + +// if (rel_error > rel_error_threshold || isnan(rel_error)) +// atomicAdd(mismatch_count, 1); +// } + +// __global__ void __xla_fp16_comparison(__half *buffer_a, __half *buffer_b, // float rel_error_threshold, // unsigned long long buffer_length, -// int* mismatch_count) { +// int *mismatch_count) { // int idx = threadIdx.x + blockIdx.x * blockDim.x; -// if (idx >= buffer_length) return; +// if (idx >= buffer_length) +// return; // float elem_a = __half2float(buffer_a[idx]); // float elem_b = __half2float(buffer_b[idx]); // elem_a = __xla_buffer_comparator_canonicalize(elem_a); // elem_b = __xla_buffer_comparator_canonicalize(elem_b); -// if (isnan(elem_a) && isnan(elem_b)) return; -// -// float rel_error = abs(elem_a - elem_b) -// / (max(abs(elem_a), abs(elem_b)) + 1); -// +// if (isnan(elem_a) && isnan(elem_b)) +// return; + +// float rel_error = abs(elem_a - elem_b) / (max(abs(elem_a), abs(elem_b)) + +// 1); + // if (rel_error > rel_error_threshold || isnan(rel_error)) // atomicAdd(mismatch_count, 1); // } -// -// __global__ void __xla_fp32_comparison(float* buffer_a, float* buffer_b, + +// __global__ void __xla_fp32_comparison(float *buffer_a, float *buffer_b, // float rel_error_threshold, // unsigned long long buffer_length, -// int* mismatch_count) { +// int *mismatch_count) { // int idx = threadIdx.x + blockIdx.x * blockDim.x; -// if (idx >= buffer_length) return; +// if (idx >= buffer_length) +// return; // float elem_a = buffer_a[idx]; // float elem_b = buffer_b[idx]; -// if (isnan(elem_a) && isnan(elem_b)) return; +// if (isnan(elem_a) && isnan(elem_b)) +// return; // if (isinf(elem_a) && isinf(elem_b) && signbit(elem_a) == signbit(elem_b)) // return; -// -// float rel_error = abs(elem_a - elem_b) -// / (max(abs(elem_a), abs(elem_b)) + 1); -// if (rel_error > rel_error_threshold || isnan(rel_error)) + +// float rel_error = abs(elem_a - elem_b) / (max(abs(elem_a), abs(elem_b)) + +// 1); if (rel_error > rel_error_threshold || isnan(rel_error)) // atomicAdd(mismatch_count, 1); // } -// -// __global__ void __xla_fp64_comparison(double* buffer_a, double* buffer_b, + +// __global__ void __xla_fp64_comparison(double *buffer_a, double *buffer_b, // float rel_error_threshold, // unsigned long long buffer_length, -// int* mismatch_count) { +// int *mismatch_count) { // int idx = threadIdx.x + blockIdx.x * blockDim.x; -// if (idx >= buffer_length) return; -// +// if (idx >= buffer_length) +// return; + // double elem_a = buffer_a[idx]; // double elem_b = buffer_b[idx]; -// if (isnan(elem_a) && isnan(elem_b)) return; +// if (isnan(elem_a) && isnan(elem_b)) +// return; // if (isinf(elem_a) && isinf(elem_b) && signbit(elem_a) == signbit(elem_b)) // return; -// double rel_error = abs(elem_a - elem_b) -// / (max(abs(elem_a), abs(elem_b)) + 1); -// if (rel_error > rel_error_threshold || isnan(rel_error)) +// double rel_error = abs(elem_a - elem_b) / (max(abs(elem_a), abs(elem_b)) + +// 1); if (rel_error > rel_error_threshold || isnan(rel_error)) // atomicAdd(mismatch_count, 1); // } -// -// __global__ void __xla_bf16_comparison(__nv_bfloat16* buffer_a, -// __nv_bfloat16* buffer_b, + +// __global__ void __xla_bf16_comparison(__nv_bfloat16 *buffer_a, +// __nv_bfloat16 *buffer_b, // float rel_error_threshold, // unsigned long long buffer_length, -// int* mismatch_count) { +// int *mismatch_count) { // int idx = threadIdx.x + blockIdx.x * blockDim.x; -// if (idx >= buffer_length) return; +// if (idx >= buffer_length) +// return; // float elem_a = __bfloat162float(buffer_a[idx]); // float elem_b = __bfloat162float(buffer_b[idx]); // elem_a = __xla_buffer_comparator_canonicalize(elem_a); // elem_b = __xla_buffer_comparator_canonicalize(elem_b); -// if (isnan(elem_a) && isnan(elem_b)) return; -// -// float rel_error = abs(elem_a - elem_b) -// / (max(abs(elem_a), abs(elem_b)) + 1); -// +// if (isnan(elem_a) && isnan(elem_b)) +// return; + +// float rel_error = abs(elem_a - elem_b) / (max(abs(elem_a), abs(elem_b)) + +// 1); + // if (rel_error > rel_error_threshold || isnan(rel_error)) // atomicAdd(mismatch_count, 1); // } -// + // // TODO(b/191520348): The comparison below requires exact equality. -// __global__ void __xla_int8_comparison(int8_t* buffer_a, int8_t* buffer_b, +// __global__ void __xla_int8_comparison(int8_t *buffer_a, int8_t *buffer_b, // float rel_error_threshold, // unsigned long long buffer_length, -// int* mismatch_count) { +// int *mismatch_count) { // int idx = threadIdx.x + blockIdx.x * blockDim.x; -// if (idx >= buffer_length) return; +// if (idx >= buffer_length) +// return; // float a = buffer_a[idx]; // float b = buffer_b[idx]; // float rel_error = abs(a - b) / (max(abs(a), abs(b)) + 1); // if (rel_error > rel_error_threshold || isnan(rel_error)) -// atomicAdd(mismatch_count, 1); +// atomicAdd(mismatch_count, 1); // } -// -// __global__ void __xla_int32_comparison(int* buffer_a, int* buffer_b, + +// __global__ void __xla_int32_comparison(int *buffer_a, int *buffer_b, // float rel_error_threshold, // unsigned long long buffer_length, -// int* mismatch_count) { +// int *mismatch_count) { // int idx = threadIdx.x + blockIdx.x * blockDim.x; -// if (idx >= buffer_length) return; +// if (idx >= buffer_length) +// return; // float elem_a = static_cast(buffer_a[idx]); // float elem_b = static_cast(buffer_b[idx]); -// float rel_error = abs(elem_a - elem_b) -// / (max(abs(elem_a), abs(elem_b)) + 1); -// if (rel_error > rel_error_threshold || isnan(rel_error)) +// float rel_error = abs(elem_a - elem_b) / (max(abs(elem_a), abs(elem_b)) + +// 1); if (rel_error > rel_error_threshold || isnan(rel_error)) // atomicAdd(mismatch_count, 1); // } // } // end extern declaration static const char* buffer_compare_ptx = R"( // -// Generated by LLVM NVPTX Back-End +// Generated by NVIDIA NVVM Compiler +// +// Compiler Build ID: CL-32415258 +// Cuda compilation tools, release 12.1, V12.1.66 +// Based on NVVM 7.0.1 // -.version 4.2 -.target sm_30 +.version 8.1 +.target sm_50 .address_size 64 -// .globl__xla_fp16_comparison + // .globl __xla_fp8_e4m3fn_comparison -.visible .entry __xla_fp16_comparison( -.param .u64 __xla_fp16_comparison_param_0, -.param .u64 __xla_fp16_comparison_param_1, -.param .f32 __xla_fp16_comparison_param_2, -.param .u64 __xla_fp16_comparison_param_3, -.param .u64 __xla_fp16_comparison_param_4 +.visible .entry __xla_fp8_e4m3fn_comparison( + .param .u64 __xla_fp8_e4m3fn_comparison_param_0, + .param .u64 __xla_fp8_e4m3fn_comparison_param_1, + .param .f32 __xla_fp8_e4m3fn_comparison_param_2, + .param .u64 __xla_fp8_e4m3fn_comparison_param_3, + .param .u64 __xla_fp8_e4m3fn_comparison_param_4 ) { -.reg .pred %p<10>; -.reg .b16 %rs<3>; -.reg .f32 %f<20>; -.reg .b32 %r<6>; -.reg .b64 %rd<12>; - -ld.param.u64 %rd8, [__xla_fp16_comparison_param_3]; -mov.u32 %r1, %tid.x; -mov.u32 %r2, %ctaid.x; -mov.u32 %r3, %ntid.x; -mad.lo.s32 %r4, %r3, %r2, %r1; -cvt.s64.s32 %rd4, %r4; -setp.ge.u64 %p1, %rd4, %rd8; -@%p1 bra LBB0_4; -ld.param.u64 %rd5, [__xla_fp16_comparison_param_0]; -ld.param.u64 %rd7, [__xla_fp16_comparison_param_1]; -cvta.to.global.u64 %rd2, %rd7; -cvta.to.global.u64 %rd3, %rd5; -shl.b64 %rd9, %rd4, 1; -add.s64 %rd10, %rd3, %rd9; -ld.global.u16 %rs1, [%rd10]; -// begin inline asm -{ cvt.f32.f16 %f6, %rs1;} - -// end inline asm -add.s64 %rd11, %rd2, %rd9; -ld.global.u16 %rs2, [%rd11]; -// begin inline asm -{ cvt.f32.f16 %f7, %rs2;} - -// end inline asm -abs.f32 %f8, %f6; -setp.gtu.f32 %p2, %f8, 0f7F800000; -min.f32 %f9, %f6, 0f477FE100; -max.f32 %f10, %f9, 0fC77FE100; -selp.f32 %f1, %f6, %f10, %p2; -abs.f32 %f11, %f7; -setp.gtu.f32 %p3, %f11, 0f7F800000; -min.f32 %f12, %f7, 0f477FE100; -max.f32 %f13, %f12, 0fC77FE100; -selp.f32 %f2, %f7, %f13, %p3; -abs.f32 %f3, %f1; -setp.gtu.f32 %p4, %f3, 0f7F800000; -abs.f32 %f4, %f2; -setp.gtu.f32 %p5, %f4, 0f7F800000; -and.pred %p6, %p4, %p5; -@%p6 bra LBB0_4; -ld.param.f32 %f5, [__xla_fp16_comparison_param_2]; -sub.f32 %f14, %f1, %f2; -abs.f32 %f15, %f14; -max.f32 %f16, %f3, %f4; -add.f32 %f17, %f16, 0f3F800000; -div.rn.f32 %f18, %f15, %f17; -setp.gt.f32 %p7, %f18, %f5; -abs.f32 %f19, %f18; -setp.gtu.f32 %p8, %f19, 0f7F800000; -or.pred %p9, %p7, %p8; -@!%p9 bra LBB0_4; -bra.uni LBB0_3; -LBB0_3: -ld.param.u64 %rd6, [__xla_fp16_comparison_param_4]; -cvta.to.global.u64 %rd1, %rd6; -atom.global.add.u32 %r5, [%rd1], 1; -LBB0_4: -ret; + .reg .pred %p<19>; + .reg .b16 %rs<73>; + .reg .f32 %f<30>; + .reg .b32 %r<6>; + .reg .b64 %rd<11>; + + + ld.param.u64 %rd2, [__xla_fp8_e4m3fn_comparison_param_0]; + ld.param.u64 %rd3, [__xla_fp8_e4m3fn_comparison_param_1]; + ld.param.f32 %f12, [__xla_fp8_e4m3fn_comparison_param_2]; + ld.param.u64 %rd5, [__xla_fp8_e4m3fn_comparison_param_3]; + ld.param.u64 %rd4, [__xla_fp8_e4m3fn_comparison_param_4]; + mov.u32 %r1, %ntid.x; + mov.u32 %r2, %ctaid.x; + mov.u32 %r3, %tid.x; + mad.lo.s32 %r4, %r2, %r1, %r3; + cvt.s64.s32 %rd1, %r4; + setp.ge.u64 %p1, %rd1, %rd5; + @%p1 bra $L__BB0_27; + + cvta.to.global.u64 %rd6, %rd2; + add.s64 %rd7, %rd6, %rd1; + ld.global.u8 %rs1, [%rd7]; + shl.b16 %rs36, %rs1, 8; + and.b16 %rs2, %rs36, -32768; + and.b16 %rs3, %rs36, 30720; + shr.u16 %rs37, %rs3, 1; + add.s16 %rs61, %rs37, 8192; + and.b16 %rs62, %rs36, 1792; + and.b16 %rs38, %rs1, 127; + setp.eq.s16 %p2, %rs38, 127; + mov.u16 %rs72, 32767; + mov.u16 %rs65, %rs72; + @%p2 bra $L__BB0_10; + + setp.eq.s16 %p3, %rs3, 0; + @%p3 bra $L__BB0_4; + + shr.u16 %rs39, %rs62, 1; + or.b16 %rs40, %rs39, %rs2; + or.b16 %rs65, %rs40, %rs61; + bra.uni $L__BB0_10; + +$L__BB0_4: + setp.eq.s16 %p4, %rs62, 0; + mov.u16 %rs63, 0; + mov.u16 %rs64, %rs63; + @%p4 bra $L__BB0_9; + + and.b16 %rs43, %rs1, 4; + setp.ne.s16 %p5, %rs43, 0; + @%p5 bra $L__BB0_8; + + mov.u16 %rs59, %rs62; + +$L__BB0_7: + shl.b16 %rs62, %rs59, 1; + add.s16 %rs61, %rs61, -1024; + and.b16 %rs44, %rs59, 512; + setp.eq.s16 %p6, %rs44, 0; + mov.u16 %rs59, %rs62; + @%p6 bra $L__BB0_7; + +$L__BB0_8: + and.b16 %rs64, %rs62, 1022; + mov.u16 %rs63, %rs61; + +$L__BB0_9: + or.b16 %rs45, %rs63, %rs2; + or.b16 %rs65, %rs45, %rs64; + +$L__BB0_10: + // begin inline asm + { cvt.f32.f16 %f27, %rs65;} + + // end inline asm + cvta.to.global.u64 %rd8, %rd3; + add.s64 %rd9, %rd8, %rd1; + ld.global.u8 %rs18, [%rd9]; + shl.b16 %rs48, %rs18, 8; + and.b16 %rs19, %rs48, -32768; + and.b16 %rs20, %rs48, 30720; + shr.u16 %rs49, %rs20, 1; + add.s16 %rs68, %rs49, 8192; + and.b16 %rs69, %rs48, 1792; + and.b16 %rs50, %rs18, 127; + setp.eq.s16 %p7, %rs50, 127; + @%p7 bra $L__BB0_19; + + setp.eq.s16 %p8, %rs20, 0; + @%p8 bra $L__BB0_13; + + shr.u16 %rs51, %rs69, 1; + or.b16 %rs52, %rs51, %rs19; + or.b16 %rs72, %rs52, %rs68; + bra.uni $L__BB0_19; + +$L__BB0_13: + setp.eq.s16 %p9, %rs69, 0; + mov.u16 %rs70, 0; + mov.u16 %rs71, %rs70; + @%p9 bra $L__BB0_18; + + and.b16 %rs55, %rs18, 4; + setp.ne.s16 %p10, %rs55, 0; + @%p10 bra $L__BB0_17; + + mov.u16 %rs66, %rs69; + +$L__BB0_16: + shl.b16 %rs69, %rs66, 1; + add.s16 %rs68, %rs68, -1024; + and.b16 %rs56, %rs66, 512; + setp.eq.s16 %p11, %rs56, 0; + mov.u16 %rs66, %rs69; + @%p11 bra $L__BB0_16; + +$L__BB0_17: + and.b16 %rs71, %rs69, 1022; + mov.u16 %rs70, %rs68; + +$L__BB0_18: + or.b16 %rs57, %rs70, %rs19; + or.b16 %rs72, %rs57, %rs71; + +$L__BB0_19: + // begin inline asm + { cvt.f32.f16 %f29, %rs72;} + + // end inline asm + abs.f32 %f15, %f27; + setp.gtu.f32 %p12, %f15, 0f7F800000; + @%p12 bra $L__BB0_21; + + mov.f32 %f16, 0f477FE100; + min.f32 %f17, %f27, %f16; + mov.f32 %f18, 0fC77FE100; + max.f32 %f27, %f18, %f17; + +$L__BB0_21: + abs.f32 %f28, %f29; + setp.gtu.f32 %p13, %f28, 0f7F800000; + @%p13 bra $L__BB0_23; + + mov.f32 %f19, 0f477FE100; + min.f32 %f20, %f29, %f19; + mov.f32 %f21, 0fC77FE100; + max.f32 %f29, %f21, %f20; + abs.f32 %f28, %f29; + +$L__BB0_23: + abs.f32 %f10, %f27; + setp.gtu.f32 %p14, %f10, 0f7F800000; + setp.gtu.f32 %p15, %f28, 0f7F800000; + and.pred %p16, %p14, %p15; + @%p16 bra $L__BB0_27; + + sub.f32 %f22, %f27, %f29; + abs.f32 %f23, %f22; + max.f32 %f24, %f10, %f28; + add.f32 %f25, %f24, 0f3F800000; + div.rn.f32 %f11, %f23, %f25; + setp.gt.f32 %p17, %f11, %f12; + @%p17 bra $L__BB0_26; + + abs.f32 %f26, %f11; + setp.le.f32 %p18, %f26, 0f7F800000; + @%p18 bra $L__BB0_27; + +$L__BB0_26: + cvta.to.global.u64 %rd10, %rd4; + atom.global.add.u32 %r5, [%rd10], 1; + +$L__BB0_27: + ret; } -// .globl__xla_fp32_comparison -.visible .entry __xla_fp32_comparison( -.param .u64 __xla_fp32_comparison_param_0, -.param .u64 __xla_fp32_comparison_param_1, -.param .f32 __xla_fp32_comparison_param_2, -.param .u64 __xla_fp32_comparison_param_3, -.param .u64 __xla_fp32_comparison_param_4 + // .globl __xla_fp8_e5m2_comparison +.visible .entry __xla_fp8_e5m2_comparison( + .param .u64 __xla_fp8_e5m2_comparison_param_0, + .param .u64 __xla_fp8_e5m2_comparison_param_1, + .param .f32 __xla_fp8_e5m2_comparison_param_2, + .param .u64 __xla_fp8_e5m2_comparison_param_3, + .param .u64 __xla_fp8_e5m2_comparison_param_4 ) { -.reg .pred %p<12>; -.reg .f32 %f<12>; -.reg .b32 %r<9>; -.reg .b64 %rd<12>; - -ld.param.u64 %rd8, [__xla_fp32_comparison_param_3]; -mov.u32 %r1, %tid.x; -mov.u32 %r2, %ctaid.x; -mov.u32 %r3, %ntid.x; -mad.lo.s32 %r4, %r3, %r2, %r1; -cvt.s64.s32 %rd4, %r4; -setp.ge.u64 %p1, %rd4, %rd8; -@%p1 bra LBB1_6; -ld.param.u64 %rd5, [__xla_fp32_comparison_param_0]; -ld.param.u64 %rd7, [__xla_fp32_comparison_param_1]; -cvta.to.global.u64 %rd2, %rd7; -cvta.to.global.u64 %rd3, %rd5; -shl.b64 %rd9, %rd4, 2; -add.s64 %rd10, %rd3, %rd9; -ld.global.f32 %f1, [%rd10]; -add.s64 %rd11, %rd2, %rd9; -ld.global.f32 %f2, [%rd11]; -abs.f32 %f3, %f1; -setp.gtu.f32 %p2, %f3, 0f7F800000; -abs.f32 %f4, %f2; -setp.gtu.f32 %p3, %f4, 0f7F800000; -and.pred %p4, %p2, %p3; -@%p4 bra LBB1_6; -setp.eq.f32 %p5, %f3, 0f7F800000; -setp.eq.f32 %p6, %f4, 0f7F800000; -and.pred %p7, %p5, %p6; -@!%p7 bra LBB1_4; -bra.uni LBB1_3; -LBB1_3: -mov.b32 %r5, %f1; -mov.b32 %r6, %f2; -xor.b32 %r7, %r6, %r5; -setp.gt.s32 %p8, %r7, -1; -@%p8 bra LBB1_6; -LBB1_4: -ld.param.f32 %f5, [__xla_fp32_comparison_param_2]; -sub.f32 %f6, %f1, %f2; -abs.f32 %f7, %f6; -max.f32 %f8, %f3, %f4; -add.f32 %f9, %f8, 0f3F800000; -div.rn.f32 %f10, %f7, %f9; -setp.gt.f32 %p9, %f10, %f5; -abs.f32 %f11, %f10; -setp.gtu.f32 %p10, %f11, 0f7F800000; -or.pred %p11, %p9, %p10; -@!%p11 bra LBB1_6; -bra.uni LBB1_5; -LBB1_5: -ld.param.u64 %rd6, [__xla_fp32_comparison_param_4]; -cvta.to.global.u64 %rd1, %rd6; -atom.global.add.u32 %r8, [%rd1], 1; -LBB1_6: -ret; + .reg .pred %p<11>; + .reg .b16 %rs<9>; + .reg .f32 %f<30>; + .reg .b32 %r<6>; + .reg .b64 %rd<11>; + + + ld.param.u64 %rd2, [__xla_fp8_e5m2_comparison_param_0]; + ld.param.u64 %rd3, [__xla_fp8_e5m2_comparison_param_1]; + ld.param.f32 %f12, [__xla_fp8_e5m2_comparison_param_2]; + ld.param.u64 %rd5, [__xla_fp8_e5m2_comparison_param_3]; + ld.param.u64 %rd4, [__xla_fp8_e5m2_comparison_param_4]; + mov.u32 %r1, %ntid.x; + mov.u32 %r2, %ctaid.x; + mov.u32 %r3, %tid.x; + mad.lo.s32 %r4, %r2, %r1, %r3; + cvt.s64.s32 %rd1, %r4; + setp.ge.u64 %p1, %rd1, %rd5; + @%p1 bra $L__BB1_9; + + cvta.to.global.u64 %rd6, %rd2; + add.s64 %rd7, %rd6, %rd1; + ld.global.u8 %rs3, [%rd7]; + shl.b16 %rs4, %rs3, 8; + and.b16 %rs5, %rs3, 127; + setp.gt.u16 %p2, %rs5, 124; + selp.b16 %rs1, 32767, %rs4, %p2; + // begin inline asm + { cvt.f32.f16 %f27, %rs1;} + + // end inline asm + cvta.to.global.u64 %rd8, %rd3; + add.s64 %rd9, %rd8, %rd1; + ld.global.u8 %rs6, [%rd9]; + shl.b16 %rs7, %rs6, 8; + and.b16 %rs8, %rs6, 127; + setp.gt.u16 %p3, %rs8, 124; + selp.b16 %rs2, 32767, %rs7, %p3; + // begin inline asm + { cvt.f32.f16 %f29, %rs2;} + + // end inline asm + abs.f32 %f15, %f27; + setp.gtu.f32 %p4, %f15, 0f7F800000; + @%p4 bra $L__BB1_3; + + mov.f32 %f16, 0f477FE100; + min.f32 %f17, %f27, %f16; + mov.f32 %f18, 0fC77FE100; + max.f32 %f27, %f18, %f17; + +$L__BB1_3: + abs.f32 %f28, %f29; + setp.gtu.f32 %p5, %f28, 0f7F800000; + @%p5 bra $L__BB1_5; + + mov.f32 %f19, 0f477FE100; + min.f32 %f20, %f29, %f19; + mov.f32 %f21, 0fC77FE100; + max.f32 %f29, %f21, %f20; + abs.f32 %f28, %f29; + +$L__BB1_5: + abs.f32 %f10, %f27; + setp.gtu.f32 %p6, %f10, 0f7F800000; + setp.gtu.f32 %p7, %f28, 0f7F800000; + and.pred %p8, %p6, %p7; + @%p8 bra $L__BB1_9; + + sub.f32 %f22, %f27, %f29; + abs.f32 %f23, %f22; + max.f32 %f24, %f10, %f28; + add.f32 %f25, %f24, 0f3F800000; + div.rn.f32 %f11, %f23, %f25; + setp.gt.f32 %p9, %f11, %f12; + @%p9 bra $L__BB1_8; + + abs.f32 %f26, %f11; + setp.le.f32 %p10, %f26, 0f7F800000; + @%p10 bra $L__BB1_9; + +$L__BB1_8: + cvta.to.global.u64 %rd10, %rd4; + atom.global.add.u32 %r5, [%rd10], 1; + +$L__BB1_9: + ret; } -// .globl__xla_fp64_comparison -.visible .entry __xla_fp64_comparison( -.param .u64 __xla_fp64_comparison_param_0, -.param .u64 __xla_fp64_comparison_param_1, -.param .f32 __xla_fp64_comparison_param_2, -.param .u64 __xla_fp64_comparison_param_3, -.param .u64 __xla_fp64_comparison_param_4 + // .globl __xla_fp16_comparison +.visible .entry __xla_fp16_comparison( + .param .u64 __xla_fp16_comparison_param_0, + .param .u64 __xla_fp16_comparison_param_1, + .param .f32 __xla_fp16_comparison_param_2, + .param .u64 __xla_fp16_comparison_param_3, + .param .u64 __xla_fp16_comparison_param_4 ) { -.reg .pred %p<16>; -.reg .f32 %f<2>; -.reg .b32 %r<13>; -.reg .f64 %fd<12>; -.reg .b64 %rd<12>; - -ld.param.u64 %rd8, [__xla_fp64_comparison_param_3]; -mov.u32 %r2, %tid.x; -mov.u32 %r3, %ctaid.x; -mov.u32 %r4, %ntid.x; -mad.lo.s32 %r5, %r4, %r3, %r2; -cvt.s64.s32 %rd4, %r5; -setp.ge.u64 %p1, %rd4, %rd8; -@%p1 bra LBB2_6; -ld.param.u64 %rd5, [__xla_fp64_comparison_param_0]; -ld.param.u64 %rd7, [__xla_fp64_comparison_param_1]; -cvta.to.global.u64 %rd2, %rd7; -cvta.to.global.u64 %rd3, %rd5; -shl.b64 %rd9, %rd4, 3; -add.s64 %rd10, %rd3, %rd9; -ld.global.f64 %fd1, [%rd10]; -add.s64 %rd11, %rd2, %rd9; -ld.global.f64 %fd2, [%rd11]; -abs.f64 %fd3, %fd1; -setp.gtu.f64 %p2, %fd3, 0d7FF0000000000000; -abs.f64 %fd4, %fd2; -setp.gtu.f64 %p3, %fd4, 0d7FF0000000000000; -and.pred %p4, %p2, %p3; -@%p4 bra LBB2_6; -{ -.reg .b32 %temp; -mov.b64 {%r6, %temp}, %fd1; -} -{ -.reg .b32 %temp; -mov.b64 {%temp, %r1}, %fd1; + .reg .pred %p<9>; + .reg .b16 %rs<3>; + .reg .f32 %f<30>; + .reg .b32 %r<6>; + .reg .b64 %rd<12>; + + + ld.param.u64 %rd2, [__xla_fp16_comparison_param_0]; + ld.param.u64 %rd3, [__xla_fp16_comparison_param_1]; + ld.param.f32 %f12, [__xla_fp16_comparison_param_2]; + ld.param.u64 %rd5, [__xla_fp16_comparison_param_3]; + ld.param.u64 %rd4, [__xla_fp16_comparison_param_4]; + mov.u32 %r1, %ntid.x; + mov.u32 %r2, %ctaid.x; + mov.u32 %r3, %tid.x; + mad.lo.s32 %r4, %r2, %r1, %r3; + cvt.s64.s32 %rd1, %r4; + setp.ge.u64 %p1, %rd1, %rd5; + @%p1 bra $L__BB2_9; + + cvta.to.global.u64 %rd6, %rd2; + shl.b64 %rd7, %rd1, 1; + add.s64 %rd8, %rd6, %rd7; + ld.global.u16 %rs1, [%rd8]; + // begin inline asm + { cvt.f32.f16 %f27, %rs1;} + + // end inline asm + cvta.to.global.u64 %rd9, %rd3; + add.s64 %rd10, %rd9, %rd7; + ld.global.u16 %rs2, [%rd10]; + // begin inline asm + { cvt.f32.f16 %f29, %rs2;} + + // end inline asm + abs.f32 %f15, %f27; + setp.gtu.f32 %p2, %f15, 0f7F800000; + @%p2 bra $L__BB2_3; + + mov.f32 %f16, 0f477FE100; + min.f32 %f17, %f27, %f16; + mov.f32 %f18, 0fC77FE100; + max.f32 %f27, %f18, %f17; + +$L__BB2_3: + abs.f32 %f28, %f29; + setp.gtu.f32 %p3, %f28, 0f7F800000; + @%p3 bra $L__BB2_5; + + mov.f32 %f19, 0f477FE100; + min.f32 %f20, %f29, %f19; + mov.f32 %f21, 0fC77FE100; + max.f32 %f29, %f21, %f20; + abs.f32 %f28, %f29; + +$L__BB2_5: + abs.f32 %f10, %f27; + setp.gtu.f32 %p4, %f10, 0f7F800000; + setp.gtu.f32 %p5, %f28, 0f7F800000; + and.pred %p6, %p4, %p5; + @%p6 bra $L__BB2_9; + + sub.f32 %f22, %f27, %f29; + abs.f32 %f23, %f22; + max.f32 %f24, %f10, %f28; + add.f32 %f25, %f24, 0f3F800000; + div.rn.f32 %f11, %f23, %f25; + setp.gt.f32 %p7, %f11, %f12; + @%p7 bra $L__BB2_8; + + abs.f32 %f26, %f11; + setp.le.f32 %p8, %f26, 0f7F800000; + @%p8 bra $L__BB2_9; + +$L__BB2_8: + cvta.to.global.u64 %rd11, %rd4; + atom.global.add.u32 %r5, [%rd11], 1; + +$L__BB2_9: + ret; + } -and.b32 %r7, %r1, 2147483647; -setp.eq.s32 %p5, %r7, 2146435072; -setp.eq.s32 %p6, %r6, 0; -and.pred %p7, %p5, %p6; -@!%p7 bra LBB2_4; -bra.uni LBB2_3; -LBB2_3: + // .globl __xla_fp32_comparison +.visible .entry __xla_fp32_comparison( + .param .u64 __xla_fp32_comparison_param_0, + .param .u64 __xla_fp32_comparison_param_1, + .param .f32 __xla_fp32_comparison_param_2, + .param .u64 __xla_fp32_comparison_param_3, + .param .u64 __xla_fp32_comparison_param_4 +) { -.reg .b32 %temp; -mov.b64 {%r8, %temp}, %fd2; + .reg .pred %p<10>; + .reg .b16 %rs<3>; + .reg .f32 %f<18>; + .reg .b32 %r<10>; + .reg .b64 %rd<12>; + + + ld.param.u64 %rd2, [__xla_fp32_comparison_param_0]; + ld.param.u64 %rd3, [__xla_fp32_comparison_param_1]; + ld.param.f32 %f9, [__xla_fp32_comparison_param_2]; + ld.param.u64 %rd5, [__xla_fp32_comparison_param_3]; + ld.param.u64 %rd4, [__xla_fp32_comparison_param_4]; + mov.u32 %r1, %ntid.x; + mov.u32 %r2, %ctaid.x; + mov.u32 %r3, %tid.x; + mad.lo.s32 %r4, %r2, %r1, %r3; + cvt.s64.s32 %rd1, %r4; + setp.ge.u64 %p1, %rd1, %rd5; + @%p1 bra $L__BB3_9; + + cvta.to.global.u64 %rd6, %rd2; + shl.b64 %rd7, %rd1, 2; + add.s64 %rd8, %rd6, %rd7; + cvta.to.global.u64 %rd9, %rd3; + add.s64 %rd10, %rd9, %rd7; + ld.global.f32 %f1, [%rd10]; + ld.global.f32 %f2, [%rd8]; + abs.f32 %f3, %f2; + abs.f32 %f17, %f1; + setp.gtu.f32 %p2, %f3, 0f7F800000; + @%p2 bra $L__BB3_3; + bra.uni $L__BB3_4; + +$L__BB3_3: + setp.gtu.f32 %p3, %f17, 0f7F800000; + @%p3 bra $L__BB3_9; + +$L__BB3_4: + setp.neu.f32 %p4, %f17, 0f7F800000; + setp.neu.f32 %p5, %f3, 0f7F800000; + or.pred %p6, %p5, %p4; + @%p6 bra $L__BB3_6; + + mov.b32 %r5, %f2; + shr.u32 %r6, %r5, 31; + cvt.u16.u32 %rs1, %r6; + mov.b32 %r7, %f1; + shr.u32 %r8, %r7, 31; + cvt.u16.u32 %rs2, %r8; + setp.eq.s16 %p7, %rs1, %rs2; + mov.f32 %f17, 0f7F800000; + @%p7 bra $L__BB3_9; + +$L__BB3_6: + sub.f32 %f11, %f2, %f1; + abs.f32 %f12, %f11; + max.f32 %f13, %f3, %f17; + add.f32 %f14, %f13, 0f3F800000; + div.rn.f32 %f8, %f12, %f14; + setp.gt.f32 %p8, %f8, %f9; + @%p8 bra $L__BB3_8; + + abs.f32 %f15, %f8; + setp.le.f32 %p9, %f15, 0f7F800000; + @%p9 bra $L__BB3_9; + +$L__BB3_8: + cvta.to.global.u64 %rd11, %rd4; + atom.global.add.u32 %r9, [%rd11], 1; + +$L__BB3_9: + ret; + } + // .globl __xla_fp64_comparison +.visible .entry __xla_fp64_comparison( + .param .u64 __xla_fp64_comparison_param_0, + .param .u64 __xla_fp64_comparison_param_1, + .param .f32 __xla_fp64_comparison_param_2, + .param .u64 __xla_fp64_comparison_param_3, + .param .u64 __xla_fp64_comparison_param_4 +) { -.reg .b32 %temp; -mov.b64 {%temp, %r9}, %fd2; -} -and.b32 %r10, %r9, 2147483647; -setp.eq.s32 %p8, %r10, 2146435072; -setp.eq.s32 %p9, %r8, 0; -and.pred %p10, %p8, %p9; -xor.b32 %r11, %r9, %r1; -setp.gt.s32 %p11, %r11, -1; -and.pred %p12, %p10, %p11; -@%p12 bra LBB2_6; -LBB2_4: -ld.param.f32 %f1, [__xla_fp64_comparison_param_2]; -sub.f64 %fd5, %fd1, %fd2; -abs.f64 %fd6, %fd5; -max.f64 %fd7, %fd3, %fd4; -add.f64 %fd8, %fd7, 0d3FF0000000000000; -div.rn.f64 %fd9, %fd6, %fd8; -cvt.f64.f32 %fd10, %f1; -setp.gt.f64 %p13, %fd9, %fd10; -abs.f64 %fd11, %fd9; -setp.gtu.f64 %p14, %fd11, 0d7FF0000000000000; -or.pred %p15, %p13, %p14; -@!%p15 bra LBB2_6; -bra.uni LBB2_5; -LBB2_5: -ld.param.u64 %rd6, [__xla_fp64_comparison_param_4]; -cvta.to.global.u64 %rd1, %rd6; -atom.global.add.u32 %r12, [%rd1], 1; -LBB2_6: -ret; + .reg .pred %p<13>; + .reg .b16 %rs<3>; + .reg .f32 %f<2>; + .reg .b32 %r<14>; + .reg .f64 %fd<13>; + .reg .b64 %rd<12>; + + + ld.param.u64 %rd2, [__xla_fp64_comparison_param_0]; + ld.param.u64 %rd3, [__xla_fp64_comparison_param_1]; + ld.param.f32 %f1, [__xla_fp64_comparison_param_2]; + ld.param.u64 %rd5, [__xla_fp64_comparison_param_3]; + ld.param.u64 %rd4, [__xla_fp64_comparison_param_4]; + mov.u32 %r3, %ntid.x; + mov.u32 %r4, %ctaid.x; + mov.u32 %r5, %tid.x; + mad.lo.s32 %r6, %r4, %r3, %r5; + cvt.s64.s32 %rd1, %r6; + setp.ge.u64 %p1, %rd1, %rd5; + @%p1 bra $L__BB4_9; + + cvta.to.global.u64 %rd6, %rd2; + shl.b64 %rd7, %rd1, 3; + add.s64 %rd8, %rd6, %rd7; + cvta.to.global.u64 %rd9, %rd3; + add.s64 %rd10, %rd9, %rd7; + ld.global.f64 %fd1, [%rd10]; + ld.global.f64 %fd2, [%rd8]; + abs.f64 %fd3, %fd2; + setp.le.f64 %p2, %fd3, 0d7FF0000000000000; + @%p2 bra $L__BB4_3; + + abs.f64 %fd5, %fd1; + setp.gtu.f64 %p3, %fd5, 0d7FF0000000000000; + @%p3 bra $L__BB4_9; + +$L__BB4_3: + { + .reg .b32 %temp; + mov.b64 {%r7, %temp}, %fd2; + } + { + .reg .b32 %temp; + mov.b64 {%temp, %r1}, %fd2; + } + and.b32 %r8, %r1, 2147483647; + setp.ne.s32 %p4, %r8, 2146435072; + setp.ne.s32 %p5, %r7, 0; + or.pred %p6, %p4, %p5; + @%p6 bra $L__BB4_6; + + { + .reg .b32 %temp; + mov.b64 {%r9, %temp}, %fd1; + } + { + .reg .b32 %temp; + mov.b64 {%temp, %r2}, %fd1; + } + and.b32 %r10, %r2, 2147483647; + setp.ne.s32 %p7, %r10, 2146435072; + setp.ne.s32 %p8, %r9, 0; + or.pred %p9, %p7, %p8; + @%p9 bra $L__BB4_6; + + shr.u32 %r11, %r1, 31; + cvt.u16.u32 %rs1, %r11; + shr.u32 %r12, %r2, 31; + cvt.u16.u32 %rs2, %r12; + setp.eq.s16 %p10, %rs1, %rs2; + @%p10 bra $L__BB4_9; + +$L__BB4_6: + sub.f64 %fd6, %fd2, %fd1; + abs.f64 %fd7, %fd6; + abs.f64 %fd8, %fd1; + max.f64 %fd9, %fd3, %fd8; + add.f64 %fd10, %fd9, 0d3FF0000000000000; + div.rn.f64 %fd4, %fd7, %fd10; + cvt.f64.f32 %fd11, %f1; + setp.gt.f64 %p11, %fd4, %fd11; + @%p11 bra $L__BB4_8; + + abs.f64 %fd12, %fd4; + setp.le.f64 %p12, %fd12, 0d7FF0000000000000; + @%p12 bra $L__BB4_9; + +$L__BB4_8: + cvta.to.global.u64 %rd11, %rd4; + atom.global.add.u32 %r13, [%rd11], 1; + +$L__BB4_9: + ret; } -// .globl__xla_bf16_comparison + // .globl __xla_bf16_comparison .visible .entry __xla_bf16_comparison( -.param .u64 __xla_bf16_comparison_param_0, -.param .u64 __xla_bf16_comparison_param_1, -.param .f32 __xla_bf16_comparison_param_2, -.param .u64 __xla_bf16_comparison_param_3, -.param .u64 __xla_bf16_comparison_param_4 + .param .u64 __xla_bf16_comparison_param_0, + .param .u64 __xla_bf16_comparison_param_1, + .param .f32 __xla_bf16_comparison_param_2, + .param .u64 __xla_bf16_comparison_param_3, + .param .u64 __xla_bf16_comparison_param_4 ) { -.reg .pred %p<10>; -.reg .b16 %rs<3>; -.reg .f32 %f<20>; -.reg .b32 %r<6>; -.reg .b64 %rd<12>; - -ld.param.u64 %rd8, [__xla_bf16_comparison_param_3]; -mov.u32 %r1, %tid.x; -mov.u32 %r2, %ctaid.x; -mov.u32 %r3, %ntid.x; -mad.lo.s32 %r4, %r3, %r2, %r1; -cvt.s64.s32 %rd4, %r4; -setp.ge.u64 %p1, %rd4, %rd8; -@%p1 bra LBB3_4; -ld.param.u64 %rd5, [__xla_bf16_comparison_param_0]; -ld.param.u64 %rd7, [__xla_bf16_comparison_param_1]; -cvta.to.global.u64 %rd2, %rd7; -cvta.to.global.u64 %rd3, %rd5; -shl.b64 %rd9, %rd4, 1; -add.s64 %rd10, %rd3, %rd9; -ld.global.u16 %rs1, [%rd10]; -// begin inline asm -{ mov.b32 %f6, {0,%rs1};} - -// end inline asm -add.s64 %rd11, %rd2, %rd9; -ld.global.u16 %rs2, [%rd11]; -// begin inline asm -{ mov.b32 %f7, {0,%rs2};} - -// end inline asm -abs.f32 %f8, %f6; -setp.gtu.f32 %p2, %f8, 0f7F800000; -min.f32 %f9, %f6, 0f477FE100; -max.f32 %f10, %f9, 0fC77FE100; -selp.f32 %f1, %f6, %f10, %p2; -abs.f32 %f11, %f7; -setp.gtu.f32 %p3, %f11, 0f7F800000; -min.f32 %f12, %f7, 0f477FE100; -max.f32 %f13, %f12, 0fC77FE100; -selp.f32 %f2, %f7, %f13, %p3; -abs.f32 %f3, %f1; -setp.gtu.f32 %p4, %f3, 0f7F800000; -abs.f32 %f4, %f2; -setp.gtu.f32 %p5, %f4, 0f7F800000; -and.pred %p6, %p4, %p5; -@%p6 bra LBB3_4; -ld.param.f32 %f5, [__xla_bf16_comparison_param_2]; -sub.f32 %f14, %f1, %f2; -abs.f32 %f15, %f14; -max.f32 %f16, %f3, %f4; -add.f32 %f17, %f16, 0f3F800000; -div.rn.f32 %f18, %f15, %f17; -setp.gt.f32 %p7, %f18, %f5; -abs.f32 %f19, %f18; -setp.gtu.f32 %p8, %f19, 0f7F800000; -or.pred %p9, %p7, %p8; -@!%p9 bra LBB3_4; -bra.uni LBB3_3; -LBB3_3: -ld.param.u64 %rd6, [__xla_bf16_comparison_param_4]; -cvta.to.global.u64 %rd1, %rd6; -atom.global.add.u32 %r5, [%rd1], 1; -LBB3_4: -ret; + .reg .pred %p<9>; + .reg .b16 %rs<3>; + .reg .f32 %f<30>; + .reg .b32 %r<6>; + .reg .b64 %rd<12>; + + + ld.param.u64 %rd2, [__xla_bf16_comparison_param_0]; + ld.param.u64 %rd3, [__xla_bf16_comparison_param_1]; + ld.param.f32 %f12, [__xla_bf16_comparison_param_2]; + ld.param.u64 %rd5, [__xla_bf16_comparison_param_3]; + ld.param.u64 %rd4, [__xla_bf16_comparison_param_4]; + mov.u32 %r1, %ntid.x; + mov.u32 %r2, %ctaid.x; + mov.u32 %r3, %tid.x; + mad.lo.s32 %r4, %r2, %r1, %r3; + cvt.s64.s32 %rd1, %r4; + setp.ge.u64 %p1, %rd1, %rd5; + @%p1 bra $L__BB5_9; + + cvta.to.global.u64 %rd6, %rd2; + shl.b64 %rd7, %rd1, 1; + add.s64 %rd8, %rd6, %rd7; + ld.global.u16 %rs1, [%rd8]; + // begin inline asm + { mov.b32 %f27, {0,%rs1};} + + // end inline asm + cvta.to.global.u64 %rd9, %rd3; + add.s64 %rd10, %rd9, %rd7; + ld.global.u16 %rs2, [%rd10]; + // begin inline asm + { mov.b32 %f29, {0,%rs2};} + + // end inline asm + abs.f32 %f15, %f27; + setp.gtu.f32 %p2, %f15, 0f7F800000; + @%p2 bra $L__BB5_3; + + mov.f32 %f16, 0f477FE100; + min.f32 %f17, %f27, %f16; + mov.f32 %f18, 0fC77FE100; + max.f32 %f27, %f18, %f17; + +$L__BB5_3: + abs.f32 %f28, %f29; + setp.gtu.f32 %p3, %f28, 0f7F800000; + @%p3 bra $L__BB5_5; + + mov.f32 %f19, 0f477FE100; + min.f32 %f20, %f29, %f19; + mov.f32 %f21, 0fC77FE100; + max.f32 %f29, %f21, %f20; + abs.f32 %f28, %f29; + +$L__BB5_5: + abs.f32 %f10, %f27; + setp.gtu.f32 %p4, %f10, 0f7F800000; + setp.gtu.f32 %p5, %f28, 0f7F800000; + and.pred %p6, %p4, %p5; + @%p6 bra $L__BB5_9; + + sub.f32 %f22, %f27, %f29; + abs.f32 %f23, %f22; + max.f32 %f24, %f10, %f28; + add.f32 %f25, %f24, 0f3F800000; + div.rn.f32 %f11, %f23, %f25; + setp.gt.f32 %p7, %f11, %f12; + @%p7 bra $L__BB5_8; + + abs.f32 %f26, %f11; + setp.le.f32 %p8, %f26, 0f7F800000; + @%p8 bra $L__BB5_9; + +$L__BB5_8: + cvta.to.global.u64 %rd11, %rd4; + atom.global.add.u32 %r5, [%rd11], 1; + +$L__BB5_9: + ret; } -// .globl__xla_int8_comparison + // .globl __xla_int8_comparison .visible .entry __xla_int8_comparison( -.param .u64 __xla_int8_comparison_param_0, -.param .u64 __xla_int8_comparison_param_1, -.param .f32 __xla_int8_comparison_param_2, -.param .u64 __xla_int8_comparison_param_3, -.param .u64 __xla_int8_comparison_param_4 + .param .u64 __xla_int8_comparison_param_0, + .param .u64 __xla_int8_comparison_param_1, + .param .f32 __xla_int8_comparison_param_2, + .param .u64 __xla_int8_comparison_param_3, + .param .u64 __xla_int8_comparison_param_4 ) { - .reg .pred %p<5>; - .reg .f32 %f<12>; - .reg .b32 %r<8>; - .reg .b64 %rd<11>; - - ld.param.u64 %rd8, [__xla_int8_comparison_param_3]; - mov.u32 %r1, %tid.x; - mov.u32 %r2, %ctaid.x; - mov.u32 %r3, %ntid.x; - mad.lo.s32 %r4, %r3, %r2, %r1; - cvt.s64.s32 %rd4, %r4; - setp.ge.u64 %p1, %rd4, %rd8; - @%p1 bra LBB7_3; - ld.param.f32 %f1, [__xla_int8_comparison_param_2]; - ld.param.u64 %rd5, [__xla_int8_comparison_param_0]; - ld.param.u64 %rd7, [__xla_int8_comparison_param_1]; - cvta.to.global.u64 %rd2, %rd7; - cvta.to.global.u64 %rd3, %rd5; - add.s64 %rd9, %rd3, %rd4; - ld.global.s8 %r5, [%rd9]; - add.s64 %rd10, %rd2, %rd4; - ld.global.s8 %r6, [%rd10]; - cvt.rn.f32.s32 %f2, %r5; - cvt.rn.f32.s32 %f3, %r6; - sub.f32 %f4, %f2, %f3; - abs.f32 %f5, %f4; - abs.f32 %f6, %f2; - abs.f32 %f7, %f3; - max.f32 %f8, %f6, %f7; - add.f32 %f9, %f8, 0f3F800000; - div.rn.f32 %f10, %f5, %f9; - setp.leu.f32 %p2, %f10, %f1; - abs.f32 %f11, %f10; - setp.le.f32 %p3, %f11, 0f7F800000; - and.pred %p4, %p2, %p3; - @%p4 bra LBB7_3; - ld.param.u64 %rd6, [__xla_int8_comparison_param_4]; - cvta.to.global.u64 %rd1, %rd6; - atom.global.add.u32 %r7, [%rd1], 1; -LBB7_3: - ret; -} + .reg .pred %p<4>; + .reg .b16 %rs<3>; + .reg .f32 %f<12>; + .reg .b32 %r<6>; + .reg .b64 %rd<11>; + -// .globl__xla_int32_comparison + ld.param.u64 %rd2, [__xla_int8_comparison_param_0]; + ld.param.u64 %rd3, [__xla_int8_comparison_param_1]; + ld.param.f32 %f2, [__xla_int8_comparison_param_2]; + ld.param.u64 %rd5, [__xla_int8_comparison_param_3]; + ld.param.u64 %rd4, [__xla_int8_comparison_param_4]; + mov.u32 %r1, %ntid.x; + mov.u32 %r2, %ctaid.x; + mov.u32 %r3, %tid.x; + mad.lo.s32 %r4, %r2, %r1, %r3; + cvt.s64.s32 %rd1, %r4; + setp.ge.u64 %p1, %rd1, %rd5; + @%p1 bra $L__BB6_4; + + cvta.to.global.u64 %rd6, %rd2; + add.s64 %rd7, %rd6, %rd1; + ld.global.s8 %rs1, [%rd7]; + cvt.rn.f32.s16 %f3, %rs1; + cvta.to.global.u64 %rd8, %rd3; + add.s64 %rd9, %rd8, %rd1; + ld.global.s8 %rs2, [%rd9]; + cvt.rn.f32.s16 %f4, %rs2; + sub.f32 %f5, %f3, %f4; + abs.f32 %f6, %f5; + abs.f32 %f7, %f3; + abs.f32 %f8, %f4; + max.f32 %f9, %f7, %f8; + add.f32 %f10, %f9, 0f3F800000; + div.rn.f32 %f1, %f6, %f10; + setp.gt.f32 %p2, %f1, %f2; + @%p2 bra $L__BB6_3; + + abs.f32 %f11, %f1; + setp.le.f32 %p3, %f11, 0f7F800000; + @%p3 bra $L__BB6_4; + +$L__BB6_3: + cvta.to.global.u64 %rd10, %rd4; + atom.global.add.u32 %r5, [%rd10], 1; + +$L__BB6_4: + ret; + +} + // .globl __xla_int32_comparison .visible .entry __xla_int32_comparison( -.param .u64 __xla_int32_comparison_param_0, -.param .u64 __xla_int32_comparison_param_1, -.param .f32 __xla_int32_comparison_param_2, -.param .u64 __xla_int32_comparison_param_3, -.param .u64 __xla_int32_comparison_param_4 + .param .u64 __xla_int32_comparison_param_0, + .param .u64 __xla_int32_comparison_param_1, + .param .f32 __xla_int32_comparison_param_2, + .param .u64 __xla_int32_comparison_param_3, + .param .u64 __xla_int32_comparison_param_4 ) { -.reg .pred %p<5>; -.reg .f32 %f<12>; -.reg .b32 %r<8>; -.reg .b64 %rd<12>; - -ld.param.u64 %rd8, [__xla_int32_comparison_param_3]; -mov.u32 %r1, %tid.x; -mov.u32 %r2, %ctaid.x; -mov.u32 %r3, %ntid.x; -mad.lo.s32 %r4, %r3, %r2, %r1; -cvt.s64.s32 %rd4, %r4; -setp.ge.u64 %p1, %rd4, %rd8; -@%p1 bra LBB5_3; -ld.param.f32 %f1, [__xla_int32_comparison_param_2]; -ld.param.u64 %rd5, [__xla_int32_comparison_param_0]; -ld.param.u64 %rd7, [__xla_int32_comparison_param_1]; -cvta.to.global.u64 %rd2, %rd7; -cvta.to.global.u64 %rd3, %rd5; -shl.b64 %rd9, %rd4, 2; -add.s64 %rd10, %rd3, %rd9; -ld.global.u32 %r5, [%rd10]; -cvt.rn.f32.s32 %f2, %r5; -add.s64 %rd11, %rd2, %rd9; -ld.global.u32 %r6, [%rd11]; -cvt.rn.f32.s32 %f3, %r6; -sub.f32 %f4, %f2, %f3; -abs.f32 %f5, %f4; -abs.f32 %f6, %f2; -abs.f32 %f7, %f3; -max.f32 %f8, %f6, %f7; -add.f32 %f9, %f8, 0f3F800000; -div.rn.f32 %f10, %f5, %f9; -setp.gt.f32 %p2, %f10, %f1; -abs.f32 %f11, %f10; -setp.gtu.f32 %p3, %f11, 0f7F800000; -or.pred %p4, %p2, %p3; -@!%p4 bra LBB5_3; -bra.uni LBB5_2; -LBB5_2: -ld.param.u64 %rd6, [__xla_int32_comparison_param_4]; -cvta.to.global.u64 %rd1, %rd6; -atom.global.add.u32 %r7, [%rd1], 1; -LBB5_3: -ret; + .reg .pred %p<4>; + .reg .f32 %f<12>; + .reg .b32 %r<8>; + .reg .b64 %rd<12>; + + + ld.param.u64 %rd2, [__xla_int32_comparison_param_0]; + ld.param.u64 %rd3, [__xla_int32_comparison_param_1]; + ld.param.f32 %f2, [__xla_int32_comparison_param_2]; + ld.param.u64 %rd5, [__xla_int32_comparison_param_3]; + ld.param.u64 %rd4, [__xla_int32_comparison_param_4]; + mov.u32 %r1, %ntid.x; + mov.u32 %r2, %ctaid.x; + mov.u32 %r3, %tid.x; + mad.lo.s32 %r4, %r2, %r1, %r3; + cvt.s64.s32 %rd1, %r4; + setp.ge.u64 %p1, %rd1, %rd5; + @%p1 bra $L__BB7_4; + + cvta.to.global.u64 %rd6, %rd2; + shl.b64 %rd7, %rd1, 2; + add.s64 %rd8, %rd6, %rd7; + ld.global.u32 %r5, [%rd8]; + cvt.rn.f32.s32 %f3, %r5; + cvta.to.global.u64 %rd9, %rd3; + add.s64 %rd10, %rd9, %rd7; + ld.global.u32 %r6, [%rd10]; + cvt.rn.f32.s32 %f4, %r6; + sub.f32 %f5, %f3, %f4; + abs.f32 %f6, %f5; + abs.f32 %f7, %f3; + abs.f32 %f8, %f4; + max.f32 %f9, %f7, %f8; + add.f32 %f10, %f9, 0f3F800000; + div.rn.f32 %f1, %f6, %f10; + setp.gt.f32 %p2, %f1, %f2; + @%p2 bra $L__BB7_3; + + abs.f32 %f11, %f1; + setp.le.f32 %p3, %f11, 0f7F800000; + @%p3 bra $L__BB7_4; + +$L__BB7_3: + cvta.to.global.u64 %rd11, %rd4; + atom.global.add.u32 %r7, [%rd11], 1; + +$L__BB7_4: + ret; } )"; @@ -773,6 +1188,12 @@ StatusOr BufferComparator::CompareEqual( se::Stream* stream, se::DeviceMemoryBase current, se::DeviceMemoryBase expected) const { switch (shape_.element_type()) { + case xla::F8E4M3FN: + return CompareEqualParameterized( + stream, lhs, rhs, shape_, config_, "__xla_fp8_e4m3fn_comparison"); + case xla::F8E5M2: + return CompareEqualParameterized( + stream, lhs, rhs, shape_, config_, "__xla_fp8_e5m2_comparison"); case xla::F16: return CompareEqualParameterized( stream, current, expected, shape_, config_, "__xla_fp16_comparison"); diff --git a/tensorflow/compiler/xla/service/gpu/conv_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/conv_algorithm_picker.cc index 46e87ec03531e8..c3b21a5e216e5c 100644 --- a/tensorflow/compiler/xla/service/gpu/conv_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/conv_algorithm_picker.cc @@ -109,7 +109,7 @@ StatusOr> ScratchAllocator::AllocateBytes( return se::DeviceMemory(buffer_addr); } -StatusOr> GetAlgorithms( +StatusOr> GetAlgorithms( const GpuConvConfig& config, se::Stream* stream, bool use_cudnn_frontend, bool use_fallback, const se::NumericOptions& numeric_options) { TF_ASSIGN_OR_RETURN(se::dnn::ConvolutionKind kind, @@ -122,8 +122,7 @@ StatusOr> GetAlgorithms( GetDNNDataTypeFromPrimitiveType(config.output_type)); se::StreamExecutor* stream_exec = stream->parent(); - - std::vector result; + std::vector result; switch (kind) { default: @@ -156,6 +155,30 @@ StatusOr> GetAlgorithms( break; } + case se::dnn::ConvolutionKind::FORWARD_GRAPH: { + std::vector> runners; + // This path is cuDNN-only, where the DeviceMemoryBase arguments and the + // allocator are unused; so, they're all provided as nullptr. + TF_RETURN_IF_ERROR(stream_exec->GetGraphConvolveRunners( + use_cudnn_frontend, kind, input_type, output_type, stream, + config.input_descriptor, + /* input_data = */ DeviceMemoryBase(nullptr), + config.filter_descriptor, + /* filter_data = */ DeviceMemoryBase(nullptr), + config.output_descriptor, + /* output_data = */ DeviceMemoryBase(nullptr), config.conv_desc, + use_fallback, nullptr, se::NumericOptions{deterministic_ops}, + &runners, config.serialized_graph)); + for (auto& runner : runners) { + TF_ASSIGN_OR_RETURN( + auto runner_cache, + se::dnn::LazyOpRunner::FromOpRunner( + std::move(runner))); + result.emplace_back(std::move(runner_cache)); + } + break; + } + case se::dnn::ConvolutionKind::FORWARD: case se::dnn::ConvolutionKind::BACKWARD_DATA: case se::dnn::ConvolutionKind::BACKWARD_FILTER: { @@ -446,7 +469,7 @@ GpuConvAlgorithmPicker::AutotuneRuntimeArguments::FromInstruction( // simply skips the engine/algorithm while recording a reason for skipping it. StatusOr GpuConvAlgorithmPicker::AutotuneOneConvRunner( se::DeviceMemoryAllocator* allocator, se::Stream* stream, - MaybeFusedConvRunner* const runner, + GenericConvRunner* const runner, std::optional* reference_result, absl::Span disabled_algos, std::optional instruction_info, @@ -736,7 +759,7 @@ StatusOr GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda( std::optional reference_result; TF_ASSIGN_OR_RETURN( - std::vector runners, + std::vector runners, GetAlgorithms(runtime_arguments.gpu_conv_config, stream, cudnn_frontend_enabled, /* use_fallback = */ false, numeric_options)); @@ -760,7 +783,7 @@ StatusOr GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda( } TF_ASSIGN_OR_RETURN( - std::vector fallback_runners, + std::vector fallback_runners, GetAlgorithms(runtime_arguments.gpu_conv_config, stream, cudnn_frontend_enabled, /* use_fallback = */ true, numeric_options)); @@ -940,7 +963,7 @@ StatusOr GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheRocm( se::dnn::LazyOpRunner::FromOpRunner( std::move(runner))); - MaybeFusedConvRunner runner_cache(std::move(lazy_runner)); + GenericConvRunner runner_cache(std::move(lazy_runner)); // Use assignment instead of brace-list to make GCC 4.9 happy. RunConvOptions options; diff --git a/tensorflow/compiler/xla/service/gpu/conv_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/conv_algorithm_picker.h index 0d265a2e7f680b..e451c0d87b0aa6 100644 --- a/tensorflow/compiler/xla/service/gpu/conv_algorithm_picker.h +++ b/tensorflow/compiler/xla/service/gpu/conv_algorithm_picker.h @@ -139,7 +139,7 @@ class GpuConvAlgorithmPicker : public HloModulePass { StatusOr AutotuneOneConvRunner( se::DeviceMemoryAllocator* allocator, se::Stream* stream, - MaybeFusedConvRunner* runner, + GenericConvRunner* runner, std::optional* reference_result, absl::Span disabled_algos, std::optional instruction_info, diff --git a/tensorflow/compiler/xla/service/gpu/conv_layout_normalization.cc b/tensorflow/compiler/xla/service/gpu/conv_layout_normalization.cc index 90b3ef1344797b..23684177e4fb8e 100644 --- a/tensorflow/compiler/xla/service/gpu/conv_layout_normalization.cc +++ b/tensorflow/compiler/xla/service/gpu/conv_layout_normalization.cc @@ -60,7 +60,8 @@ StatusOr UpdateLayoutForCudnnConvolution( gpu::GetCudnnConvKind(Cast(hlo))); switch (conv_kind) { case gpu::CudnnConvKind::kForward: - case gpu::CudnnConvKind::kForwardActivation: { + case gpu::CudnnConvKind::kForwardActivation: + case gpu::CudnnConvKind::kForwardGraph: { input_shape = lhs->shape(); filter_shape = rhs->shape(); output_shape = conv_output_shape; diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc index 9589cfdb8b438a..c8f361ed5b886c 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc @@ -41,13 +41,13 @@ ConvolutionThunk::ConvolutionThunk( scratch_buffer_(scratch_slice), config_(std::move(config)) {} -MaybeFusedConvRunner& ConvolutionThunk::GetOrCreateRunner( +GenericConvRunner& ConvolutionThunk::GetOrCreateRunner( const stream_executor::Stream* stream) { absl::MutexLock lock(&mu_); auto it = runner_cache_.find(stream); if (it == runner_cache_.end()) { it = runner_cache_ - .insert({stream, std::make_unique(config_)}) + .insert({stream, std::make_unique(config_)}) .first; } return *it->second; diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h index bb659f52d77471..cdd4fabb717aa4 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h @@ -57,14 +57,13 @@ class ConvolutionThunk : public Thunk { std::vector operand_buffers_; BufferAllocation::Slice result_buffer_; BufferAllocation::Slice scratch_buffer_; - MaybeFusedConvRunner& GetOrCreateRunner( - const stream_executor::Stream* stream); + GenericConvRunner& GetOrCreateRunner(const stream_executor::Stream* stream); // Convolution config const GpuConvConfig config_; absl::Mutex mu_; absl::flat_hash_map> + std::unique_ptr> runner_cache_ ABSL_GUARDED_BY(mu_); }; diff --git a/tensorflow/compiler/xla/service/gpu/cublas_cudnn.cc b/tensorflow/compiler/xla/service/gpu/cublas_cudnn.cc index e9c9fda6d823ad..8f448b32a90fe9 100644 --- a/tensorflow/compiler/xla/service/gpu/cublas_cudnn.cc +++ b/tensorflow/compiler/xla/service/gpu/cublas_cudnn.cc @@ -53,6 +53,8 @@ const absl::string_view kCudnnConvBackwardFilterCallTarget = const absl::string_view kCudnnConvBiasActivationForwardCallTarget = "__cudnn$convBiasActivationForward"; const absl::string_view kCudnnConvForwardCallTarget = "__cudnn$convForward"; +const absl::string_view kCudnnConvForwardGraphCallTarget = + "__cudnn$convForwardGraph"; const absl::string_view kCudnnConvReorderFilterCallTarget = "__cudnn$convReorderFilter"; const absl::string_view kCudnnConvReorderFilterAndBiasCallTarget = @@ -103,6 +105,7 @@ bool IsCustomCallToDnnConvolution(const HloInstruction& hlo) { } const auto& target = hlo.custom_call_target(); return target == kCudnnConvForwardCallTarget || + target == kCudnnConvForwardGraphCallTarget || target == kCudnnConvBackwardInputCallTarget || target == kCudnnConvBackwardFilterCallTarget || target == kCudnnConvBiasActivationForwardCallTarget; @@ -170,6 +173,9 @@ StatusOr GetCudnnConvKind( if (target == kCudnnConvForwardCallTarget) { return CudnnConvKind::kForward; } + if (target == kCudnnConvForwardGraphCallTarget) { + return CudnnConvKind::kForwardGraph; + } if (target == kCudnnConvBackwardInputCallTarget) { return CudnnConvKind::kBackwardInput; } @@ -192,6 +198,8 @@ std::string CudnnConvKindToString(CudnnConvKind kind) { return "backward_input"; case CudnnConvKind::kForwardActivation: return "forward with activation"; + case CudnnConvKind::kForwardGraph: + return "forward with pointwise operations"; } } diff --git a/tensorflow/compiler/xla/service/gpu/cublas_cudnn.h b/tensorflow/compiler/xla/service/gpu/cublas_cudnn.h index 708ee44293851b..fce3022909c7f5 100644 --- a/tensorflow/compiler/xla/service/gpu/cublas_cudnn.h +++ b/tensorflow/compiler/xla/service/gpu/cublas_cudnn.h @@ -44,6 +44,8 @@ enum class CudnnConvKind { kBackwardFilter, // input + output => filter kForwardActivation, // activation(conv(input, filter) + broadcast(bias) + // (optionally) side_input) => output + kForwardGraph, // pointwise(...pointwise(conv(input, filter))...) + // => output }; enum class CudnnfMHAKind { @@ -130,6 +132,7 @@ extern const absl::string_view kCudnnConvForwardCallTarget; extern const absl::string_view kCudnnConvBackwardInputCallTarget; extern const absl::string_view kCudnnConvBackwardFilterCallTarget; extern const absl::string_view kCudnnConvBiasActivationForwardCallTarget; +extern const absl::string_view kCudnnConvForwardGraphCallTarget; // cuDNN specific convolution helper (emitted together with a int8x32 // convolution, if reordering is required). diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc index 8bf17cca183d0f..8dde0aef08b584 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/tsl/platform/errors.h" #include "tensorflow/tsl/platform/statusor.h" +#include "third_party/gpus/cuda/include/cuda.h" namespace xla { namespace gpu { @@ -296,6 +297,239 @@ StatusOr FuseConvAlpha(HloComputation* comp) { return changed; } +bool IsF8Type(const HloInstruction* instr) { + return primitive_util::IsF8Type(instr->shape().element_type()); +} + +// Format of the serialized graph is +// "conv[output type]->op name[output type]->op name[output type]->...". +class GraphString { + public: + GraphString() : size_(0) {} + + void AppendOp(std::string op_name, PrimitiveType type) { + graph_.append(op_name + "[" + + primitive_util::LowercasePrimitiveTypeName(type) + "]->"); + size_++; + } + + void ChangeDataType(PrimitiveType type) { + std::string::size_type m = graph_.find_last_of('['); + std::string::size_type n = graph_.find_last_of(']'); + graph_.replace(m + 1, n - m - 1, + primitive_util::LowercasePrimitiveTypeName(type)); + } + + int Size() { return size_; } + + std::string Graph() { return graph_; } + + private: + std::string graph_; + int size_; +}; + +// Recursively captures and serializes the graph of pointwise operations +// operating on the convolution. +bool CaptureConvGraphRecursive(HloInstruction* instr, + std::vector& operands, + GraphString& graph_string, + absl::flat_hash_set& visited_instrs, + HloInstruction*& final_instr, + int pattern_level = 0) { + // The maximum depth of the considered patterns. + const int max_pattern_level = 1; + // Avoid visiting the same instruction more than once. + if (!visited_instrs.emplace(instr->unique_id()).second) { + return false; + } + // When the function was called from outside or after a successful match, set + // the final instruction to the current instruction. + if (pattern_level == 0) { + final_instr = instr; + } + + HloInstruction *op, *operand; + for (HloInstruction* user : instr->users()) { + if (pattern_level == 0) { + // Add + if (Match(user, m::AddAnyOrder(&op, m::Op(), m::Op(&operand)))) { + graph_string.AppendOp("add", op->shape().element_type()); + operands.push_back(operand); + return CaptureConvGraphRecursive(user, operands, graph_string, + visited_instrs, final_instr, 0); + } + // Scale + if (Match(user, m::MultiplyAnyOrder(&op, m::Op(), + m::Broadcast(m::Op(&operand))))) { + graph_string.AppendOp("scale", op->shape().element_type()); + operands.push_back(operand); + return CaptureConvGraphRecursive(user, operands, graph_string, + visited_instrs, final_instr, 0); + } + // Inverse Scale + if (Match(user, m::Divide(&op, m::Op(), m::Broadcast(m::Op(&operand))))) { + graph_string.AppendOp("invscale", op->shape().element_type()); + operands.push_back(operand); + return CaptureConvGraphRecursive(user, operands, graph_string, + visited_instrs, final_instr, 0); + } + // ReLU + if (Match(user, m::MaximumAnyOrder(&op, m::Op(), + m::Broadcast(m::ConstantScalar(0))))) { + graph_string.AppendOp("relu", op->shape().element_type()); + return CaptureConvGraphRecursive(user, operands, graph_string, + visited_instrs, final_instr, 0); + } + } + + if (pattern_level == 1) { + // Convert with clamp + if (Match(user, + m::Convert(&op, + m::Clamp(m::Broadcast(m::ConstantScalar()), m::Op(), + m::Broadcast(m::ConstantScalar()))))) { + graph_string.ChangeDataType(op->shape().element_type()); + return CaptureConvGraphRecursive(user, operands, graph_string, + visited_instrs, final_instr, 0); + } + } + + // If none of the matches was successful and the pattern level is below the + // maximum level, attempt to match at higher level. + if (pattern_level < max_pattern_level) { + return CaptureConvGraphRecursive(user, operands, graph_string, + visited_instrs, final_instr, + pattern_level + 1); + } + } + + // The first entry in the serialized graph is the convolution. The size is + // greater than one if at least one match was succesful. + if (graph_string.Size() > 1) { + return true; + } else { + return false; + } +} + +// Captures the graph of pointwise operations operating on the convolution. +bool CaptureConvGraph(HloInstruction* instr, + std::vector& operands, + GraphString& graph_string, HloInstruction*& final_instr, + HloInstruction* x_scale, HloInstruction* w_scale, + bool x_mult_scale, bool w_mult_scale) { + absl::flat_hash_set visited_instrs; + graph_string.AppendOp("conv", instr->shape().element_type()); + + // Shift the scaling of the inputs to the output of the convolution. + if (x_scale && w_scale && x_mult_scale == w_mult_scale) { + HloInstruction* product = + instr->AddInstruction(HloInstruction::CreateBinary( + x_scale->shape(), HloOpcode::kMultiply, x_scale, w_scale)); + operands.push_back(product); + graph_string.AppendOp(x_mult_scale ? "scale" : "invscale", + instr->shape().element_type()); + } else { + if (x_scale) { + operands.push_back(x_scale); + graph_string.AppendOp(x_mult_scale ? "scale" : "invscale", + instr->shape().element_type()); + } + if (w_scale) { + operands.push_back(w_scale); + graph_string.AppendOp(w_mult_scale ? "scale" : "invscale", + instr->shape().element_type()); + } + } + return CaptureConvGraphRecursive(instr, operands, graph_string, + visited_instrs, final_instr); +} + +// Matches convolutions operating on FP8 inputs and filters and rewrites into a +// ForwardGraph Custom Call. For scaled FP8 convolutions on Hopper systems, the +// following steps are elided and rewritten into a ForwardGraph Custom Call: +// +// 1. Cast the filter and input from FP8 to a wider type such as FP16 or FP32. +// 2. Unscale the filter and input by multiplying them by the corresponding +// input scale. +// 3. Evaluate the convolution based on the scaled filter and input. +// 4. Optionally add a matrix bias to the result and apply a ReLU activation. +// 5. Scale the output by dividing the output by the output scale. +// 6. Cast the output back to FP8. + +StatusOr F8GraphConv(HloComputation* comp, se::CudaComputeCapability cc) { + bool changed = false; +#if CUDA_VERSION >= 12000 + for (auto instr : comp->MakeInstructionPostOrder()) { + const DebugOptions& debug_options = + instr->GetModule()->config().debug_options(); + if (!cc.IsAtLeast(se::CudaComputeCapability::HOPPER)) { + return false; + } + HloInstruction *convolution, *gte, *input, *filter, *final_instr, + *x_scale = nullptr, *w_scale = nullptr, *x_scale_op = nullptr, + *w_scale_op = nullptr; + std::vector operands; + GraphString graph_string; + + // TODO(philipphack): Consider allowing ops between dequantization and + // convolution. + auto pattern = m::GetTupleElement( + >e, + m::CustomCall( + &convolution, + m::AnyOf( + m::Convert(m::Op(&input).WithPredicate(IsF8Type)), + m::Divide(&x_scale_op, + m::Convert(m::Op(&input).WithPredicate(IsF8Type)), + m::Broadcast(m::Op(&x_scale))), + m::MultiplyAnyOrder( + &x_scale_op, + m::Convert(m::Op(&input).WithPredicate(IsF8Type)), + m::Broadcast(m::Op(&x_scale)))), + m::AnyOf( + m::Convert(m::Op(&filter).WithPredicate(IsF8Type)), + m::Divide(&w_scale_op, + m::Convert(m::Op(&input).WithPredicate(IsF8Type)), + m::Broadcast(m::Op(&x_scale))), + m::MultiplyAnyOrder( + &w_scale_op, + m::Convert(m::Op(&filter).WithPredicate(IsF8Type)), + m::Broadcast(m::Op(&w_scale))))), + 0); + if (Match(instr, pattern) && + CaptureConvGraph( + const_cast(instr), operands, graph_string, + final_instr, x_scale, w_scale, + x_scale_op ? x_scale_op->opcode() == HloOpcode::kMultiply : false, + w_scale_op ? w_scale_op->opcode() == HloOpcode::kMultiply + : false)) { + TF_ASSIGN_OR_RETURN( + auto config, convolution->backend_config()); + config.set_serialized_graph(graph_string.Graph()); + operands.insert(operands.begin(), input); + operands.insert(operands.begin() + 1, filter); + + Shape new_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::ChangeElementType( + ShapeUtil::GetTupleElementShape(convolution->shape(), 0), + final_instr->shape().element_type()), + ShapeUtil::GetTupleElementShape(convolution->shape(), 1)}); + HloInstruction* new_convolution = comp->AddInstruction( + convolution->CloneWithNewOperands(new_shape, operands)); + new_convolution->set_custom_call_target(kCudnnConvForwardGraphCallTarget); + TF_RETURN_IF_ERROR(new_convolution->set_backend_config(config)); + TF_ASSIGN_OR_RETURN(HloInstruction * new_gte, + MakeGetTupleElementHlo(new_convolution, 0)); + TF_RETURN_IF_ERROR(comp->ReplaceInstruction(final_instr, new_gte)); + changed = true; + } + } +#endif + return changed; +} + StatusOr FuseBiasOrSideInput(HloComputation* comp) { bool changed = false; for (auto instr : comp->MakeInstructionPostOrder()) { @@ -1018,8 +1252,14 @@ StatusOr CudnnFusedConvRewriter::Run( for (HloComputation* comp : module->MakeNonfusionComputations(execution_threads)) { - // Fuse "inside out" starting with the operations closest to the conv. bool changed = false; + // Rewrite FP8 convolutions and supported adjacent pointwise ops into a + // ForwardGraph Custom Call. + TF_ASSIGN_OR_RETURN(changed, F8GraphConv(comp, compute_capability_)); + if (changed) { + return changed; + } + // Fuse "inside out" starting with the operations closest to the conv. TF_ASSIGN_OR_RETURN(changed, FuseRemoveConvertInConv(comp)); any_changed |= changed; diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc index 0ac90d1ca480aa..8ba4995c3b110a 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc @@ -139,6 +139,27 @@ class CudnnFusedConvRewriterTest : public GpuCodegenTest { Not(HasSubstr(kCudnnConvBiasActivationForwardCallTarget))); } } + + void TestF8(absl::string_view pre_hlo_string, + absl::string_view post_hlo_string) { + if (!GetCudaComputeCapability().IsAtLeast( + se::CudaComputeCapability::HOPPER)) { + GTEST_SKIP() << "FP8 convolutions require Hopper or newer architecture."; + } + std::string alpha_conv_scalar, alpha_side_input_scalar; + std::string elementwise_type; + + std::string optimized_hlo_string = GetOptimizedHlo(pre_hlo_string); + EXPECT_THAT(optimized_hlo_string, Not(HasSubstr("Convert"))); + EXPECT_THAT(optimized_hlo_string, HasSubstr("__cudnn$conv")); + EXPECT_TRUE(RunAndCompare(pre_hlo_string, ErrorSpec{0.1})) + << pre_hlo_string; + + StatusOr filecheck_result = + RunFileCheck(optimized_hlo_string, post_hlo_string); + ASSERT_TRUE(filecheck_result.ok()) << filecheck_result.status(); + EXPECT_TRUE(*filecheck_result); + } }; TEST_F(CudnnFusedConvRewriterTest, TestConvOnly) { @@ -601,6 +622,173 @@ TEST_F(CudnnFusedConvRewriterTest, TestPreservesFeatureGroupCount) { EXPECT_TRUE(RunAndCompare(kHloString, ErrorSpec{0.01})); } +TEST_F(CudnnFusedConvRewriterTest, TestConvScaledYF8) { + TestF8( + // pre_hlo + R"( + HloModule Test + + ENTRY Test { + input = f8e4m3fn[1,128,6,6] parameter(0) + filter = f8e4m3fn[3,3,128,16] parameter(1) + input_f32 = f32[1,128,6,6] convert(input) + filter_f32 = f32[3,3,128,16] convert(filter) + z_scale = f32[] parameter(2) + z_scale_bcast = f32[1,16,6,6] broadcast(z_scale), dimensions={} + conv_a = f32[1,16,6,6] convolution(input_f32, filter_f32), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1 + conv_a_scaled = f32[1,16,6,6] multiply(conv_a, z_scale_bcast) + c1 = f32[] constant(-448.) + c1_bcast = f32[1,16,6,6] broadcast(c1), dimensions={} + c2 = f32[] constant(448.) + c2_bcast = f32[1,16,6,6] broadcast(c2), dimensions={} + conv_a_clamped = f32[1,16,6,6] clamp(c1_bcast, conv_a_scaled, c2_bcast) + ROOT conv_f8 = f8e4m3fn[1,16,6,6] convert(conv_a_clamped) + + })", + // post_hlo + R"( +// CHECK: [[cudnn_conv_4_0:%[^ ]+]] = (f8e4m3fn[1,6,6,16]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[OPERAND0:%[^ ]+]], [[OPERAND1:%[^ ]+]], [[OPERAND2:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForwardGraph" +// CHECK: "serialized_graph":"conv[f32]-\u003escale[f8e4m3fn]-\u003e" + )"); +} + +TEST_F(CudnnFusedConvRewriterTest, TestConvScaledF8) { + TestF8( + // pre_hlo + R"( + HloModule Test + + ENTRY Test { + input = f8e4m3fn[1,128,6,6] parameter(0) + filter = f8e4m3fn[3,3,128,16] parameter(1) + input_scale = f32[] parameter(2) + input_scale_bcast = f32[1,128,6,6] broadcast(input_scale), dimensions={} + filter_scale = f32[] parameter(3) + filter_scale_bcast = f32[3,3,128,16] broadcast(filter_scale), dimensions={} + input_f32 = f32[1,128,6,6] convert(input) + input_unscaled = f32[1,128,6,6] multiply(input_f32, input_scale_bcast) + filter_f32 = f32[3,3,128,16] convert(filter) + filter_unscaled = f32[3,3,128,16] multiply(filter_f32, filter_scale_bcast) + z_scale = f32[] parameter(4) + z_scale_bcast = f32[1,16,6,6] broadcast(z_scale), dimensions={} + conv_a = f32[1,16,6,6] convolution(input_unscaled, filter_unscaled), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1 + conv_a_scaled = f32[1,16,6,6] multiply(conv_a, z_scale_bcast) + c1 = f32[] constant(-448.) + c1_bcast = f32[1,16,6,6] broadcast(c1), dimensions={} + c2 = f32[] constant(448.) + c2_bcast = f32[1,16,6,6] broadcast(c2), dimensions={} + conv_a_clamped = f32[1,16,6,6] clamp(c1_bcast, conv_a_scaled, c2_bcast) + ROOT conv_f8 = f8e4m3fn[1,16,6,6] convert(conv_a_clamped) + + })", + // post_hlo + R"( +// CHECK: [[cudnn_conv_4_0:%[^ ]+]] = (f8e4m3fn[1,6,6,16]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[OPERAND0:%[^ ]+]], [[OPERAND1:%[^ ]+]], [[OPERAND2:%[^ ]+]], [[OPERAND3:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForwardGraph" +// CHECK: "serialized_graph":"conv[f32]-\u003escale[f32]-\u003escale[f8e4m3fn]-\u003e" + )"); +} + +TEST_F(CudnnFusedConvRewriterTest, TestConvScaledBiasF8) { + TestF8( + // pre_hlo + R"( + HloModule Test + + ENTRY Test { + input = f8e4m3fn[1,128,6,6] parameter(0) + filter = f8e4m3fn[3,3,128,16] parameter(1) + input_scale = f32[] parameter(2) + input_scale_bcast = f32[1,128,6,6] broadcast(input_scale), dimensions={} + filter_scale = f32[] parameter(3) + filter_scale_bcast = f32[3,3,128,16] broadcast(filter_scale), dimensions={} + input_f32 = f32[1,128,6,6] convert(input) + input_unscaled = f32[1,128,6,6] multiply(input_f32, input_scale_bcast) + filter_f32 = f32[3,3,128,16] convert(filter) + filter_unscaled = f32[3,3,128,16] multiply(filter_f32, filter_scale_bcast) + bias = f32[1,16,6,6] parameter(4) + z_scale = f32[] parameter(5) + z_scale_bcast = f32[1,16,6,6] broadcast(z_scale), dimensions={} + conv_a = f32[1,16,6,6] convolution(input_unscaled, filter_unscaled), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1 + conv_a_bias = f32[1,16,6,6] add(conv_a, bias) + conv_a_scaled = f32[1,16,6,6] multiply(conv_a_bias, z_scale_bcast) + c1 = f32[] constant(-448.) + c1_bcast = f32[1,16,6,6] broadcast(c1), dimensions={} + c2 = f32[] constant(448.) + c2_bcast = f32[1,16,6,6] broadcast(c2), dimensions={} + conv_a_clamped = f32[1,16,6,6] clamp(c1_bcast, conv_a_scaled, c2_bcast) + ROOT conv_f8 = f8e4m3fn[1,16,6,6] convert(conv_a_clamped) + + })", + // post_hlo + R"( +// CHECK: [[cudnn_conv_4_0:%[^ ]+]] = (f8e4m3fn[1,6,6,16]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[OPERAND0:%[^ ]+]], [[OPERAND1:%[^ ]+]], [[OPERAND2:%[^ ]+]], [[OPERAND3:%[^ ]+]], [[OPERAND4:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForwardGraph" +// CHECK: "serialized_graph":"conv[f32]-\u003escale[f32]-\u003eadd[f32]-\u003escale[f8e4m3fn]-\u003e" + )"); +} + +TEST_F(CudnnFusedConvRewriterTest, TestConvInvscaledF8) { + TestF8( + // pre_hlo + R"( + HloModule Test + + ENTRY Test { + input = f8e4m3fn[1,128,6,6] parameter(0) + filter = f8e4m3fn[3,3,128,16] parameter(1) + input_f32 = f32[1,128,6,6] convert(input) + filter_f32 = f32[3,3,128,16] convert(filter) + z_scale = f32[] parameter(2) + z_scale_bcast = f32[1,16,6,6] broadcast(z_scale), dimensions={} + conv_a = f32[1,16,6,6] convolution(input_f32, filter_f32), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1 + conv_a_scaled = f32[1,16,6,6] divide(conv_a, z_scale_bcast) + c1 = f32[] constant(-448.) + c1_bcast = f32[1,16,6,6] broadcast(c1), dimensions={} + c2 = f32[] constant(448.) + c2_bcast = f32[1,16,6,6] broadcast(c2), dimensions={} + conv_a_clamped = f32[1,16,6,6] clamp(c1_bcast, conv_a_scaled, c2_bcast) + ROOT conv_f8 = f8e4m3fn[1,16,6,6] convert(conv_a_clamped) + + })", + // post_hlo + R"( +// CHECK: [[cudnn_conv_4_0:%[^ ]+]] = (f8e4m3fn[1,6,6,16]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[OPERAND0:%[^ ]+]], [[OPERAND1:%[^ ]+]], [[OPERAND2:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForwardGraph" +// CHECK: "serialized_graph":"conv[f32]-\u003einvscale[f8e4m3fn]-\u003e" + )"); +} + +TEST_F(CudnnFusedConvRewriterTest, TestConvScaledReluActivationF8) { + TestF8( + // pre_hlo + R"( + HloModule Test + + ENTRY Test { + input = f8e4m3fn[1,128,6,6] parameter(0) + filter = f8e4m3fn[3,3,128,16] parameter(1) + input_f32 = f32[1,128,6,6] convert(input) + filter_f32 = f32[3,3,128,16] convert(filter) + z_scale = f32[] parameter(2) + z_scale_bcast = f32[1,16,6,6] broadcast(z_scale), dimensions={} + c = f32[] constant(0) + c_bcast = f32[1,16,6,6] broadcast(c), dimensions={} + conv_a = f32[1,16,6,6] convolution(input_f32, filter_f32), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1 + relu_a = f32[1,16,6,6] maximum(conv_a, c_bcast) + relu_a_scaled = f32[1,16,6,6] multiply(relu_a, z_scale_bcast) + c1 = f32[] constant(-448.) + c1_bcast = f32[1,16,6,6] broadcast(c1), dimensions={} + c2 = f32[] constant(448.) + c2_bcast = f32[1,16,6,6] broadcast(c2), dimensions={} + relu_a_clamped = f32[1,16,6,6] clamp(c1_bcast, relu_a_scaled, c2_bcast) + ROOT conv_f8 = f8e4m3fn[1,16,6,6] convert(relu_a_clamped) + + })", + // post_hlo + R"( +// CHECK: [[cudnn_conv_4_0:%[^ ]+]] = (f8e4m3fn[1,6,6,16]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[OPERAND0:%[^ ]+]], [[OPERAND1:%[^ ]+]], [[OPERAND2:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForwardGraph" +// CHECK: "serialized_graph":"conv[f32]-\u003erelu[f32]-\u003escale[f8e4m3fn]-\u003e" + )"); +} + TEST_F(CudnnFusedConvRewriterTest, TestConvInt8ToInt8) { // max(0, clamp(conv(x, w)))); for int8_t TestClamp( diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_pad_for_convolutions.cc b/tensorflow/compiler/xla/service/gpu/cudnn_pad_for_convolutions.cc index ac7efb51e31dae..56adc340b5fc7c 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_pad_for_convolutions.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_pad_for_convolutions.cc @@ -211,6 +211,7 @@ static StatusOr TryResolvePaddedShapesForTensorCore( switch (kind) { case CudnnConvKind::kForward: case CudnnConvKind::kForwardActivation: + case CudnnConvKind::kForwardGraph: return std::make_tuple(&new_lhs_shape, &new_rhs_shape, &new_result_shape); case CudnnConvKind::kBackwardInput: diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_padding_legalization.cc b/tensorflow/compiler/xla/service/gpu/gpu_conv_padding_legalization.cc index 1eaa56dd7db28b..5cd7979f927955 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_conv_padding_legalization.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_conv_padding_legalization.cc @@ -33,7 +33,9 @@ namespace gpu { namespace { bool IsForwardConvolutionCanonical(const HloInstruction& conv) { CHECK(conv.custom_call_target() == kCudnnConvForwardCallTarget || - conv.custom_call_target() == kCudnnConvBiasActivationForwardCallTarget); + conv.custom_call_target() == + kCudnnConvBiasActivationForwardCallTarget || + conv.custom_call_target() == kCudnnConvForwardGraphCallTarget); return window_util::HasSymmetricPadding(conv.window()) && !window_util::HasNegativePadding(conv.window()) && !window_util::HasDilation(conv.window()); @@ -415,6 +417,7 @@ StatusOr GpuConvPaddingLegalization::RunOnComputation( switch (kind) { case CudnnConvKind::kForward: case CudnnConvKind::kForwardActivation: + case CudnnConvKind::kForwardGraph: return CanonicalizeForwardConvolution(instruction); case CudnnConvKind::kBackwardInput: return CanonicalizeBackwardInputConvolution(instruction); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.cc b/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.cc index 3e9ba809238283..941817365782c2 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.cc @@ -87,6 +87,57 @@ Status RunGpuConvUnfused(const GpuConvParams& params, se::Stream* stream, filter_buf, output_buf); } +template +Status RunGpuConvGraph(const GpuConvParams& params, se::Stream* stream, + RunConvOptions options, + DeviceMemory input_buf, + DeviceMemory filter_buf, + DeviceMemory output_buf, + DeviceMemoryBase scratch_memory) { + if (params.config->conv_result_scale != 1) { + return InternalError( + "StreamExecutor doesn't support scaled convolution: %lf.", + params.config->conv_result_scale); + } + + TF_ASSIGN_OR_RETURN(se::dnn::ConvolutionKind kind, + GetDNNConvKindFromCudnnConvKind(params.config->kind)); + + TF_ASSIGN_OR_RETURN( + se::dnn::DataType input_type, + GetDNNDataTypeFromPrimitiveType(params.config->input_type)); + + TF_ASSIGN_OR_RETURN( + se::dnn::DataType output_type, + GetDNNDataTypeFromPrimitiveType(params.config->output_type)); + + se::dnn::LazyOpRunner* lazy_runner = + options.runner_cache->AsGraphConvRunner(); + std::optional> local_runner; + if (!lazy_runner) { + local_runner.emplace(params.config->algorithm); + lazy_runner = &*local_runner; + } + + se::dnn::GraphConvOp::Config config{kind, + input_type, + output_type, + params.config->input_descriptor, + params.config->filter_descriptor, + params.config->output_descriptor, + params.config->conv_desc, + params.config->serialized_graph}; + TF_ASSIGN_OR_RETURN(auto* runner, + lazy_runner->GetOrCreateRunner(config, stream)); + + std::vector operands = {input_buf, filter_buf, output_buf}; + // Insert the optional operands ahead of the output. + operands.insert(operands.end() - 1, params.operand_bufs.begin(), + params.operand_bufs.end()); + + return (*runner)(stream, options.profile_result, scratch_memory, operands); +} + template Status RunGpuConvForwardActivation(const GpuConvParams& params, se::Stream* stream, RunConvOptions options, @@ -174,6 +225,9 @@ Status RunGpuConvInternalImpl(const GpuConvParams& params, se::Stream* stream, return RunGpuConvForwardActivation( params, stream, options, input_buf, filter_buf, output_buf, scratch_memory); + case CudnnConvKind::kForwardGraph: + return RunGpuConvGraph(params, stream, options, input_buf, filter_buf, + output_buf, scratch_memory); } } return OkStatus(); @@ -272,10 +326,12 @@ StatusOr GetGpuConvConfig( config.kind = desc.kind; config.algorithm = se::dnn::AlgorithmDesc(backend_config.algorithm()); config.conv_result_scale = backend_config.conv_result_scale(); + config.serialized_graph = backend_config.serialized_graph(); switch (config.kind) { case CudnnConvKind::kForward: case CudnnConvKind::kForwardActivation: + case CudnnConvKind::kForwardGraph: config.input_shape = operand0_shape; config.filter_shape = operand1_shape; config.output_shape = result_shape; @@ -494,6 +550,7 @@ StatusOr GetGpuConvParams( switch (config.kind) { case CudnnConvKind::kForward: case CudnnConvKind::kForwardActivation: + case CudnnConvKind::kForwardGraph: params.input_buf = operand_buffers[0]; params.filter_buf = operand_buffers[1]; params.output_buf = result_buffer; @@ -510,6 +567,10 @@ StatusOr GetGpuConvParams( break; } + if (config.kind == CudnnConvKind::kForwardGraph) { + params.operand_bufs = {operand_buffers.begin() + 2, operand_buffers.end()}; + } + if (config.kind == CudnnConvKind::kForwardActivation) { params.fusion.emplace(); GpuConvParams::FusionParams& fusion = *params.fusion; @@ -532,6 +593,20 @@ Status RunGpuConv(const gpu::GpuConvConfig& config, PrimitiveType input_primitive_type = config.input_type; switch (input_primitive_type) { + case F8E4M3FN: + if (config.kind != CudnnConvKind::kForwardGraph) { + return InternalError("FP8 convolution requires graph mode."); + } + return RunGpuConvImpl(params, stream, scratch_memory, + options); + case F8E5M2: + if (config.kind != CudnnConvKind::kForwardGraph) { + return InternalError("FP8 convolution requires graph mode."); + } + return RunGpuConvImpl(params, stream, scratch_memory, + options); case F16: return RunGpuConvImpl( params, stream, scratch_memory, options); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h b/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h index 0c27d10099121e..dd490565ea38aa 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_CONV_RUNNER_H_ #include +#include #include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_instructions.h" @@ -69,6 +70,7 @@ struct GpuConvConfig { Shape filter_shape; Shape output_shape; std::optional fusion; + std::string serialized_graph; }; // Implementation struct exposed for debugging and log analysis. @@ -83,6 +85,9 @@ struct GpuConvParams { se::DeviceMemoryBase filter_buf; se::DeviceMemoryBase output_buf; + // Buffers for operands of pointwise ops. + std::vector operand_bufs; + std::optional fusion; }; @@ -91,28 +96,40 @@ struct GpuConvParams { // naturally play well with the typed APIs provided by StreamExecutor; rather // than rewriting everything here, just propagate the dynamic typing to one more // place by having either a FusedConvRunner or a ConvRunner. -class MaybeFusedConvRunner { +class GenericConvRunner { public: - MaybeFusedConvRunner() = default; + GenericConvRunner() = default; - explicit MaybeFusedConvRunner( + explicit GenericConvRunner( std::unique_ptr> runner) : repr_(std::move(runner)) {} - explicit MaybeFusedConvRunner( + explicit GenericConvRunner( + std::unique_ptr> runner) + : repr_(std::move(runner)) {} + + explicit GenericConvRunner( std::unique_ptr> runner) : repr_(std::move(runner)) {} - explicit MaybeFusedConvRunner(const GpuConvConfig& config) - : MaybeFusedConvRunner( - config.kind == CudnnConvKind::kForwardActivation - ? MaybeFusedConvRunner( - std::make_unique< - se::dnn::LazyOpRunner>( - config.algorithm)) - : MaybeFusedConvRunner( - std::make_unique>( - config.algorithm))) {} + explicit GenericConvRunner(const GpuConvConfig& config) + : GenericConvRunner(FromGpuConvConfig(config)) {} + + static GenericConvRunner FromGpuConvConfig(const GpuConvConfig& config) { + if (config.kind == CudnnConvKind::kForwardGraph) { + return GenericConvRunner( + std::make_unique>( + config.algorithm)); + } else if (config.kind == CudnnConvKind::kForwardActivation) { + return GenericConvRunner( + std::make_unique>( + config.algorithm)); + } else { + return GenericConvRunner( + std::make_unique>( + config.algorithm)); + } + } se::dnn::AlgorithmDesc ToAlgorithmDesc() const { return std::visit(ToAlgorithmDescVisitor{}, repr_); @@ -126,6 +143,15 @@ class MaybeFusedConvRunner { .get(); } + se::dnn::LazyOpRunner* AsGraphConvRunner() { + CHECK(std::holds_alternative< + std::unique_ptr>>(repr_)); + return std::get< + std::unique_ptr>>( + repr_) + .get(); + } + se::dnn::LazyOpRunner* AsFusedConvRunner() { CHECK(std::holds_alternative< std::unique_ptr>>(repr_)); @@ -150,6 +176,7 @@ class MaybeFusedConvRunner { using Repr = std::variant>, + std::unique_ptr>, std::unique_ptr>>; Repr repr_; }; @@ -160,7 +187,7 @@ struct RunConvOptions { // Use this runner cache (and its configured algorithm), instead of the one // from the instruction. - MaybeFusedConvRunner* runner_cache; + GenericConvRunner* runner_cache; }; // This file contains low-level routines for running cudnn convolutions. diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc index c0f9a392717f3c..c8140d0cb1534b 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc @@ -87,6 +87,11 @@ HeuristicLayoutAssignment(const HloInstruction* instr, return kAllNHWC; } + if (primitive_util::IsF8Type(input_ty)) { + VLOG(2) << "Using NHWC for FP8 conv " << instr->ToString(); + return kAllNHWC; + } + const DebugOptions& debug_options = instr->GetModule()->config().debug_options(); @@ -158,6 +163,7 @@ Status GpuLayoutAssignment::AddBackendConstraintsToDnnConvCustomCall( switch (kind) { case CudnnConvKind::kForward: case CudnnConvKind::kForwardActivation: + case CudnnConvKind::kForwardGraph: input_shape = &lhs_shape; filter_shape = &rhs_shape; output_shape = &result_shape; @@ -200,20 +206,33 @@ Status GpuLayoutAssignment::AddBackendConstraintsToDnnConvCustomCall( TF_RETURN_IF_ERROR(SetOperandLayout(lhs_shape, instr, 0)); TF_RETURN_IF_ERROR(SetOperandLayout(rhs_shape, instr, 1)); TF_RETURN_IF_ERROR(SetBufferLayout(result_shape.layout(), *call_result_buf)); - // instr->operand(2), if exists, is the bias buffer. There is no need to - // assign layout to it, as it has only one dimension. - + // For fused convolutions, instr->operand(2), if exists, is the bias buffer. + // There is no need to assign layout to it, as it has only one dimension. // instr->operand(3), if exists, is the side input buffer. - if (instr->operand_count() == 4) { - if (kind != CudnnConvKind::kForwardActivation) { - return InternalError( - "Invalid convolution. Conv has a side input, but kind is not fused " - "conv forward: %s", - instr->ToString()); - } + if (kind == CudnnConvKind::kForwardActivation && + instr->operand_count() == 4) { // The side input layout must match the output layout. TF_RETURN_IF_ERROR(SetOperandLayout(*output_shape, instr, 3)); } + + // For graph convolutions, align the layouts of the non-scalar inputs to any + // pointwise ops with the output layout. + if (kind == CudnnConvKind::kForwardGraph) { + for (int k = 2; k < instr->operand_count(); ++k) { + if (!ShapeUtil::IsScalar(instr->operand(k)->shape())) { + TF_RETURN_IF_ERROR(SetOperandLayout(*output_shape, instr, k)); + } + } + } + + if (instr->operand_count() > 2 && kind != CudnnConvKind::kForwardActivation && + kind != CudnnConvKind::kForwardGraph) { + return InternalError( + "Invalid convolution. Conv has a side input, but kind is not fused " + "conv forward or graph conv foward: %s", + instr->ToString()); + } + return OkStatus(); } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index d1f60a0e977385..c798a9bc6cb574 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -844,6 +844,7 @@ Status IrEmitterUnnested::EmitConvolutionThunk(mlir::Operation* op) { using mlir::lmhlo_gpu::ConvBackwardInputOp; using mlir::lmhlo_gpu::ConvForwardFusedOp; using mlir::lmhlo_gpu::ConvForwardFusedSideInputOp; + using mlir::lmhlo_gpu::ConvForwardGraphOp; using mlir::lmhlo_gpu::ConvForwardOp; // Last 2 operands of the convolution operation are the result and scratch. @@ -942,6 +943,11 @@ Status IrEmitterUnnested::EmitConvolutionThunk(mlir::Operation* op) { } else if (auto conv = dyn_cast(op)) { descriptor.kind = CudnnConvKind::kBackwardFilter; fill_conv_descriptor(conv); + } else if (auto conv = dyn_cast(op)) { + descriptor.kind = CudnnConvKind::kForwardGraph; + fill_conv_descriptor(conv); + descriptor.backend_config.set_serialized_graph( + conv.getSerializedGraph().data()); } else if (auto conv = dyn_cast(op)) { descriptor.kind = CudnnConvKind::kForwardActivation; fill_conv_descriptor(conv); @@ -955,7 +961,7 @@ Status IrEmitterUnnested::EmitConvolutionThunk(mlir::Operation* op) { descriptor.backend_config.set_side_input_scale( conv.getSideInputScale().convertToDouble()); } else { - return InternalError("Unexpected operation"); + return InternalError("EmitConvolutionThunk: Unexpected operation"); } TF_ASSIGN_OR_RETURN(GpuConvConfig config, GetGpuConvConfig(descriptor, "")); AddThunkToThunkSequence(std::make_unique( @@ -4681,6 +4687,7 @@ Status IrEmitterUnnested::EmitOp(mlir::Operation* op) { #endif // GOOGLE_CUDA if (mlir::isaconfig) {} GpuConvConfig config; - MaybeFusedConvRunner runner; + GenericConvRunner runner; }; class StreamExecutorConvRunners : public runtime::StateVector {}; diff --git a/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc index 6765af0a38330f..9a19f4ac9a96c7 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc @@ -485,6 +485,8 @@ StatusOr GetDNNConvKindFromCudnnConvKind( return se::dnn::FORWARD; case CudnnConvKind::kForwardActivation: return se::dnn::FORWARD_BIAS_ACTIVATION; + case CudnnConvKind::kForwardGraph: + return se::dnn::FORWARD_GRAPH; default: break; } @@ -535,6 +537,10 @@ StatusOr GetDNNDataTypeFromPrimitiveType( return se::dnn::ToDataType::value; case BF16: return se::dnn::ToDataType::value; + case F8E4M3FN: + return se::dnn::ToDataType::value; + case F8E5M2: + return se::dnn::ToDataType::value; default: break; } diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index 3c44c74fe7cd07..85de298679ab53 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -153,8 +153,8 @@ OperandLayoutConstraint::OperandLayoutConstraint( instruction_(instruction), operand_no_(operand_no) { CHECK(shape_layout.LayoutIsSet()); - CHECK(ShapeUtil::Compatible(shape_layout.shape(), - instruction->operand(operand_no)->shape())) + CHECK(ShapeUtil::CompatibleIgnoringElementType( + shape_layout.shape(), instruction->operand(operand_no)->shape())) << shape_layout.shape() << " is not compatible with " << instruction->operand(operand_no)->shape() << " (for operand " << operand_no << " of instruction " << instruction->ToString() << ")"; diff --git a/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc b/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc index 5b22a54a9957a9..3f426604d2aba0 100644 --- a/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc @@ -374,7 +374,8 @@ void PreloadCudnnSubLibs(PreloadCudnnType type) { void PreloadCudnnSubLibsHelper(dnn::ConvolutionKind kind) { switch (kind) { - case dnn::ConvolutionKind::FORWARD: { + case dnn::ConvolutionKind::FORWARD: + case dnn::ConvolutionKind::FORWARD_GRAPH: { PreloadCudnnSubLibs(PreloadCudnnType::ConvFwd); break; } @@ -1135,6 +1136,12 @@ cudnnDataType_t ToCudnnDataType( #if CUDNN_VERSION >= 8200 case dnn::DataType::kBF16: return CUDNN_DATA_BFLOAT16; +#endif +#if CUDNN_VERSION >= 8900 + case dnn::DataType::kF8E4M3FN: + return CUDNN_DATA_FP8_E4M3; + case dnn::DataType::kF8E5M2: + return CUDNN_DATA_FP8_E5M2; #endif default: LOG(FATAL) << "Invalid DNN data type: " << static_cast(data_type); @@ -3431,6 +3438,11 @@ dnn::DataType GetConvAccumulatorType(dnn::DataType data_type) { return CudnnEnvVar::IsEnabled() ? dnn::DataType::kFloat : dnn::DataType::kBF16; +#endif +#if CUDNN_VERSION >= 8900 + case dnn::DataType::kF8E4M3FN: + case dnn::DataType::kF8E5M2: + return dnn::DataType::kFloat; #endif default: LOG(FATAL) << "Invalid DNN data type: " << static_cast(data_type); @@ -3449,7 +3461,8 @@ cudnnBackendDescriptorType_t GetCudnnConvolutionType( dnn::ConvolutionKind kind) { cudnnBackendDescriptorType_t conv_mode; switch (kind) { - case dnn::ConvolutionKind::FORWARD: { + case dnn::ConvolutionKind::FORWARD: + case dnn::ConvolutionKind::FORWARD_GRAPH: { conv_mode = CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR; break; } @@ -4140,6 +4153,287 @@ GetCudnnOperationGraph(dnn::ConvolutionKind kind, dnn::DataType input_type, return std::make_unique(std::move(opGraph)); } +enum class InputKind { kNone, kScalar, kTensor }; + +tsl::StatusOr PrimitiveTypeStringToDnnType( + string data_type_string) { + if (data_type_string == "f8e4m3fn") { + return dnn::DataType::kF8E4M3FN; + } else if (data_type_string == "f8e5m2") { + return dnn::DataType::kF8E5M2; + } else if (data_type_string == "bf16") { + return dnn::DataType::kBF16; + } else if (data_type_string == "f16") { + return dnn::DataType::kHalf; + } else if (data_type_string == "f32") { + return dnn::DataType::kFloat; + } else { + return tsl::errors::Internal("Unsupported primitive type."); + } +} + +tsl::StatusOr> +OpNameStringToInputKindAndMode(string opstring) { +#define KIND_AND_MODE_FROM_OP_STRING(OPSTRING, INPUTKIND, PWMODE) \ + if (opstring == OPSTRING) { \ + return std::make_pair(INPUTKIND, PWMODE); \ + } + + KIND_AND_MODE_FROM_OP_STRING("add", InputKind::kTensor, CUDNN_POINTWISE_ADD) + KIND_AND_MODE_FROM_OP_STRING("relu", InputKind::kNone, + CUDNN_POINTWISE_RELU_FWD) + KIND_AND_MODE_FROM_OP_STRING("scale", InputKind::kScalar, CUDNN_POINTWISE_MUL) + KIND_AND_MODE_FROM_OP_STRING("invscale", InputKind::kScalar, + CUDNN_POINTWISE_DIV) + +#undef KIND_AND_MODE_FROM_OP_STRING + + return tsl::errors::Internal("Unknown op."); +} + +tsl::StatusOr, + std::vector>> +GetGenericCudnnOperationGraph( + dnn::ConvolutionKind kind, dnn::DataType input_type, + const dnn::BatchDescriptor& input_descriptor, + const dnn::FilterDescriptor& filter_descriptor, + const dnn::BatchDescriptor& output_descriptor, + const dnn::ConvolutionDescriptor& convolution_descriptor, + CudnnHandle& cudnn, string serialized_graph = "") { + PreloadCudnnSubLibsHelper(kind); + std::vector virtual_uids, non_virtual_uids; + + // Struct to describe the ops (convolution and pointwise) in the sequence + // described by the graph. + struct SequentialOpDescriptor { + InputKind input_kind; + std::variant mode; + dnn::DataType output_type; + }; + + // Format of the serialized graph is + // "conv[output type]->op name[output type]->op name[output type]->...". + auto deserialize_cudnn_graph = + [&]() -> tsl::StatusOr> { + std::vector op_sequence = {}; + string::size_type pos = 0; + while (pos < serialized_graph.size()) { + std::variant mode; + dnn::DataType output_type; + InputKind input_kind = InputKind::kNone; + string::size_type m = serialized_graph.find('[', pos); + string::size_type n = serialized_graph.find(']', pos); + string op_string = serialized_graph.substr(pos, m - pos); + string data_type_string = serialized_graph.substr(m + 1, n - m - 1); + TF_ASSIGN_OR_RETURN(output_type, + PrimitiveTypeStringToDnnType(data_type_string)); + if (op_string == "conv") { + mode = convolution_descriptor.convolution_not_crosscorr() + ? CUDNN_CONVOLUTION + : CUDNN_CROSS_CORRELATION; + } else { + TF_ASSIGN_OR_RETURN(std::tie(input_kind, mode), + OpNameStringToInputKindAndMode(op_string)); + } + op_sequence.push_back({input_kind, mode, output_type}); + pos = n + 3; + } + return op_sequence; + }; + + TF_ASSIGN_OR_RETURN(std::vector op_sequence, + deserialize_cudnn_graph()); + + if (op_sequence.empty()) { + return tsl::errors::Internal("No supported ops in convolution graph."); + } + + cudnnBackendDescriptorType_t conv_mode = GetCudnnConvolutionType(kind); + + std::vector ops = {}; + + // x tensor. + int vector_size, vector_dim; + std::tie(vector_size, vector_dim) = + GetTensorVectorSizeAndDim(input_descriptor, input_type); + std::vector input_dims = input_descriptor.vectorized_dims( + dnn::DataLayout::kBatchDepthYX, vector_size, vector_dim); + std::vector input_strides = input_descriptor.vectorized_strides( + dnn::DataLayout::kBatchDepthYX, vector_size, vector_dim); + + TF_ASSIGN_OR_RETURN(auto tensor_x, + CreateCudnnTensor(input_dims, input_strides, 'x', + input_type, vector_size, vector_dim)); + non_virtual_uids.push_back('x'); + + // w tensor. + std::tie(vector_size, vector_dim) = + GetTensorVectorSizeAndDim(filter_descriptor, input_type); + std::vector filter_dims = filter_descriptor.vectorized_dims( + dnn::FilterLayout::kOutputInputYX, vector_size, vector_dim); + std::vector filter_strides = filter_descriptor.vectorized_strides( + dnn::FilterLayout::kOutputInputYX, vector_size, vector_dim); + + cudnnBackendTensorReordering_t tensor_ordering_type = + filter_descriptor.layout() == + dnn::FilterLayout::kOutputInputYX32_CudnnReordered + ? CUDNN_TENSOR_REORDERING_INT8x32 + : CUDNN_TENSOR_REORDERING_NONE; + TF_ASSIGN_OR_RETURN( + auto tensor_w, + CreateCudnnTensor(filter_dims, filter_strides, 'w', input_type, + vector_size, vector_dim, + /*is_virtual=*/false, tensor_ordering_type)); + non_virtual_uids.push_back('w'); + + // y tensor. + std::tie(vector_size, vector_dim) = + GetTensorVectorSizeAndDim(output_descriptor, op_sequence[0].output_type); + std::vector output_dims = output_descriptor.vectorized_dims( + dnn::DataLayout::kBatchDepthYX, vector_size, vector_dim); + std::vector output_strides = output_descriptor.vectorized_strides( + dnn::DataLayout::kBatchDepthYX, vector_size, vector_dim); + + TF_ASSIGN_OR_RETURN( + auto tensor_y, + CreateCudnnTensor(output_dims, output_strides, 'y', + op_sequence[0].output_type, vector_size, vector_dim, + /*is_virtual=*/true)); + virtual_uids.push_back('y'); + + auto accumulator_type = ToCudnnDataType(GetConvAccumulatorType(input_type)); + CHECK_NE(convolution_descriptor.pad_alignment(), + dnn::PadAlignment::kTensorFlowPadding) + << "TensorFlow padding alignment is not supported."; + + int conv_dim = convolution_descriptor.ndims(); + auto conv_desc = + cudnn_frontend::ConvDescBuilder() + .setComputeType(accumulator_type) + .setMathMode(std::get(op_sequence[0].mode)) + .setSpatialDimCount(conv_dim) + .setSpatialStride(conv_dim, convolution_descriptor.strides().data()) + .setPrePadding(conv_dim, convolution_descriptor.padding().data()) + .setPostPadding(conv_dim, convolution_descriptor.padding().data()) + .setDilation(conv_dim, convolution_descriptor.dilations().data()) + .build(); + RETURN_MSG_IF_CUDNN_ERROR(conv_desc); + + // CUDNN Operation + double alpha = 1.0; + double beta = 0.0; + cudnn_frontend::Operation op = cudnn_frontend::OperationBuilder(conv_mode) + .setxDesc(tensor_x) + .setyDesc(tensor_y) + .setwDesc(tensor_w) + .setcDesc(conv_desc) + .setAlpha(alpha) + .setBeta(beta) + .build(); + RETURN_MSG_IF_CUDNN_ERROR(op); + ops.push_back(std::move(op)); + VLOG(4) << "\nTensor_x: " << tensor_x.describe() + << "\nTensor_y: " << tensor_y.describe() + << "\nTensor_w: " << tensor_w.describe() + << "\nConv desc: " << conv_desc.describe() + << "\nOp: " << ops.back().describe(); + + // Add any pointwise ops to the cuDNN graph. + for (int op_num = 0; op_num < op_sequence.size(); ++op_num) { + SequentialOpDescriptor op_descriptor = op_sequence[op_num]; + if (std::holds_alternative(op_descriptor.mode)) { + std::optional second_operand; + // Create cuDNN tensors for operands of binary ops (side inputs). + if (op_descriptor.input_kind == InputKind::kScalar) { + std::vector scale_dim(4, 1); + second_operand = + cudnn_frontend::TensorBuilder() + .setDim(4, scale_dim.data()) + .setStrides(4, scale_dim.data()) + .setId(non_virtual_uids.emplace_back( + std::min(non_virtual_uids.back(), virtual_uids.back()) - 1)) + .setAlignment(32) + .setDataType( + ToCudnnDataType(op_sequence[op_num - 1].output_type)) + .build(); + VLOG(4) << "\nPointwise operand: " << second_operand->describe(); + } else if (op_descriptor.input_kind == InputKind::kTensor) { + second_operand = + cudnn_frontend::TensorBuilder() + .cloneFrom(tensor_y, non_virtual_uids.emplace_back( + std::min(non_virtual_uids.back(), + virtual_uids.back()) - + 1)) + .setVirtual(false) + .setAlignment(32) + .setDataType( + ToCudnnDataType(op_sequence[op_num - 1].output_type)) + .build(); + VLOG(4) << "\nPointwise operand: " << second_operand->describe(); + } + + // Create the result tensor of the op. + cudnn_frontend::Tensor result = + cudnn_frontend::TensorBuilder() + .cloneFrom( + tensor_y, + std::min(non_virtual_uids.back(), virtual_uids.back()) - 1) + .setVirtual(op_num != op_sequence.size() - 1) + .setDataType(ToCudnnDataType(op_descriptor.output_type)) + .build(); + VLOG(4) << "\nPointwise result: " << result.describe(); + + if (op_num == op_sequence.size() - 1) { + non_virtual_uids.emplace_back( + std::min(non_virtual_uids.back(), virtual_uids.back()) - 1); + } else { + virtual_uids.emplace_back( + std::min(non_virtual_uids.back(), virtual_uids.back()) - 1); + } + + // Create the descriptor of the op. + cudnn_frontend::PointWiseDesc desc = + cudnn_frontend::PointWiseDescBuilder() + .setMode(std::get(op_descriptor.mode)) + .setMathPrecision(CUDNN_DATA_FLOAT) + .build(); + VLOG(4) << "\nPointwise op desc: " << desc.describe(); + + // Add the op to the operation graph. + if (second_operand.has_value()) { + ops.emplace_back(cudnn_frontend::OperationBuilder( + CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(ops.back().getOutputTensor()) + .setbDesc(second_operand.value()) + .setyDesc(result) + .setpwDesc(desc) + .build()); + } else { + ops.emplace_back(cudnn_frontend::OperationBuilder( + CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) + .setxDesc(ops.back().getOutputTensor()) + .setyDesc(result) + .setpwDesc(desc) + .build()); + } + RETURN_MSG_IF_CUDNN_ERROR(ops.back()); + VLOG(4) << "\nOp: " << ops.back().describe(); + } + } + + // Construct the cuDNN OperationGraph. + auto opGraph = cudnn_frontend::OperationGraphBuilder() + .setHandle(cudnn.handle()) + .setOperationGraph(ops) + .build(); + RETURN_MSG_IF_CUDNN_ERROR(opGraph); + VLOG(4) << "\ncuDNN OperationGraph: " << opGraph.describe(); + + return make_pair(std::unique_ptr( + new cudnn_frontend::OperationGraph(std::move(opGraph))), + non_virtual_uids); +} + bool SideInputNeeded(dnn::ActivationMode activation_mode, double conv_scale, double side_input_scale) { // Cudnn uses precompiled kernels to perform the Conv-Add-BiasAdd-Act when the @@ -6215,6 +6509,75 @@ class CudnnExecutionPlanRunner return ExecutionPlanToAlgorithmDesc(plan_, workspace_size_); } + tsl::Status operator()(Stream* stream, dnn::ProfileResult* profile_result, + DeviceMemoryBase scratch_memory, + std::vector inputs) const { + if (static_cast(parent_) != + stream->parent()->implementation()) { + return tsl::errors::Internal( + "CudnnExecutionPlanRunner cached across multiple StreamExecutors."); + } + + auto cudnn = cudnn_->GetHandle(parent_, stream); + + size_t workspace_size = plan_.getWorkspaceSize(); + RETURN_MSG_IF_CUDNN_ERROR(plan_); + + std::vector data_uids_vec = {data_uids_.cbegin(), + data_uids_.cend()}; + std::vector data_ptrs_vec; + for (DeviceMemoryBase input : inputs) { + data_ptrs_vec.push_back(input.opaque()); + } + + auto variantPack = + cudnn_frontend::VariantPackBuilder() + .setWorkspacePointer(scratch_memory.opaque()) + .setDataPointers(data_ptrs_vec.size(), data_ptrs_vec.data()) + .setUids(data_uids_vec.size(), data_uids_vec.data()) + .build(); + RETURN_MSG_IF_CUDNN_ERROR(variantPack); + + VLOG(4) << "\nDo cudnn execution plan with plan tag: " << plan_.getTag() + << "\nWorkspace size in bytes: " << workspace_size + << "\nVariantPack: " << variantPack.describe(); + + const bool is_profiling = profile_result != nullptr; + + std::unique_ptr timer; + if (is_profiling) { + timer.reset(new GpuTimer(parent_)); // NOLINT + // The start and stop of the timer should be as close to the Cudnn call as + // possible. It is still possible for other threads to issue workload on + // to this stream. So it could take multiple profiling measurements. + if (!timer->Init() || !timer->Start(AsGpuStream(stream))) { + return tsl::Status(absl::StatusCode::kInternal, + "Failed to start timer"); + } + } + + cudnnStatus_t status = cudnnBackendExecute( + cudnn.handle(), plan_.get_raw_desc(), variantPack.get_raw_desc()); + RETURN_IF_CUDNN_ERROR(status); + + if (is_profiling) { + if (!timer->Stop(AsGpuStream(stream))) { + return tsl::Status(absl::StatusCode::kInternal, "Failed to stop timer"); + } + TF_ASSIGN_OR_RETURN(auto desc, ToAlgorithmDesc()); + profile_result->set_algorithm(desc); + profile_result->set_elapsed_time_in_ms(timer->GetElapsedMilliseconds()); + profile_result->set_scratch_size(scratch_memory.size()); + + VLOG(4) << "cudnn op with plan " << plan_.getTag() + << ", workspace_size=" << workspace_size << " -> " + << CudnnStatusToString(status) << " in " + << timer->GetElapsedMilliseconds() << "ms"; + } + + return tsl::OkStatus(); + } + tsl::Status operator()(Stream* stream, dnn::ProfileResult* profile_result, DeviceMemoryBase scratch_memory, Args... inputs) const override { @@ -6576,6 +6939,38 @@ tsl::Status CudnnSupport::GetConvolveRunners( #endif // CUDNN_VERSION >= 8100 && TF_ENABLE_CUDNN_FRONTEND } +tsl::Status CudnnSupport::GetGraphConvolveRunners( + bool use_cudnn_frontend, dnn::ConvolutionKind kind, + dnn::DataType input_type, dnn::DataType output_type, Stream* stream, + const dnn::BatchDescriptor& input_descriptor, + DeviceMemoryBase /*input_data*/, + const dnn::FilterDescriptor& filter_descriptor, + DeviceMemoryBase /*filter_data*/, + const dnn::BatchDescriptor& output_descriptor, + DeviceMemoryBase /*output_data*/, + const dnn::ConvolutionDescriptor& convolution_descriptor, bool use_fallback, + ScratchAllocator* /*scratch_allocator*/, + const NumericOptions& numeric_options, + std::vector>* out_exec_plans, + string serialized_graph) { + + if (!use_cudnn_frontend) { + return tsl::errors::Internal( + "cuDNN graph execution requires the use of the cuDNN frontend."); + } + + auto cudnn = cudnn_->GetHandle(parent_, stream); + TF_ASSIGN_OR_RETURN( + auto op_graph_and_uids, + GetGenericCudnnOperationGraph( + kind, input_type, input_descriptor, filter_descriptor, + output_descriptor, convolution_descriptor, cudnn, serialized_graph)); + return CreateOpRunners( + stream, cudnn, parent_, cudnn_.get(), std::move(op_graph_and_uids.first), + kind, input_type, op_graph_and_uids.second, use_fallback, out_exec_plans, + /*need_side_input=*/false, numeric_options); +} + tsl::StatusOr> CudnnSupport::ConvolveRunnerFromDesc( Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, @@ -6643,6 +7038,46 @@ CudnnSupport::ConvolveRunnerFromDesc( #endif } +tsl::StatusOr> +CudnnSupport::GraphConvolveRunnerFromDesc( + Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, + dnn::ConvolutionKind kind, dnn::DataType input_type, + dnn::DataType output_type, const dnn::BatchDescriptor& input_descriptor, + const dnn::FilterDescriptor& filter_descriptor, + const dnn::BatchDescriptor& output_descriptor, + const dnn::ConvolutionDescriptor& convolution_descriptor, + string serialized_graph) { + if (!algorithm_desc.is_cudnn_frontend()) { + return tsl::errors::Internal( + "cuDNN graph execution requires the use of the cuDNN frontend."); + } + +#if CUDNN_VERSION >= 8900 && TF_ENABLE_CUDNN_FRONTEND + auto cudnn = cudnn_->GetHandle(parent_, stream); + + TF_ASSIGN_OR_RETURN( + auto op_graph_and_uids, + GetGenericCudnnOperationGraph( + kind, input_type, input_descriptor, filter_descriptor, + output_descriptor, convolution_descriptor, cudnn, serialized_graph)); + + TF_ASSIGN_OR_RETURN( + auto execution_plan, + RebuildExecutionPlan(cudnn, algorithm_desc, *op_graph_and_uids.first)); + + TF_ASSIGN_OR_RETURN(auto runner, + CudnnExecutionPlanRunner::Create( + parent_, cudnn_.get(), std::move(execution_plan), + op_graph_and_uids.second, + /*need_side_input=*/false)); + return {std::make_unique>( + std::move(runner))}; +#else + return tsl::errors::Unimplemented( + "cuDNN graph execution requires cuDNN version 8.9 or higher."); +#endif +} + class CudnnLegacyFusedConvRunner : public dnn::FusedConvRunner { public: // Queries the workspace size and constructs a 'CudnnLegacyFusedConvRunner'. diff --git a/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.h b/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.h index e8cbe9207c4476..2639bdfe918600 100644 --- a/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.h +++ b/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.h @@ -238,6 +238,30 @@ class CudnnSupport : public dnn::DnnSupport { const dnn::BatchDescriptor& output_descriptor, const dnn::ConvolutionDescriptor& convolution_descriptor) override; + tsl::Status GetGraphConvolveRunners( + bool use_cudnn_frontend, dnn::ConvolutionKind kind, + dnn::DataType input_type, dnn::DataType output_type, Stream* stream, + const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data, + const dnn::FilterDescriptor& filter_descriptor, + DeviceMemoryBase filter_data, + const dnn::BatchDescriptor& output_descriptor, + DeviceMemoryBase output_data, + const dnn::ConvolutionDescriptor& convolution_descriptor, + bool use_fallback, ScratchAllocator* scratch_allocator, + const NumericOptions& numeric_options, + std::vector>* out_exec_plans, + string serialized_graph) override; + + tsl::StatusOr> + GraphConvolveRunnerFromDesc( + Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, + dnn::ConvolutionKind kind, dnn::DataType input_type, + dnn::DataType output_type, const dnn::BatchDescriptor& input_descriptor, + const dnn::FilterDescriptor& filter_descriptor, + const dnn::BatchDescriptor& output_descriptor, + const dnn::ConvolutionDescriptor& convolution_descriptor, + string serialized_graph) override; + tsl::Status GetFusedConvolveRunners( bool use_cudnn_frontend, dnn::ConvolutionKind kind, dnn::DataType input_type, dnn::DataType bias_type, diff --git a/tensorflow/compiler/xla/stream_executor/dnn.cc b/tensorflow/compiler/xla/stream_executor/dnn.cc index 7d48b355259f85..474fc51c017b8a 100644 --- a/tensorflow/compiler/xla/stream_executor/dnn.cc +++ b/tensorflow/compiler/xla/stream_executor/dnn.cc @@ -149,6 +149,36 @@ DnnSupport::ConvolveRunnerFromDesc( return tsl::errors::Unimplemented("ConvolveRunnerFromDesc not implemented."); } +tsl::Status DnnSupport::GetGraphConvolveRunners( + bool /* use_cudnn_frontend */, dnn::ConvolutionKind /*kind*/, + dnn::DataType /*input_type*/, dnn::DataType /*output_type*/, + Stream* /*stream*/, const dnn::BatchDescriptor& /*input_descriptor*/, + DeviceMemoryBase /*input_data*/, + const dnn::FilterDescriptor& /*filter_descriptor*/, + DeviceMemoryBase /*filter_data*/, + const dnn::BatchDescriptor& /*output_descriptor*/, + DeviceMemoryBase /*output_data*/, + const dnn::ConvolutionDescriptor& /*convolution_descriptor*/, + bool /*use_fallback*/, ScratchAllocator* /*scratch_allocator*/, + const NumericOptions& /*numeric_options*/, + std::vector>* /*exec_plans*/, + string cudnn_grapph) { + return tsl::errors::Unimplemented("GetGraphConvolveRunners not implemented."); +} + +tsl::StatusOr> +DnnSupport::GraphConvolveRunnerFromDesc( + Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, + dnn::ConvolutionKind kind, dnn::DataType element_type, + dnn::DataType output_type, const dnn::BatchDescriptor& input_descriptor, + const dnn::FilterDescriptor& filter_descriptor, + const dnn::BatchDescriptor& output_descriptor, + const dnn::ConvolutionDescriptor& convolution_descriptor, + string serialized_graph) { + return tsl::errors::Unimplemented( + "GraphConvolveRunnerFromDesc not implemented."); +} + tsl::Status DnnSupport::GetFusedConvolveRunners( bool use_cudnn_frontend, dnn::ConvolutionKind kind, dnn::DataType element_type, dnn::DataType bias_type, diff --git a/tensorflow/compiler/xla/stream_executor/dnn.h b/tensorflow/compiler/xla/stream_executor/dnn.h index ea0a94989e51e7..c34dd6d8956587 100644 --- a/tensorflow/compiler/xla/stream_executor/dnn.h +++ b/tensorflow/compiler/xla/stream_executor/dnn.h @@ -965,6 +965,13 @@ class OpRunner { virtual tsl::Status operator()(Stream*, ProfileResult*, DeviceMemoryBase scratch_memory, Args... args) const = 0; + + // Launch the operation with a variable number of operands. + virtual tsl::Status operator()(Stream*, ProfileResult*, + DeviceMemoryBase scratch_memory, + std::vector) const { + return tsl::errors::Unimplemented("operator() not implemented."); + }; }; using ConvSignature = void(DeviceMemoryBase /* input_data */, @@ -972,6 +979,8 @@ using ConvSignature = void(DeviceMemoryBase /* input_data */, DeviceMemoryBase /* output_data */); using ConvRunner = OpRunner; +using GraphConvRunner = OpRunner; + using FusedConvSignature = void(DeviceMemoryBase /* input_data */, DeviceMemoryBase /* filter_data */, DeviceMemoryBase /* side_input_data */, @@ -1626,6 +1635,30 @@ class DnnSupport { const dnn::BatchDescriptor& output_descriptor, const dnn::ConvolutionDescriptor& convolution_descriptor); + virtual tsl::Status GetGraphConvolveRunners( + bool use_cudnn_frontend, dnn::ConvolutionKind kind, + dnn::DataType input_type, dnn::DataType output_type, Stream* stream, + const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data, + const dnn::FilterDescriptor& filter_descriptor, + DeviceMemoryBase filter_data, + const dnn::BatchDescriptor& output_descriptor, + DeviceMemoryBase output_data, + const dnn::ConvolutionDescriptor& convolution_descriptor, + bool use_fallback, ScratchAllocator* scratch_allocator, + const NumericOptions& numeric_options, + std::vector>* out_exec_plans, + string serialized_graph); + + virtual tsl::StatusOr> + GraphConvolveRunnerFromDesc( + Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, + dnn::ConvolutionKind kind, dnn::DataType element_type, + dnn::DataType output_type, const dnn::BatchDescriptor& input_descriptor, + const dnn::FilterDescriptor& filter_descriptor, + const dnn::BatchDescriptor& output_descriptor, + const dnn::ConvolutionDescriptor& convolution_descriptor, + string serialized_graph); + virtual tsl::Status GetFusedConvolveRunners( bool use_cudnn_frontend, dnn::ConvolutionKind kind, dnn::DataType element_type, dnn::DataType bias_type, diff --git a/tensorflow/compiler/xla/stream_executor/lazy_op_runner.h b/tensorflow/compiler/xla/stream_executor/lazy_op_runner.h index f51ffd59f8dc92..11ef2b3497793f 100644 --- a/tensorflow/compiler/xla/stream_executor/lazy_op_runner.h +++ b/tensorflow/compiler/xla/stream_executor/lazy_op_runner.h @@ -158,6 +158,32 @@ struct ConvOp { } }; +// Implementation of the concept required by LazyOpRunner, for +// GraphConvolveRunner. +struct GraphConvOp { + using Signature = ConvSignature; + + struct Config { + ConvolutionKind kind; + DataType input_type, output_type; + const BatchDescriptor& input_descriptor; + const FilterDescriptor& filter_descriptor; + const BatchDescriptor& output_descriptor; + const ConvolutionDescriptor& convolution_descriptor; + string serialized_graph; + }; + + static tsl::StatusOr>> + RunnerFromAlgorithmDesc(const AlgorithmDesc& desc, Config config, + Stream* stream) { + return stream->GraphConvolveRunnerFromDesc( + desc, config.kind, config.input_type, config.output_type, + config.input_descriptor, config.filter_descriptor, + config.output_descriptor, config.convolution_descriptor, + config.serialized_graph); + } +}; + // Implementation of the concept required by LazyOpRunner, for LazyConvRunner. struct FusedConvOp { using Signature = FusedConvSignature; diff --git a/tensorflow/compiler/xla/stream_executor/stream.h b/tensorflow/compiler/xla/stream_executor/stream.h index 2d156fb6462b74..b9811f0a467153 100644 --- a/tensorflow/compiler/xla/stream_executor/stream.h +++ b/tensorflow/compiler/xla/stream_executor/stream.h @@ -435,6 +435,25 @@ class Stream { filter_descriptor, output_descriptor, convolution_descriptor); } + tsl::StatusOr> + GraphConvolveRunnerFromDesc( + const dnn::AlgorithmDesc &algorithm_desc, dnn::ConvolutionKind kind, + dnn::DataType element_type, dnn::DataType output_type, + const dnn::BatchDescriptor &input_descriptor, + const dnn::FilterDescriptor &filter_descriptor, + const dnn::BatchDescriptor &output_descriptor, + const dnn::ConvolutionDescriptor &convolution_descriptor, + string serialized_graph) { + dnn::DnnSupport *dnn_support = parent_->AsDnn(); + if (!dnn_support) { + return tsl::errors::Unimplemented("DNN library is not found."); + } + return dnn_support->GraphConvolveRunnerFromDesc( + this, algorithm_desc, kind, element_type, output_type, input_descriptor, + filter_descriptor, output_descriptor, convolution_descriptor, + serialized_graph); + } + tsl::StatusOr> FusedConvolveRunnerFromDesc( const dnn::AlgorithmDesc &algorithm_desc, dnn::ConvolutionKind kind, diff --git a/tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.cc b/tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.cc index a329f02cb271a3..4bc3ae261bd1f3 100644 --- a/tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.cc +++ b/tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.cc @@ -286,6 +286,28 @@ tsl::Status StreamExecutor::GetConvolveRunners( scratch_allocator, numeric_options, out_exec_plans); } +tsl::Status StreamExecutor::GetGraphConvolveRunners( + bool use_cudnn_frontend, dnn::ConvolutionKind kind, + dnn::DataType input_type, dnn::DataType output_type, Stream* stream, + const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data, + const dnn::FilterDescriptor& filter_descriptor, + DeviceMemoryBase filter_data, const dnn::BatchDescriptor& output_descriptor, + DeviceMemoryBase output_data, + const dnn::ConvolutionDescriptor& convolution_descriptor, bool use_fallback, + ScratchAllocator* scratch_allocator, const NumericOptions& numeric_options, + std::vector>* out_exec_plans, + string serialized_graph) { + dnn::DnnSupport* dnn_support = AsDnn(); + if (!dnn_support) { + return tsl::errors::Unimplemented("DNN library is not found."); + } + return dnn_support->GetGraphConvolveRunners( + use_cudnn_frontend, kind, input_type, output_type, stream, + input_descriptor, input_data, filter_descriptor, filter_data, + output_descriptor, output_data, convolution_descriptor, use_fallback, + scratch_allocator, numeric_options, out_exec_plans, serialized_graph); +} + tsl::Status StreamExecutor::GetFusedConvolveRunners( bool use_cudnn_frontend, dnn::ConvolutionKind kind, dnn::DataType input_type, dnn::DataType bias_type, diff --git a/tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.h b/tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.h index f32fe8e979710b..7a92d32aa3f99b 100644 --- a/tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.h +++ b/tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.h @@ -377,6 +377,20 @@ class StreamExecutor { const NumericOptions& numeric_options, std::vector>* out_exec_plans); + tsl::Status GetGraphConvolveRunners( + bool use_cudnn_frontend, dnn::ConvolutionKind kind, + dnn::DataType input_type, dnn::DataType output_type, Stream* stream, + const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data, + const dnn::FilterDescriptor& filter_descriptor, + DeviceMemoryBase filter_data, + const dnn::BatchDescriptor& output_descriptor, + DeviceMemoryBase output_data, + const dnn::ConvolutionDescriptor& convolution_descriptor, + bool use_fallback, ScratchAllocator* scratch_allocator, + const NumericOptions& numeric_options, + std::vector>* out_exec_plans, + string serialized_graph); + tsl::Status GetFusedConvolveRunners( bool use_cudnn_frontend, dnn::ConvolutionKind kind, dnn::DataType input_type, dnn::DataType bias_type, diff --git a/tensorflow/compiler/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.cc b/tensorflow/compiler/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.cc index 983417903f5c7e..e09c4d9a2d50b6 100644 --- a/tensorflow/compiler/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.cc +++ b/tensorflow/compiler/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.cc @@ -1276,6 +1276,13 @@ tsl::StatusOr LhloDialectEmitter::EmitDnnConvolution( TF_RETURN_IF_ERROR(set_activation(cnn_fused_side_input)); return set_common_conv_attributes(cnn_fused_side_input); } + case xla::gpu::CudnnConvKind::kForwardGraph: { + TF_ASSIGN_OR_RETURN( + auto cnn_graph, + CreateOpWithoutAttrs(custom_call)); + cnn_graph.setSerializedGraph(backend_config.serialized_graph()); + return set_common_conv_attributes(cnn_graph); + } } } diff --git a/tensorflow/core/kernels/conv_ops_gpu.cc b/tensorflow/core/kernels/conv_ops_gpu.cc index 774efdd60f0c5e..da06511b33cd6f 100644 --- a/tensorflow/core/kernels/conv_ops_gpu.cc +++ b/tensorflow/core/kernels/conv_ops_gpu.cc @@ -260,6 +260,7 @@ StatusOr> AutotuneUnfusedConv( switch (kind) { case se::dnn::ConvolutionKind::FORWARD: case se::dnn::ConvolutionKind::FORWARD_BIAS_ACTIVATION: + case se::dnn::ConvolutionKind::FORWARD_GRAPH: output_ptr = se::DeviceMemory( WrapRedzoneBestEffort(&rz_allocator, output_ptr)); break; diff --git a/tensorflow/tsl/protobuf/dnn.proto b/tensorflow/tsl/protobuf/dnn.proto index 65653da343da69..b349115292e43a 100644 --- a/tensorflow/tsl/protobuf/dnn.proto +++ b/tensorflow/tsl/protobuf/dnn.proto @@ -102,6 +102,7 @@ enum ConvolutionKind { BACKWARD_FILTER = 2; BACKWARD_DATA = 3; FORWARD_BIAS_ACTIVATION = 4; + FORWARD_GRAPH = 5; } // Generic tensor representation. From caade6453519ad2531ebcf8f206e40187a1687ca Mon Sep 17 00:00:00 2001 From: Philipp Hack Date: Wed, 19 Jul 2023 23:39:16 +0000 Subject: [PATCH 049/410] Support for FP8 convolutions in XLA. --- tensorflow/compiler/xla/service/gpu/BUILD | 4 +- .../xla/service/gpu/backend_configs.proto | 2 +- .../xla/service/gpu/buffer_comparator.cc | 134 +++++----- .../xla/service/gpu/buffer_comparator_test.cc | 70 +++++ .../xla/service/gpu/conv_algorithm_picker.cc | 13 +- .../service/gpu/cudnn_fused_conv_rewriter.cc | 180 +++++++------ .../gpu/cudnn_fused_conv_rewriter_test.cc | 196 ++++++++++++-- .../xla/service/gpu/gpu_conv_runner.h | 9 +- .../xla/stream_executor/cuda/cuda_dnn.cc | 241 ++++++++---------- .../xla/stream_executor/cuda/cuda_dnn.h | 13 +- .../compiler/xla/stream_executor/dnn.cc | 14 +- tensorflow/compiler/xla/stream_executor/dnn.h | 25 +- .../xla/stream_executor/lazy_op_runner.h | 4 +- .../stream_executor/stream_executor_pimpl.cc | 20 +- .../stream_executor/stream_executor_pimpl.h | 13 +- 15 files changed, 564 insertions(+), 374 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 7d84276d40126c..2c2ec570cea653 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -3064,7 +3064,9 @@ cc_library( "//tensorflow/compiler/xla/stream_executor", "//tensorflow/tsl/platform:errors", "//tensorflow/tsl/platform:statusor", - ], + ] + if_cuda_is_configured([ + "@local_config_cuda//cuda:cuda_headers", + ]), ) cc_library( diff --git a/tensorflow/compiler/xla/service/gpu/backend_configs.proto b/tensorflow/compiler/xla/service/gpu/backend_configs.proto index 1236d07be4cb6d..3663e51639e680 100644 --- a/tensorflow/compiler/xla/service/gpu/backend_configs.proto +++ b/tensorflow/compiler/xla/service/gpu/backend_configs.proto @@ -60,7 +60,7 @@ message CudnnConvBackendConfig { // Serialization of the graph described by the convolution and adjacent // pointwise ops. - optional string serialized_graph = 8; + optional string serialized_graph = 9; } // Backend config for the GEMM operation running through cuBLAS. diff --git a/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc index bc36fcf8eeb19b..f30d86add5ac4f 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc +++ b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc @@ -44,26 +44,27 @@ static constexpr double kTolerance = 0.1f; // NaN's are considered equal, and for half's we clamp all numbers to largest // and smallest numbers representable to avoid miscomparisons due to overflows. - +// // The PTX below is compiled from the following CUDA code: - +// // #include // #include // #include - +// // namespace { - +// // __device__ __inline__ float __xla_buffer_comparator_canonicalize(float input) // { // // All fp16 infinities are treated as 65505 or -65505, in order to avoid // // differences due to overflows. // return isnan(input) ? input : max(-65505.0f, min(input, 65505.0f)); // } - +// // } // end anonymous namespace - +// // extern "C" { // avoid name mangling - +// +// // __global__ void __xla_fp8_e4m3fn_comparison(__nv_fp8_storage_t *buffer_a, // __nv_fp8_storage_t *buffer_b, // float rel_error_threshold, @@ -82,14 +83,14 @@ static constexpr double kTolerance = 0.1f; // elem_b = __xla_buffer_comparator_canonicalize(elem_b); // if (isnan(elem_a) && isnan(elem_b)) // return; - +// // float rel_error = abs(elem_a - elem_b) / (max(abs(elem_a), abs(elem_b)) + // 1); - +// // if (rel_error > rel_error_threshold || isnan(rel_error)) // atomicAdd(mismatch_count, 1); // } - +// // __global__ void __xla_fp8_e5m2_comparison(__nv_fp8_storage_t *buffer_a, // __nv_fp8_storage_t *buffer_b, // float rel_error_threshold, @@ -108,121 +109,114 @@ static constexpr double kTolerance = 0.1f; // elem_b = __xla_buffer_comparator_canonicalize(elem_b); // if (isnan(elem_a) && isnan(elem_b)) // return; - +// // float rel_error = abs(elem_a - elem_b) / (max(abs(elem_a), abs(elem_b)) + // 1); - +// // if (rel_error > rel_error_threshold || isnan(rel_error)) // atomicAdd(mismatch_count, 1); // } - -// __global__ void __xla_fp16_comparison(__half *buffer_a, __half *buffer_b, +// +// __global__ void __xla_fp16_comparison(__half* buffer_a, __half* buffer_b, // float rel_error_threshold, // unsigned long long buffer_length, -// int *mismatch_count) { +// int* mismatch_count) { // int idx = threadIdx.x + blockIdx.x * blockDim.x; -// if (idx >= buffer_length) -// return; +// if (idx >= buffer_length) return; // float elem_a = __half2float(buffer_a[idx]); // float elem_b = __half2float(buffer_b[idx]); // elem_a = __xla_buffer_comparator_canonicalize(elem_a); // elem_b = __xla_buffer_comparator_canonicalize(elem_b); -// if (isnan(elem_a) && isnan(elem_b)) -// return; - -// float rel_error = abs(elem_a - elem_b) / (max(abs(elem_a), abs(elem_b)) + -// 1); - +// if (isnan(elem_a) && isnan(elem_b)) return; +// +// float rel_error = abs(elem_a - elem_b) +// / (max(abs(elem_a), abs(elem_b)) + 1); +// // if (rel_error > rel_error_threshold || isnan(rel_error)) // atomicAdd(mismatch_count, 1); // } - -// __global__ void __xla_fp32_comparison(float *buffer_a, float *buffer_b, +// +// __global__ void __xla_fp32_comparison(float* buffer_a, float* buffer_b, // float rel_error_threshold, // unsigned long long buffer_length, -// int *mismatch_count) { +// int* mismatch_count) { // int idx = threadIdx.x + blockIdx.x * blockDim.x; -// if (idx >= buffer_length) -// return; +// if (idx >= buffer_length) return; // float elem_a = buffer_a[idx]; // float elem_b = buffer_b[idx]; -// if (isnan(elem_a) && isnan(elem_b)) -// return; +// if (isnan(elem_a) && isnan(elem_b)) return; // if (isinf(elem_a) && isinf(elem_b) && signbit(elem_a) == signbit(elem_b)) // return; - -// float rel_error = abs(elem_a - elem_b) / (max(abs(elem_a), abs(elem_b)) + -// 1); if (rel_error > rel_error_threshold || isnan(rel_error)) +// +// float rel_error = abs(elem_a - elem_b) +// / (max(abs(elem_a), abs(elem_b)) + 1); +// if (rel_error > rel_error_threshold || isnan(rel_error)) // atomicAdd(mismatch_count, 1); // } - -// __global__ void __xla_fp64_comparison(double *buffer_a, double *buffer_b, +// +// __global__ void __xla_fp64_comparison(double* buffer_a, double* buffer_b, // float rel_error_threshold, // unsigned long long buffer_length, -// int *mismatch_count) { +// int* mismatch_count) { // int idx = threadIdx.x + blockIdx.x * blockDim.x; -// if (idx >= buffer_length) -// return; - +// if (idx >= buffer_length) return; +// // double elem_a = buffer_a[idx]; // double elem_b = buffer_b[idx]; -// if (isnan(elem_a) && isnan(elem_b)) -// return; +// if (isnan(elem_a) && isnan(elem_b)) return; // if (isinf(elem_a) && isinf(elem_b) && signbit(elem_a) == signbit(elem_b)) // return; -// double rel_error = abs(elem_a - elem_b) / (max(abs(elem_a), abs(elem_b)) + -// 1); if (rel_error > rel_error_threshold || isnan(rel_error)) +// double rel_error = abs(elem_a - elem_b) +// / (max(abs(elem_a), abs(elem_b)) + 1); +// if (rel_error > rel_error_threshold || isnan(rel_error)) // atomicAdd(mismatch_count, 1); // } - -// __global__ void __xla_bf16_comparison(__nv_bfloat16 *buffer_a, -// __nv_bfloat16 *buffer_b, +// +// __global__ void __xla_bf16_comparison(__nv_bfloat16* buffer_a, +// __nv_bfloat16* buffer_b, // float rel_error_threshold, // unsigned long long buffer_length, -// int *mismatch_count) { +// int* mismatch_count) { // int idx = threadIdx.x + blockIdx.x * blockDim.x; -// if (idx >= buffer_length) -// return; +// if (idx >= buffer_length) return; // float elem_a = __bfloat162float(buffer_a[idx]); // float elem_b = __bfloat162float(buffer_b[idx]); // elem_a = __xla_buffer_comparator_canonicalize(elem_a); // elem_b = __xla_buffer_comparator_canonicalize(elem_b); -// if (isnan(elem_a) && isnan(elem_b)) -// return; - -// float rel_error = abs(elem_a - elem_b) / (max(abs(elem_a), abs(elem_b)) + -// 1); - +// if (isnan(elem_a) && isnan(elem_b)) return; +// +// float rel_error = abs(elem_a - elem_b) +// / (max(abs(elem_a), abs(elem_b)) + 1); +// // if (rel_error > rel_error_threshold || isnan(rel_error)) // atomicAdd(mismatch_count, 1); // } - +// // // TODO(b/191520348): The comparison below requires exact equality. -// __global__ void __xla_int8_comparison(int8_t *buffer_a, int8_t *buffer_b, +// __global__ void __xla_int8_comparison(int8_t* buffer_a, int8_t* buffer_b, // float rel_error_threshold, // unsigned long long buffer_length, -// int *mismatch_count) { +// int* mismatch_count) { // int idx = threadIdx.x + blockIdx.x * blockDim.x; -// if (idx >= buffer_length) -// return; +// if (idx >= buffer_length) return; // float a = buffer_a[idx]; // float b = buffer_b[idx]; // float rel_error = abs(a - b) / (max(abs(a), abs(b)) + 1); // if (rel_error > rel_error_threshold || isnan(rel_error)) -// atomicAdd(mismatch_count, 1); +// atomicAdd(mismatch_count, 1); // } - -// __global__ void __xla_int32_comparison(int *buffer_a, int *buffer_b, +// +// __global__ void __xla_int32_comparison(int* buffer_a, int* buffer_b, // float rel_error_threshold, // unsigned long long buffer_length, -// int *mismatch_count) { +// int* mismatch_count) { // int idx = threadIdx.x + blockIdx.x * blockDim.x; -// if (idx >= buffer_length) -// return; +// if (idx >= buffer_length) return; // float elem_a = static_cast(buffer_a[idx]); // float elem_b = static_cast(buffer_b[idx]); -// float rel_error = abs(elem_a - elem_b) / (max(abs(elem_a), abs(elem_b)) + -// 1); if (rel_error > rel_error_threshold || isnan(rel_error)) +// float rel_error = abs(elem_a - elem_b) +// / (max(abs(elem_a), abs(elem_b)) + 1); +// if (rel_error > rel_error_threshold || isnan(rel_error)) // atomicAdd(mismatch_count, 1); // } // } // end extern declaration @@ -1190,10 +1184,10 @@ StatusOr BufferComparator::CompareEqual( switch (shape_.element_type()) { case xla::F8E4M3FN: return CompareEqualParameterized( - stream, lhs, rhs, shape_, config_, "__xla_fp8_e4m3fn_comparison"); + stream, current, expected, shape_, config_, "__xla_fp8_e4m3fn_comparison"); case xla::F8E5M2: return CompareEqualParameterized( - stream, lhs, rhs, shape_, config_, "__xla_fp8_e5m2_comparison"); + stream, current, expected, shape_, config_, "__xla_fp8_e5m2_comparison"); case xla::F16: return CompareEqualParameterized( stream, current, expected, shape_, config_, "__xla_fp16_comparison"); diff --git a/tensorflow/compiler/xla/service/gpu/buffer_comparator_test.cc b/tensorflow/compiler/xla/service/gpu/buffer_comparator_test.cc index a5d2886fe5342b..04550d6bb122aa 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_comparator_test.cc +++ b/tensorflow/compiler/xla/service/gpu/buffer_comparator_test.cc @@ -162,6 +162,30 @@ TEST_F(BufferComparatorTest, TestInfs) { EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {-20})); EXPECT_FALSE(CompareEqualFloatBuffers({-inf}, {20})); EXPECT_FALSE(CompareEqualFloatBuffers({-inf}, {-20})); + + EXPECT_TRUE( + CompareEqualFloatBuffers({inf}, {std::nanf("")})); + EXPECT_TRUE(CompareEqualFloatBuffers({inf}, {inf})); + EXPECT_TRUE(CompareEqualFloatBuffers({inf}, {65504})); + EXPECT_TRUE(CompareEqualFloatBuffers({-inf}, {-65504})); + EXPECT_TRUE(CompareEqualFloatBuffers({inf}, {-65504})); + EXPECT_TRUE(CompareEqualFloatBuffers({-inf}, {65504})); + EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {20})); + EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {-20})); + EXPECT_FALSE(CompareEqualFloatBuffers({-inf}, {20})); + EXPECT_FALSE(CompareEqualFloatBuffers({-inf}, {-20})); + + EXPECT_FALSE( + CompareEqualFloatBuffers({inf}, {std::nanf("")})); + EXPECT_TRUE(CompareEqualFloatBuffers({inf}, {inf})); + EXPECT_TRUE(CompareEqualFloatBuffers({inf}, {65504})); + EXPECT_TRUE(CompareEqualFloatBuffers({-inf}, {-65504})); + EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {-65504})); + EXPECT_FALSE(CompareEqualFloatBuffers({-inf}, {65504})); + EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {20})); + EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {-20})); + EXPECT_FALSE(CompareEqualFloatBuffers({-inf}, {20})); + EXPECT_FALSE(CompareEqualFloatBuffers({-inf}, {-20})); } TEST_F(BufferComparatorTest, TestNumbers) { @@ -189,6 +213,18 @@ TEST_F(BufferComparatorTest, TestNumbers) { EXPECT_TRUE(CompareEqualFloatBuffers({90}, {100})); EXPECT_TRUE(CompareEqualFloatBuffers({100}, {90})); EXPECT_FALSE(CompareEqualFloatBuffers({-128}, {127})); + + EXPECT_TRUE(CompareEqualFloatBuffers({20}, {20.1})); + EXPECT_FALSE(CompareEqualFloatBuffers({0}, {1})); + EXPECT_TRUE(CompareEqualFloatBuffers({0.9}, {1})); + EXPECT_TRUE(CompareEqualFloatBuffers({9}, {10})); + EXPECT_TRUE(CompareEqualFloatBuffers({9}, {10})); + + EXPECT_TRUE(CompareEqualFloatBuffers({20}, {20.1})); + EXPECT_FALSE(CompareEqualFloatBuffers({0}, {1})); + EXPECT_TRUE(CompareEqualFloatBuffers({0.9}, {1})); + EXPECT_TRUE(CompareEqualFloatBuffers({11}, {12})); + EXPECT_TRUE(CompareEqualFloatBuffers({12}, {11})); } TEST_F(BufferComparatorTest, TestMultiple) { @@ -259,6 +295,40 @@ TEST_F(BufferComparatorTest, TestMultiple) { rhs[i] = 0; } } + + { + EXPECT_TRUE(CompareEqualFloatBuffers( + {20, 30, 40, 50, 60}, {20.1, 30.1, 40.1, 50.1, 60.1})); + std::vector lhs(200); + std::vector rhs(200); + for (int i = 0; i < 200; i++) { + EXPECT_TRUE(CompareEqualFloatBuffers(lhs, rhs)) + << "should be the same at index " << i; + lhs[i] = 3; + rhs[i] = 5; + EXPECT_FALSE(CompareEqualFloatBuffers(lhs, rhs)) + << "should be the different at index " << i; + lhs[i] = 0; + rhs[i] = 0; + } + } + + { + EXPECT_TRUE(CompareEqualFloatBuffers( + {20, 30, 40, 50, 60}, {20.1, 30.1, 40.1, 50.1, 60.1})); + std::vector lhs(200); + std::vector rhs(200); + for (int i = 0; i < 200; i++) { + EXPECT_TRUE(CompareEqualFloatBuffers(lhs, rhs)) + << "should be the same at index " << i; + lhs[i] = 3; + rhs[i] = 5; + EXPECT_FALSE(CompareEqualFloatBuffers(lhs, rhs)) + << "should be the different at index " << i; + lhs[i] = 0; + rhs[i] = 0; + } + } } TEST_F(BufferComparatorTest, BF16) { diff --git a/tensorflow/compiler/xla/service/gpu/conv_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/conv_algorithm_picker.cc index c3b21a5e216e5c..e7e3f6fcfba88c 100644 --- a/tensorflow/compiler/xla/service/gpu/conv_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/conv_algorithm_picker.cc @@ -160,15 +160,10 @@ StatusOr> GetAlgorithms( // This path is cuDNN-only, where the DeviceMemoryBase arguments and the // allocator are unused; so, they're all provided as nullptr. TF_RETURN_IF_ERROR(stream_exec->GetGraphConvolveRunners( - use_cudnn_frontend, kind, input_type, output_type, stream, - config.input_descriptor, - /* input_data = */ DeviceMemoryBase(nullptr), - config.filter_descriptor, - /* filter_data = */ DeviceMemoryBase(nullptr), - config.output_descriptor, - /* output_data = */ DeviceMemoryBase(nullptr), config.conv_desc, - use_fallback, nullptr, se::NumericOptions{deterministic_ops}, - &runners, config.serialized_graph)); + kind, input_type, output_type, stream, config.input_descriptor, + config.filter_descriptor, config.output_descriptor, config.conv_desc, + use_fallback, numeric_options, &runners, + config.serialized_graph)); for (auto& runner : runners) { TF_ASSIGN_OR_RETURN( auto runner_cache, diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc index 8dde0aef08b584..3c1f337c0f1d20 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc @@ -301,8 +301,12 @@ bool IsF8Type(const HloInstruction* instr) { return primitive_util::IsF8Type(instr->shape().element_type()); } -// Format of the serialized graph is -// "conv[output type]->op name[output type]->op name[output type]->...". +// The format of the serialized graph describing a linear sequence of ops fused +// into the cuDNN convolution Custom Call is +// "conv[output_type]->op_name[output_type]->op_name[output_type]->..." with the +// convolution assumed to be the first op in the graph. Currently, +// multiplication and division by a broadcast scalar, addition of a matrix bias +// and the application of a ReLU activation are supported. class GraphString { public: GraphString() : size_(0) {} @@ -331,7 +335,7 @@ class GraphString { // Recursively captures and serializes the graph of pointwise operations // operating on the convolution. -bool CaptureConvGraphRecursive(HloInstruction* instr, +void CaptureConvGraphRecursive(HloInstruction* instr, std::vector& operands, GraphString& graph_string, absl::flat_hash_set& visited_instrs, @@ -341,7 +345,7 @@ bool CaptureConvGraphRecursive(HloInstruction* instr, const int max_pattern_level = 1; // Avoid visiting the same instruction more than once. if (!visited_instrs.emplace(instr->unique_id()).second) { - return false; + return; } // When the function was called from outside or after a successful match, set // the final instruction to the current instruction. @@ -349,77 +353,94 @@ bool CaptureConvGraphRecursive(HloInstruction* instr, final_instr = instr; } - HloInstruction *op, *operand; - for (HloInstruction* user : instr->users()) { + if (instr->user_count() == 1) { + HloInstruction *op, *operand, *user = instr->users()[0]; if (pattern_level == 0) { // Add if (Match(user, m::AddAnyOrder(&op, m::Op(), m::Op(&operand)))) { graph_string.AppendOp("add", op->shape().element_type()); operands.push_back(operand); - return CaptureConvGraphRecursive(user, operands, graph_string, - visited_instrs, final_instr, 0); + CaptureConvGraphRecursive(user, operands, graph_string, visited_instrs, + final_instr, 0); + return; } // Scale if (Match(user, m::MultiplyAnyOrder(&op, m::Op(), - m::Broadcast(m::Op(&operand))))) { + m::Broadcast(m::Op(&operand)))) && + ShapeUtil::IsScalar(operand->shape())) { graph_string.AppendOp("scale", op->shape().element_type()); operands.push_back(operand); - return CaptureConvGraphRecursive(user, operands, graph_string, - visited_instrs, final_instr, 0); + CaptureConvGraphRecursive(user, operands, graph_string, visited_instrs, + final_instr, 0); + return; } // Inverse Scale - if (Match(user, m::Divide(&op, m::Op(), m::Broadcast(m::Op(&operand))))) { + if (Match(user, m::Divide(&op, m::Op(), m::Broadcast(m::Op(&operand)))) && + ShapeUtil::IsScalar(operand->shape())) { graph_string.AppendOp("invscale", op->shape().element_type()); operands.push_back(operand); - return CaptureConvGraphRecursive(user, operands, graph_string, - visited_instrs, final_instr, 0); + CaptureConvGraphRecursive(user, operands, graph_string, visited_instrs, + final_instr, 0); + return; } // ReLU if (Match(user, m::MaximumAnyOrder(&op, m::Op(), m::Broadcast(m::ConstantScalar(0))))) { graph_string.AppendOp("relu", op->shape().element_type()); - return CaptureConvGraphRecursive(user, operands, graph_string, - visited_instrs, final_instr, 0); + CaptureConvGraphRecursive(user, operands, graph_string, visited_instrs, + final_instr, 0); + return; } } if (pattern_level == 1) { - // Convert with clamp + // Convert with clamp to FP8 types + HloInstruction *clamp_lower, *clamp_upper; if (Match(user, - m::Convert(&op, - m::Clamp(m::Broadcast(m::ConstantScalar()), m::Op(), - m::Broadcast(m::ConstantScalar()))))) { - graph_string.ChangeDataType(op->shape().element_type()); - return CaptureConvGraphRecursive(user, operands, graph_string, - visited_instrs, final_instr, 0); + m::Convert( + &op, + m::Clamp(m::Broadcast(m::ConstantScalar(&clamp_lower)), + m::Op(), + m::Broadcast(m::ConstantScalar(&clamp_upper)))))) { + if ((op->shape().element_type() == F8E4M3FN && + clamp_lower->literal().IsAllFloat(static_cast( + std::numeric_limits::lowest())) && + clamp_upper->literal().IsAllFloat(static_cast( + std::numeric_limits::max()))) || + (op->shape().element_type() == F8E5M2 && + clamp_lower->literal().IsAllFloat(static_cast( + std::numeric_limits::lowest())) && + clamp_upper->literal().IsAllFloat(static_cast( + std::numeric_limits::max())))) { + graph_string.ChangeDataType(op->shape().element_type()); + CaptureConvGraphRecursive(user, operands, graph_string, + visited_instrs, final_instr, 0); + return; + } } } // If none of the matches was successful and the pattern level is below the // maximum level, attempt to match at higher level. if (pattern_level < max_pattern_level) { - return CaptureConvGraphRecursive(user, operands, graph_string, - visited_instrs, final_instr, - pattern_level + 1); + CaptureConvGraphRecursive(user, operands, graph_string, visited_instrs, + final_instr, pattern_level + 1); + return; } } - // The first entry in the serialized graph is the convolution. The size is - // greater than one if at least one match was succesful. - if (graph_string.Size() > 1) { - return true; - } else { - return false; - } + return; } -// Captures the graph of pointwise operations operating on the convolution. -bool CaptureConvGraph(HloInstruction* instr, - std::vector& operands, - GraphString& graph_string, HloInstruction*& final_instr, - HloInstruction* x_scale, HloInstruction* w_scale, - bool x_mult_scale, bool w_mult_scale) { - absl::flat_hash_set visited_instrs; +// Captures in a GraphString the subgraph of pointwise operations operating on +// the convolution that will be fused into the cuDNN convolution Custom Call. +std::tuple, GraphString, HloInstruction*> +CaptureConvGraph(HloInstruction* instr, HloInstruction* x_scale, + HloInstruction* w_scale, bool x_mult_scale, + bool w_mult_scale) { + std::vector operands; + GraphString graph_string; + graph_string.AppendOp("conv", instr->shape().element_type()); // Shift the scaling of the inputs to the output of the convolution. @@ -442,8 +463,13 @@ bool CaptureConvGraph(HloInstruction* instr, instr->shape().element_type()); } } - return CaptureConvGraphRecursive(instr, operands, graph_string, - visited_instrs, final_instr); + + absl::flat_hash_set visited_instrs; + HloInstruction* final_instr; + CaptureConvGraphRecursive(instr, operands, graph_string, visited_instrs, + final_instr); + + return std::make_tuple(operands, graph_string, final_instr); } // Matches convolutions operating on FP8 inputs and filters and rewrites into a @@ -451,12 +477,12 @@ bool CaptureConvGraph(HloInstruction* instr, // following steps are elided and rewritten into a ForwardGraph Custom Call: // // 1. Cast the filter and input from FP8 to a wider type such as FP16 or FP32. -// 2. Unscale the filter and input by multiplying them by the corresponding -// input scale. +// 2. Unscale the filter and input by multiplying or dividing by scalars. // 3. Evaluate the convolution based on the scaled filter and input. -// 4. Optionally add a matrix bias to the result and apply a ReLU activation. -// 5. Scale the output by dividing the output by the output scale. -// 6. Cast the output back to FP8. +// 4. Apply a series of elementwise transformations, where a transformation can +// be adding a matrix bias, applying a ReLU activation, or +// multiplying or dividing by a broadcast scalar. +// 5. Cast the output back to FP8. StatusOr F8GraphConv(HloComputation* comp, se::CudaComputeCapability cc) { bool changed = false; @@ -467,11 +493,9 @@ StatusOr F8GraphConv(HloComputation* comp, se::CudaComputeCapability cc) { if (!cc.IsAtLeast(se::CudaComputeCapability::HOPPER)) { return false; } - HloInstruction *convolution, *gte, *input, *filter, *final_instr, + HloInstruction *convolution, *gte, *input, *filter, *x_scale = nullptr, *w_scale = nullptr, *x_scale_op = nullptr, *w_scale_op = nullptr; - std::vector operands; - GraphString graph_string; // TODO(philipphack): Consider allowing ops between dequantization and // convolution. @@ -498,32 +522,36 @@ StatusOr F8GraphConv(HloComputation* comp, se::CudaComputeCapability cc) { m::Convert(m::Op(&filter).WithPredicate(IsF8Type)), m::Broadcast(m::Op(&w_scale))))), 0); - if (Match(instr, pattern) && - CaptureConvGraph( - const_cast(instr), operands, graph_string, - final_instr, x_scale, w_scale, - x_scale_op ? x_scale_op->opcode() == HloOpcode::kMultiply : false, - w_scale_op ? w_scale_op->opcode() == HloOpcode::kMultiply - : false)) { - TF_ASSIGN_OR_RETURN( - auto config, convolution->backend_config()); - config.set_serialized_graph(graph_string.Graph()); - operands.insert(operands.begin(), input); - operands.insert(operands.begin() + 1, filter); - - Shape new_shape = ShapeUtil::MakeTupleShape( - {ShapeUtil::ChangeElementType( - ShapeUtil::GetTupleElementShape(convolution->shape(), 0), - final_instr->shape().element_type()), - ShapeUtil::GetTupleElementShape(convolution->shape(), 1)}); - HloInstruction* new_convolution = comp->AddInstruction( - convolution->CloneWithNewOperands(new_shape, operands)); - new_convolution->set_custom_call_target(kCudnnConvForwardGraphCallTarget); - TF_RETURN_IF_ERROR(new_convolution->set_backend_config(config)); - TF_ASSIGN_OR_RETURN(HloInstruction * new_gte, - MakeGetTupleElementHlo(new_convolution, 0)); - TF_RETURN_IF_ERROR(comp->ReplaceInstruction(final_instr, new_gte)); - changed = true; + if (Match(instr, pattern)) { + std::vector operands; + GraphString graph_string; + HloInstruction* final_instr; + std::tie(operands, graph_string, final_instr) = CaptureConvGraph( + const_cast(instr), x_scale, w_scale, + x_scale_op ? x_scale_op->opcode() == HloOpcode::kMultiply : false, + w_scale_op ? w_scale_op->opcode() == HloOpcode::kMultiply : false); + if (graph_string.Size() > 1) { + TF_ASSIGN_OR_RETURN( + auto config, convolution->backend_config()); + config.set_serialized_graph(graph_string.Graph()); + operands.insert(operands.begin(), input); + operands.insert(operands.begin() + 1, filter); + + Shape new_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::ChangeElementType( + ShapeUtil::GetTupleElementShape(convolution->shape(), 0), + final_instr->shape().element_type()), + ShapeUtil::GetTupleElementShape(convolution->shape(), 1)}); + HloInstruction* new_convolution = comp->AddInstruction( + convolution->CloneWithNewOperands(new_shape, operands)); + new_convolution->set_custom_call_target( + kCudnnConvForwardGraphCallTarget); + TF_RETURN_IF_ERROR(new_convolution->set_backend_config(config)); + TF_ASSIGN_OR_RETURN(HloInstruction * new_gte, + MakeGetTupleElementHlo(new_convolution, 0)); + TF_RETURN_IF_ERROR(comp->ReplaceInstruction(final_instr, new_gte)); + changed = true; + } } } #endif diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc index 8ba4995c3b110a..401bbe13572447 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc @@ -141,24 +141,42 @@ class CudnnFusedConvRewriterTest : public GpuCodegenTest { } void TestF8(absl::string_view pre_hlo_string, - absl::string_view post_hlo_string) { - if (!GetCudaComputeCapability().IsAtLeast( + absl::string_view custom_call_string, + absl::string_view serialized_graph_string) { + if (GetCudaComputeCapability().IsAtLeast( se::CudaComputeCapability::HOPPER)) { - GTEST_SKIP() << "FP8 convolutions require Hopper or newer architecture."; + // On Hopper or newer architectures, test numerical correctness and verify + // the HLO of the Custom Call with operand and return layouts and the + // serialized graph based on the full compiler pipeline. + std::string optimized_hlo_string = GetOptimizedHlo(pre_hlo_string); + EXPECT_THAT(optimized_hlo_string, Not(HasSubstr("Convert"))); + EXPECT_THAT(optimized_hlo_string, HasSubstr("__cudnn$conv")); + EXPECT_TRUE(RunAndCompare(pre_hlo_string, ErrorSpec{0.1})) + << pre_hlo_string; + + StatusOr filecheck_result = + RunFileCheck(optimized_hlo_string, custom_call_string); + ASSERT_TRUE(filecheck_result.ok()) << filecheck_result.status(); + EXPECT_TRUE(*filecheck_result); + + filecheck_result = + RunFileCheck(optimized_hlo_string, serialized_graph_string); + ASSERT_TRUE(filecheck_result.ok()) << filecheck_result.status(); + EXPECT_TRUE(*filecheck_result); + } else { + // On older architectures, disregard layout information and only verify + // the serialized graph based on the GpuConvRewriter and + // CudnnFusedConvRewriter passes. + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(pre_hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunHloPass(GpuConvRewriter(), module.get())); + RunAndFilecheckHloRewrite( + module->ToString(HloPrintOptions{}.set_print_operand_shape(false)), + CudnnFusedConvRewriter( + se::CudaComputeCapability{se::CudaComputeCapability::HOPPER, 0}), + serialized_graph_string); } - std::string alpha_conv_scalar, alpha_side_input_scalar; - std::string elementwise_type; - - std::string optimized_hlo_string = GetOptimizedHlo(pre_hlo_string); - EXPECT_THAT(optimized_hlo_string, Not(HasSubstr("Convert"))); - EXPECT_THAT(optimized_hlo_string, HasSubstr("__cudnn$conv")); - EXPECT_TRUE(RunAndCompare(pre_hlo_string, ErrorSpec{0.1})) - << pre_hlo_string; - - StatusOr filecheck_result = - RunFileCheck(optimized_hlo_string, post_hlo_string); - ASSERT_TRUE(filecheck_result.ok()) << filecheck_result.status(); - EXPECT_TRUE(*filecheck_result); } }; @@ -622,7 +640,7 @@ TEST_F(CudnnFusedConvRewriterTest, TestPreservesFeatureGroupCount) { EXPECT_TRUE(RunAndCompare(kHloString, ErrorSpec{0.01})); } -TEST_F(CudnnFusedConvRewriterTest, TestConvScaledYF8) { +TEST_F(CudnnFusedConvRewriterTest, TestConvScaledOutputF8) { TestF8( // pre_hlo R"( @@ -645,9 +663,12 @@ TEST_F(CudnnFusedConvRewriterTest, TestConvScaledYF8) { ROOT conv_f8 = f8e4m3fn[1,16,6,6] convert(conv_a_clamped) })", - // post_hlo + // custom_call R"( // CHECK: [[cudnn_conv_4_0:%[^ ]+]] = (f8e4m3fn[1,6,6,16]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[OPERAND0:%[^ ]+]], [[OPERAND1:%[^ ]+]], [[OPERAND2:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForwardGraph" + )", + // serialized_graph + R"( // CHECK: "serialized_graph":"conv[f32]-\u003escale[f8e4m3fn]-\u003e" )"); } @@ -681,13 +702,133 @@ TEST_F(CudnnFusedConvRewriterTest, TestConvScaledF8) { ROOT conv_f8 = f8e4m3fn[1,16,6,6] convert(conv_a_clamped) })", - // post_hlo + // custom_call R"( // CHECK: [[cudnn_conv_4_0:%[^ ]+]] = (f8e4m3fn[1,6,6,16]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[OPERAND0:%[^ ]+]], [[OPERAND1:%[^ ]+]], [[OPERAND2:%[^ ]+]], [[OPERAND3:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForwardGraph" + )", + // serialized_graph + R"( // CHECK: "serialized_graph":"conv[f32]-\u003escale[f32]-\u003escale[f8e4m3fn]-\u003e" )"); } +TEST_F(CudnnFusedConvRewriterTest, TestConvScaledInputE5M2F8) { + TestF8( + // pre_hlo + R"( + HloModule Test + + ENTRY Test { + input = f8e5m2[1,128,6,6] parameter(0) + filter = f8e4m3fn[3,3,128,16] parameter(1) + input_scale = f32[] parameter(2) + input_scale_bcast = f32[1,128,6,6] broadcast(input_scale), dimensions={} + filter_scale = f32[] parameter(3) + filter_scale_bcast = f32[3,3,128,16] broadcast(filter_scale), dimensions={} + input_f32 = f32[1,128,6,6] convert(input) + input_unscaled = f32[1,128,6,6] multiply(input_f32, input_scale_bcast) + filter_f32 = f32[3,3,128,16] convert(filter) + filter_unscaled = f32[3,3,128,16] multiply(filter_f32, filter_scale_bcast) + z_scale = f32[] parameter(4) + z_scale_bcast = f32[1,16,6,6] broadcast(z_scale), dimensions={} + conv_a = f32[1,16,6,6] convolution(input_unscaled, filter_unscaled), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1 + conv_a_scaled = f32[1,16,6,6] multiply(conv_a, z_scale_bcast) + c1 = f32[] constant(-57344.) + c1_bcast = f32[1,16,6,6] broadcast(c1), dimensions={} + c2 = f32[] constant(57344.) + c2_bcast = f32[1,16,6,6] broadcast(c2), dimensions={} + conv_a_clamped = f32[1,16,6,6] clamp(c1_bcast, conv_a_scaled, c2_bcast) + ROOT conv_f8 = f8e5m2[1,16,6,6] convert(conv_a_clamped) + + })", + // custom_call + R"( +// CHECK: [[cudnn_conv_4_0:%[^ ]+]] = (f8e5m2[1,6,6,16]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[OPERAND0:%[^ ]+]], [[OPERAND1:%[^ ]+]], [[OPERAND2:%[^ ]+]], [[OPERAND3:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForwardGraph" + )", + // serialized_graph + R"( +// CHECK: "serialized_graph":"conv[f32]-\u003escale[f32]-\u003escale[f8e5m2]-\u003e" + )"); +} + +TEST_F(CudnnFusedConvRewriterTest, TestConvScaledFilterE5M2F8) { + TestF8( + // pre_hlo + R"( + HloModule Test + + ENTRY Test { + input = f8e4m3fn[1,128,6,6] parameter(0) + filter = f8e5m2[3,3,128,16] parameter(1) + input_scale = f32[] parameter(2) + input_scale_bcast = f32[1,128,6,6] broadcast(input_scale), dimensions={} + filter_scale = f32[] parameter(3) + filter_scale_bcast = f32[3,3,128,16] broadcast(filter_scale), dimensions={} + input_f32 = f32[1,128,6,6] convert(input) + input_unscaled = f32[1,128,6,6] multiply(input_f32, input_scale_bcast) + filter_f32 = f32[3,3,128,16] convert(filter) + filter_unscaled = f32[3,3,128,16] multiply(filter_f32, filter_scale_bcast) + z_scale = f32[] parameter(4) + z_scale_bcast = f32[1,16,6,6] broadcast(z_scale), dimensions={} + conv_a = f32[1,16,6,6] convolution(input_unscaled, filter_unscaled), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1 + conv_a_scaled = f32[1,16,6,6] multiply(conv_a, z_scale_bcast) + c1 = f32[] constant(-57344.) + c1_bcast = f32[1,16,6,6] broadcast(c1), dimensions={} + c2 = f32[] constant(57344.) + c2_bcast = f32[1,16,6,6] broadcast(c2), dimensions={} + conv_a_clamped = f32[1,16,6,6] clamp(c1_bcast, conv_a_scaled, c2_bcast) + ROOT conv_f8 = f8e5m2[1,16,6,6] convert(conv_a_clamped) + + })", + // custom_call + R"( +// CHECK: [[cudnn_conv_4_0:%[^ ]+]] = (f8e5m2[1,6,6,16]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[OPERAND0:%[^ ]+]], [[OPERAND1:%[^ ]+]], [[OPERAND2:%[^ ]+]], [[OPERAND3:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForwardGraph" + )", + // serialized_graph + R"( +// CHECK: "serialized_graph":"conv[f32]-\u003escale[f32]-\u003escale[f8e5m2]-\u003e" + )"); +} + +TEST_F(CudnnFusedConvRewriterTest, TestConvScaledInputE5M2FilterE5M2F8) { + TestF8( + // pre_hlo + R"( + HloModule Test + + ENTRY Test { + input = f8e5m2[1,128,6,6] parameter(0) + filter = f8e5m2[3,3,128,16] parameter(1) + input_scale = f32[] parameter(2) + input_scale_bcast = f32[1,128,6,6] broadcast(input_scale), dimensions={} + filter_scale = f32[] parameter(3) + filter_scale_bcast = f32[3,3,128,16] broadcast(filter_scale), dimensions={} + input_f32 = f32[1,128,6,6] convert(input) + input_unscaled = f32[1,128,6,6] multiply(input_f32, input_scale_bcast) + filter_f32 = f32[3,3,128,16] convert(filter) + filter_unscaled = f32[3,3,128,16] multiply(filter_f32, filter_scale_bcast) + z_scale = f32[] parameter(4) + z_scale_bcast = f32[1,16,6,6] broadcast(z_scale), dimensions={} + conv_a = f32[1,16,6,6] convolution(input_unscaled, filter_unscaled), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1 + conv_a_scaled = f32[1,16,6,6] multiply(conv_a, z_scale_bcast) + c1 = f32[] constant(-57344.) + c1_bcast = f32[1,16,6,6] broadcast(c1), dimensions={} + c2 = f32[] constant(57344.) + c2_bcast = f32[1,16,6,6] broadcast(c2), dimensions={} + conv_a_clamped = f32[1,16,6,6] clamp(c1_bcast, conv_a_scaled, c2_bcast) + ROOT conv_f8 = f8e5m2[1,16,6,6] convert(conv_a_clamped) + + })", + // custom_call + R"( +// CHECK: [[cudnn_conv_4_0:%[^ ]+]] = (f8e5m2[1,6,6,16]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[OPERAND0:%[^ ]+]], [[OPERAND1:%[^ ]+]], [[OPERAND2:%[^ ]+]], [[OPERAND3:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForwardGraph" + )", + // serialized_graph + R"( +// CHECK: "serialized_graph":"conv[f32]-\u003escale[f32]-\u003escale[f8e5m2]-\u003e" + )"); +} + TEST_F(CudnnFusedConvRewriterTest, TestConvScaledBiasF8) { TestF8( // pre_hlo @@ -699,7 +840,7 @@ TEST_F(CudnnFusedConvRewriterTest, TestConvScaledBiasF8) { filter = f8e4m3fn[3,3,128,16] parameter(1) input_scale = f32[] parameter(2) input_scale_bcast = f32[1,128,6,6] broadcast(input_scale), dimensions={} - filter_scale = f32[] parameter(3) + filter_scale = f32[] parameter(3) filter_scale_bcast = f32[3,3,128,16] broadcast(filter_scale), dimensions={} input_f32 = f32[1,128,6,6] convert(input) input_unscaled = f32[1,128,6,6] multiply(input_f32, input_scale_bcast) @@ -719,9 +860,12 @@ TEST_F(CudnnFusedConvRewriterTest, TestConvScaledBiasF8) { ROOT conv_f8 = f8e4m3fn[1,16,6,6] convert(conv_a_clamped) })", - // post_hlo + // custom_call R"( // CHECK: [[cudnn_conv_4_0:%[^ ]+]] = (f8e4m3fn[1,6,6,16]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[OPERAND0:%[^ ]+]], [[OPERAND1:%[^ ]+]], [[OPERAND2:%[^ ]+]], [[OPERAND3:%[^ ]+]], [[OPERAND4:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForwardGraph" + )", + // serialized_graph + R"( // CHECK: "serialized_graph":"conv[f32]-\u003escale[f32]-\u003eadd[f32]-\u003escale[f8e4m3fn]-\u003e" )"); } @@ -749,9 +893,12 @@ TEST_F(CudnnFusedConvRewriterTest, TestConvInvscaledF8) { ROOT conv_f8 = f8e4m3fn[1,16,6,6] convert(conv_a_clamped) })", - // post_hlo + // custom_call R"( // CHECK: [[cudnn_conv_4_0:%[^ ]+]] = (f8e4m3fn[1,6,6,16]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[OPERAND0:%[^ ]+]], [[OPERAND1:%[^ ]+]], [[OPERAND2:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForwardGraph" + )", + // serialized_graph + R"( // CHECK: "serialized_graph":"conv[f32]-\u003einvscale[f8e4m3fn]-\u003e" )"); } @@ -782,9 +929,12 @@ TEST_F(CudnnFusedConvRewriterTest, TestConvScaledReluActivationF8) { ROOT conv_f8 = f8e4m3fn[1,16,6,6] convert(relu_a_clamped) })", - // post_hlo + // custom_call R"( // CHECK: [[cudnn_conv_4_0:%[^ ]+]] = (f8e4m3fn[1,6,6,16]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[OPERAND0:%[^ ]+]], [[OPERAND1:%[^ ]+]], [[OPERAND2:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForwardGraph" + )", + // serialized_graph + R"( // CHECK: "serialized_graph":"conv[f32]-\u003erelu[f32]-\u003escale[f8e4m3fn]-\u003e" )"); } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h b/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h index dd490565ea38aa..6b07091cdf43f4 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h @@ -70,6 +70,10 @@ struct GpuConvConfig { Shape filter_shape; Shape output_shape; std::optional fusion; + + // String serialization of the subgraph of adjacent ops to be fused into the + // cuDNN convolution Custom Call. Currently used for FP8 convolutions only. + // Additional information is provided in gpu_fused_conv_rewriter.cc. std::string serialized_graph; }; @@ -85,7 +89,8 @@ struct GpuConvParams { se::DeviceMemoryBase filter_buf; se::DeviceMemoryBase output_buf; - // Buffers for operands of pointwise ops. + // Buffers for operands of pointwise ops to be fused into the cuDNN + // convolution Custom Call. std::vector operand_bufs; std::optional fusion; @@ -95,7 +100,7 @@ struct GpuConvParams { // convolution is fused (and has extra arguments) or unfused, which doesn't // naturally play well with the typed APIs provided by StreamExecutor; rather // than rewriting everything here, just propagate the dynamic typing to one more -// place by having either a FusedConvRunner or a ConvRunner. +// place by having a ConvRunner, FusedConvRunner or GraphConvRunner. class GenericConvRunner { public: GenericConvRunner() = default; diff --git a/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc b/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc index 3f426604d2aba0..717eecb72f199c 100644 --- a/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc @@ -3540,6 +3540,19 @@ tsl::StatusOr CreateCudnnTensor( RETURN_MSG_IF_CUDNN_ERROR(tensor); return tensor; } + +tsl::StatusOr CreateCudnnTensor( + const cudnn_frontend::Tensor& original, int64_t uid, dnn::DataType dtype, + bool is_virtual = false) { + auto tensor = cudnn_frontend::TensorBuilder() + .cloneFrom(original, uid) + .setAlignment(32) + .setDataType(ToCudnnDataType(dtype)) + .setVirtual(is_virtual) + .build(); + RETURN_MSG_IF_CUDNN_ERROR(tensor); + return tensor; +} #else tsl::StatusOr CreateCudnnTensor( absl::Span dims, absl::Span strides, @@ -3569,6 +3582,13 @@ tsl::StatusOr CreateCudnnTensor( RETURN_MSG_IF_CUDNN_ERROR(tensor); return tensor; } + +tsl::StatusOr CreateCudnnTensor( + cudnn_frontend::Tensor original, int64_t uid, dnn::DataType dtype, + bool is_virtual = false) { + return tsl::errors : Internal("Not implemented."); +} + #endif #if (CUDNN_VERSION >= 8800 && TF_ENABLE_CUDNN_FRONTEND) @@ -4191,6 +4211,12 @@ OpNameStringToInputKindAndMode(string opstring) { return tsl::errors::Internal("Unknown op."); } +// TODO(philipphack): Consider merging with GetCudnnOperationGraph and +// GetCudnnFusedOperationGraph. + +// Returns a generic cuDNN OperationGraph for ForwardGraph +// convolutions with dynamically identified fused ops and the +// associated set of UIDs of non-virtual cuDNN tensors. tsl::StatusOr, std::vector>> GetGenericCudnnOperationGraph( @@ -4201,7 +4227,6 @@ GetGenericCudnnOperationGraph( const dnn::ConvolutionDescriptor& convolution_descriptor, CudnnHandle& cudnn, string serialized_graph = "") { PreloadCudnnSubLibsHelper(kind); - std::vector virtual_uids, non_virtual_uids; // Struct to describe the ops (convolution and pointwise) in the sequence // described by the graph. @@ -4211,8 +4236,10 @@ GetGenericCudnnOperationGraph( dnn::DataType output_type; }; - // Format of the serialized graph is - // "conv[output type]->op name[output type]->op name[output type]->...". + // The format of the serialized graph describing a linear sequence of ops + // fused into the cuDNN convolution Custom Call is + // "conv[output_type]->op_name[output_type]->op_name[output_type]->..." with + // the convolution assumed to be first op in the graph. auto deserialize_cudnn_graph = [&]() -> tsl::StatusOr> { std::vector op_sequence = {}; @@ -4228,10 +4255,18 @@ GetGenericCudnnOperationGraph( TF_ASSIGN_OR_RETURN(output_type, PrimitiveTypeStringToDnnType(data_type_string)); if (op_string == "conv") { + if (!op_sequence.empty()) { + return tsl::errors::Internal( + "The graph must not contain more than one convolution op."); + } mode = convolution_descriptor.convolution_not_crosscorr() ? CUDNN_CONVOLUTION : CUDNN_CROSS_CORRELATION; } else { + if (op_sequence.empty()) { + return tsl::errors::Internal( + "The first op in the graph must be a convolution."); + } TF_ASSIGN_OR_RETURN(std::tie(input_kind, mode), OpNameStringToInputKindAndMode(op_string)); } @@ -4251,6 +4286,25 @@ GetGenericCudnnOperationGraph( cudnnBackendDescriptorType_t conv_mode = GetCudnnConvolutionType(kind); std::vector ops = {}; + std::vector virtual_uids, non_virtual_uids; + + auto next_uid = [&non_virtual_uids, + &virtual_uids](bool is_virtual) -> int64_t { + int64_t next_uid = + std::max(non_virtual_uids.empty() + ? 0 + : *std::max_element(non_virtual_uids.begin(), + non_virtual_uids.end()), + virtual_uids.empty() ? 0 + : *std::max_element(virtual_uids.begin(), + virtual_uids.end())) + + 1; + if (is_virtual) { + return virtual_uids.emplace_back(next_uid); + } else { + return non_virtual_uids.emplace_back(next_uid); + } + }; // x tensor. int vector_size, vector_dim; @@ -4262,9 +4316,9 @@ GetGenericCudnnOperationGraph( dnn::DataLayout::kBatchDepthYX, vector_size, vector_dim); TF_ASSIGN_OR_RETURN(auto tensor_x, - CreateCudnnTensor(input_dims, input_strides, 'x', + CreateCudnnTensor(input_dims, input_strides, + next_uid(/*is_virtual=*/false), input_type, vector_size, vector_dim)); - non_virtual_uids.push_back('x'); // w tensor. std::tie(vector_size, vector_dim) = @@ -4281,10 +4335,10 @@ GetGenericCudnnOperationGraph( : CUDNN_TENSOR_REORDERING_NONE; TF_ASSIGN_OR_RETURN( auto tensor_w, - CreateCudnnTensor(filter_dims, filter_strides, 'w', input_type, - vector_size, vector_dim, + CreateCudnnTensor(filter_dims, filter_strides, + next_uid(/*is_virtual=*/false), input_type, vector_size, + vector_dim, /*is_virtual=*/false, tensor_ordering_type)); - non_virtual_uids.push_back('w'); // y tensor. std::tie(vector_size, vector_dim) = @@ -4296,10 +4350,10 @@ GetGenericCudnnOperationGraph( TF_ASSIGN_OR_RETURN( auto tensor_y, - CreateCudnnTensor(output_dims, output_strides, 'y', + CreateCudnnTensor(output_dims, output_strides, + next_uid(/*is_virtual=*/true), op_sequence[0].output_type, vector_size, vector_dim, /*is_virtual=*/true)); - virtual_uids.push_back('y'); auto accumulator_type = ToCudnnDataType(GetConvAccumulatorType(input_type)); CHECK_NE(convolution_descriptor.pad_alignment(), @@ -4346,51 +4400,30 @@ GetGenericCudnnOperationGraph( // Create cuDNN tensors for operands of binary ops (side inputs). if (op_descriptor.input_kind == InputKind::kScalar) { std::vector scale_dim(4, 1); - second_operand = - cudnn_frontend::TensorBuilder() - .setDim(4, scale_dim.data()) - .setStrides(4, scale_dim.data()) - .setId(non_virtual_uids.emplace_back( - std::min(non_virtual_uids.back(), virtual_uids.back()) - 1)) - .setAlignment(32) - .setDataType( - ToCudnnDataType(op_sequence[op_num - 1].output_type)) - .build(); + TF_ASSIGN_OR_RETURN( + second_operand, + CreateCudnnTensor(scale_dim, scale_dim, + next_uid(/*is_virtual=*/false), + op_sequence[op_num - 1].output_type, 1, -1)); VLOG(4) << "\nPointwise operand: " << second_operand->describe(); } else if (op_descriptor.input_kind == InputKind::kTensor) { - second_operand = - cudnn_frontend::TensorBuilder() - .cloneFrom(tensor_y, non_virtual_uids.emplace_back( - std::min(non_virtual_uids.back(), - virtual_uids.back()) - - 1)) - .setVirtual(false) - .setAlignment(32) - .setDataType( - ToCudnnDataType(op_sequence[op_num - 1].output_type)) - .build(); + TF_ASSIGN_OR_RETURN( + second_operand, + CreateCudnnTensor(tensor_y, next_uid(/*is_virtual=*/false), + op_sequence[op_num - 1].output_type, + /*is_virtual=*/false)); VLOG(4) << "\nPointwise operand: " << second_operand->describe(); } // Create the result tensor of the op. - cudnn_frontend::Tensor result = - cudnn_frontend::TensorBuilder() - .cloneFrom( - tensor_y, - std::min(non_virtual_uids.back(), virtual_uids.back()) - 1) - .setVirtual(op_num != op_sequence.size() - 1) - .setDataType(ToCudnnDataType(op_descriptor.output_type)) - .build(); + TF_ASSIGN_OR_RETURN( + cudnn_frontend::Tensor result, + CreateCudnnTensor( + tensor_y, + next_uid(/*is_virtual=*/op_num != op_sequence.size() - 1), + op_descriptor.output_type, op_num != op_sequence.size() - 1)); VLOG(4) << "\nPointwise result: " << result.describe(); - if (op_num == op_sequence.size() - 1) { - non_virtual_uids.emplace_back( - std::min(non_virtual_uids.back(), virtual_uids.back()) - 1); - } else { - virtual_uids.emplace_back( - std::min(non_virtual_uids.back(), virtual_uids.back()) - 1); - } - // Create the descriptor of the op. cudnn_frontend::PointWiseDesc desc = cudnn_frontend::PointWiseDescBuilder() @@ -6511,7 +6544,7 @@ class CudnnExecutionPlanRunner tsl::Status operator()(Stream* stream, dnn::ProfileResult* profile_result, DeviceMemoryBase scratch_memory, - std::vector inputs) const { + Args... inputs) const override { if (static_cast(parent_) != stream->parent()->implementation()) { return tsl::errors::Internal( @@ -6522,87 +6555,31 @@ class CudnnExecutionPlanRunner size_t workspace_size = plan_.getWorkspaceSize(); RETURN_MSG_IF_CUDNN_ERROR(plan_); + bool should_add_scalars = + !scalar_input_uids_.empty() && !scalar_input_values_.empty(); + RETURN_MSG_IF_CUDNN_ERROR(plan_); std::vector data_uids_vec = {data_uids_.cbegin(), data_uids_.cend()}; std::vector data_ptrs_vec; - for (DeviceMemoryBase input : inputs) { - data_ptrs_vec.push_back(input.opaque()); - } - - auto variantPack = - cudnn_frontend::VariantPackBuilder() - .setWorkspacePointer(scratch_memory.opaque()) - .setDataPointers(data_ptrs_vec.size(), data_ptrs_vec.data()) - .setUids(data_uids_vec.size(), data_uids_vec.data()) - .build(); - RETURN_MSG_IF_CUDNN_ERROR(variantPack); - - VLOG(4) << "\nDo cudnn execution plan with plan tag: " << plan_.getTag() - << "\nWorkspace size in bytes: " << workspace_size - << "\nVariantPack: " << variantPack.describe(); - - const bool is_profiling = profile_result != nullptr; - std::unique_ptr timer; - if (is_profiling) { - timer.reset(new GpuTimer(parent_)); // NOLINT - // The start and stop of the timer should be as close to the Cudnn call as - // possible. It is still possible for other threads to issue workload on - // to this stream. So it could take multiple profiling measurements. - if (!timer->Init() || !timer->Start(AsGpuStream(stream))) { - return tsl::Status(absl::StatusCode::kInternal, - "Failed to start timer"); + // The operands of ForwardGraph convolutions are gathered dynamically. In + // this case, Args... is std::vector. + if constexpr (sizeof...(Args) == 1 && + std::is_same_v>, + std::vector>) { + for (DeviceMemoryBase input : std::get<0>(std::tie(inputs...))) { + data_ptrs_vec.push_back(input.opaque()); } - } - - cudnnStatus_t status = cudnnBackendExecute( - cudnn.handle(), plan_.get_raw_desc(), variantPack.get_raw_desc()); - RETURN_IF_CUDNN_ERROR(status); - - if (is_profiling) { - if (!timer->Stop(AsGpuStream(stream))) { - return tsl::Status(absl::StatusCode::kInternal, "Failed to stop timer"); + } else { + data_ptrs_vec = {inputs.opaque()...}; + // We use need_side_input to determine if the side input 'z' from + // {'x', 'w', 'z', 'b', 'y'} is needed for the conv--bias-act + // patterns. + if (sizeof...(Args) == 5 && !need_side_input_) { + data_uids_vec.erase(data_uids_vec.begin() + 2); + data_ptrs_vec.erase(data_ptrs_vec.begin() + 2); } - TF_ASSIGN_OR_RETURN(auto desc, ToAlgorithmDesc()); - profile_result->set_algorithm(desc); - profile_result->set_elapsed_time_in_ms(timer->GetElapsedMilliseconds()); - profile_result->set_scratch_size(scratch_memory.size()); - - VLOG(4) << "cudnn op with plan " << plan_.getTag() - << ", workspace_size=" << workspace_size << " -> " - << CudnnStatusToString(status) << " in " - << timer->GetElapsedMilliseconds() << "ms"; - } - - return tsl::OkStatus(); - } - - tsl::Status operator()(Stream* stream, dnn::ProfileResult* profile_result, - DeviceMemoryBase scratch_memory, - Args... inputs) const override { - if (static_cast(parent_) != - stream->parent()->implementation()) { - return tsl::errors::Internal( - "CudnnExecutionPlanRunner cached across multiple StreamExecutors."); - } - - auto cudnn = cudnn_->GetHandle(parent_, stream); - - size_t workspace_size = plan_.getWorkspaceSize(); - RETURN_MSG_IF_CUDNN_ERROR(plan_); - bool should_add_scalars = - !scalar_input_uids_.empty() && !scalar_input_values_.empty(); - CHECK(scalar_input_uids_.size() == scalar_input_values_.size()); - std::array data_ptrs = {inputs.opaque()...}; - - std::vector data_uids_vec(data_uids_.cbegin(), data_uids_.cend()); - std::vector data_ptrs_vec(data_ptrs.cbegin(), data_ptrs.cend()); - // We use need_side_input to determine if the side input 'z' from - // {'x', 'w', 'z', 'b', 'y'} is needed for the conv--bias-act patterns. - if (sizeof...(Args) == 5 && !need_side_input_) { - data_uids_vec.erase(data_uids_vec.begin() + 2); - data_ptrs_vec.erase(data_ptrs_vec.begin() + 2); } if (!data_ptrs_vec.empty() && data_ptrs_vec.back() == nullptr && @@ -6940,32 +6917,22 @@ tsl::Status CudnnSupport::GetConvolveRunners( } tsl::Status CudnnSupport::GetGraphConvolveRunners( - bool use_cudnn_frontend, dnn::ConvolutionKind kind, - dnn::DataType input_type, dnn::DataType output_type, Stream* stream, + dnn::ConvolutionKind kind, dnn::DataType input_type, + dnn::DataType output_type, Stream* stream, const dnn::BatchDescriptor& input_descriptor, - DeviceMemoryBase /*input_data*/, const dnn::FilterDescriptor& filter_descriptor, - DeviceMemoryBase /*filter_data*/, const dnn::BatchDescriptor& output_descriptor, - DeviceMemoryBase /*output_data*/, const dnn::ConvolutionDescriptor& convolution_descriptor, bool use_fallback, - ScratchAllocator* /*scratch_allocator*/, const NumericOptions& numeric_options, std::vector>* out_exec_plans, string serialized_graph) { - - if (!use_cudnn_frontend) { - return tsl::errors::Internal( - "cuDNN graph execution requires the use of the cuDNN frontend."); - } - auto cudnn = cudnn_->GetHandle(parent_, stream); TF_ASSIGN_OR_RETURN( auto op_graph_and_uids, GetGenericCudnnOperationGraph( kind, input_type, input_descriptor, filter_descriptor, output_descriptor, convolution_descriptor, cudnn, serialized_graph)); - return CreateOpRunners( + return CreateOpRunners( stream, cudnn, parent_, cudnn_.get(), std::move(op_graph_and_uids.first), kind, input_type, op_graph_and_uids.second, use_fallback, out_exec_plans, /*need_side_input=*/false, numeric_options); @@ -7066,11 +7033,11 @@ CudnnSupport::GraphConvolveRunnerFromDesc( RebuildExecutionPlan(cudnn, algorithm_desc, *op_graph_and_uids.first)); TF_ASSIGN_OR_RETURN(auto runner, - CudnnExecutionPlanRunner::Create( + CudnnExecutionPlanRunner::Create( parent_, cudnn_.get(), std::move(execution_plan), op_graph_and_uids.second, /*need_side_input=*/false)); - return {std::make_unique>( + return {std::make_unique>( std::move(runner))}; #else return tsl::errors::Unimplemented( diff --git a/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.h b/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.h index 2639bdfe918600..472f709a523e89 100644 --- a/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.h +++ b/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.h @@ -239,17 +239,14 @@ class CudnnSupport : public dnn::DnnSupport { const dnn::ConvolutionDescriptor& convolution_descriptor) override; tsl::Status GetGraphConvolveRunners( - bool use_cudnn_frontend, dnn::ConvolutionKind kind, - dnn::DataType input_type, dnn::DataType output_type, Stream* stream, - const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data, + dnn::ConvolutionKind kind, dnn::DataType input_type, + dnn::DataType output_type, Stream* stream, + const dnn::BatchDescriptor& input_descriptor, const dnn::FilterDescriptor& filter_descriptor, - DeviceMemoryBase filter_data, const dnn::BatchDescriptor& output_descriptor, - DeviceMemoryBase output_data, const dnn::ConvolutionDescriptor& convolution_descriptor, - bool use_fallback, ScratchAllocator* scratch_allocator, - const NumericOptions& numeric_options, - std::vector>* out_exec_plans, + bool use_fallback, const NumericOptions& numeric_options, + std::vector>* out_exec_plans, string serialized_graph) override; tsl::StatusOr> diff --git a/tensorflow/compiler/xla/stream_executor/dnn.cc b/tensorflow/compiler/xla/stream_executor/dnn.cc index 474fc51c017b8a..dec35f55d9db8f 100644 --- a/tensorflow/compiler/xla/stream_executor/dnn.cc +++ b/tensorflow/compiler/xla/stream_executor/dnn.cc @@ -150,18 +150,14 @@ DnnSupport::ConvolveRunnerFromDesc( } tsl::Status DnnSupport::GetGraphConvolveRunners( - bool /* use_cudnn_frontend */, dnn::ConvolutionKind /*kind*/, - dnn::DataType /*input_type*/, dnn::DataType /*output_type*/, - Stream* /*stream*/, const dnn::BatchDescriptor& /*input_descriptor*/, - DeviceMemoryBase /*input_data*/, + dnn::ConvolutionKind /*kind*/, dnn::DataType /*input_type*/, + dnn::DataType /*output_type*/, Stream* /*stream*/, + const dnn::BatchDescriptor& /*input_descriptor*/, const dnn::FilterDescriptor& /*filter_descriptor*/, - DeviceMemoryBase /*filter_data*/, const dnn::BatchDescriptor& /*output_descriptor*/, - DeviceMemoryBase /*output_data*/, const dnn::ConvolutionDescriptor& /*convolution_descriptor*/, - bool /*use_fallback*/, ScratchAllocator* /*scratch_allocator*/, - const NumericOptions& /*numeric_options*/, - std::vector>* /*exec_plans*/, + bool /*use_fallback*/, const NumericOptions& /*numeric_options*/, + std::vector>* /*exec_plans*/, string cudnn_grapph) { return tsl::errors::Unimplemented("GetGraphConvolveRunners not implemented."); } diff --git a/tensorflow/compiler/xla/stream_executor/dnn.h b/tensorflow/compiler/xla/stream_executor/dnn.h index c34dd6d8956587..14dbf00f54369d 100644 --- a/tensorflow/compiler/xla/stream_executor/dnn.h +++ b/tensorflow/compiler/xla/stream_executor/dnn.h @@ -965,13 +965,6 @@ class OpRunner { virtual tsl::Status operator()(Stream*, ProfileResult*, DeviceMemoryBase scratch_memory, Args... args) const = 0; - - // Launch the operation with a variable number of operands. - virtual tsl::Status operator()(Stream*, ProfileResult*, - DeviceMemoryBase scratch_memory, - std::vector) const { - return tsl::errors::Unimplemented("operator() not implemented."); - }; }; using ConvSignature = void(DeviceMemoryBase /* input_data */, @@ -979,7 +972,8 @@ using ConvSignature = void(DeviceMemoryBase /* input_data */, DeviceMemoryBase /* output_data */); using ConvRunner = OpRunner; -using GraphConvRunner = OpRunner; +using GraphConvSignature = void(std::vector); +using GraphConvRunner = OpRunner; using FusedConvSignature = void(DeviceMemoryBase /* input_data */, DeviceMemoryBase /* filter_data */, @@ -1636,20 +1630,17 @@ class DnnSupport { const dnn::ConvolutionDescriptor& convolution_descriptor); virtual tsl::Status GetGraphConvolveRunners( - bool use_cudnn_frontend, dnn::ConvolutionKind kind, - dnn::DataType input_type, dnn::DataType output_type, Stream* stream, - const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data, + dnn::ConvolutionKind kind, dnn::DataType input_type, + dnn::DataType output_type, Stream* stream, + const dnn::BatchDescriptor& input_descriptor, const dnn::FilterDescriptor& filter_descriptor, - DeviceMemoryBase filter_data, const dnn::BatchDescriptor& output_descriptor, - DeviceMemoryBase output_data, const dnn::ConvolutionDescriptor& convolution_descriptor, - bool use_fallback, ScratchAllocator* scratch_allocator, - const NumericOptions& numeric_options, - std::vector>* out_exec_plans, + bool use_fallback, const NumericOptions& numeric_options, + std::vector>* out_exec_plans, string serialized_graph); - virtual tsl::StatusOr> + virtual tsl::StatusOr> GraphConvolveRunnerFromDesc( Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, dnn::ConvolutionKind kind, dnn::DataType element_type, diff --git a/tensorflow/compiler/xla/stream_executor/lazy_op_runner.h b/tensorflow/compiler/xla/stream_executor/lazy_op_runner.h index 11ef2b3497793f..7b0b01a1cc575b 100644 --- a/tensorflow/compiler/xla/stream_executor/lazy_op_runner.h +++ b/tensorflow/compiler/xla/stream_executor/lazy_op_runner.h @@ -161,7 +161,7 @@ struct ConvOp { // Implementation of the concept required by LazyOpRunner, for // GraphConvolveRunner. struct GraphConvOp { - using Signature = ConvSignature; + using Signature = GraphConvSignature; struct Config { ConvolutionKind kind; @@ -173,7 +173,7 @@ struct GraphConvOp { string serialized_graph; }; - static tsl::StatusOr>> + static tsl::StatusOr>> RunnerFromAlgorithmDesc(const AlgorithmDesc& desc, Config config, Stream* stream) { return stream->GraphConvolveRunnerFromDesc( diff --git a/tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.cc b/tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.cc index 4bc3ae261bd1f3..63b697b66367e0 100644 --- a/tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.cc +++ b/tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.cc @@ -287,25 +287,23 @@ tsl::Status StreamExecutor::GetConvolveRunners( } tsl::Status StreamExecutor::GetGraphConvolveRunners( - bool use_cudnn_frontend, dnn::ConvolutionKind kind, - dnn::DataType input_type, dnn::DataType output_type, Stream* stream, - const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data, + dnn::ConvolutionKind kind, dnn::DataType input_type, + dnn::DataType output_type, Stream* stream, + const dnn::BatchDescriptor& input_descriptor, const dnn::FilterDescriptor& filter_descriptor, - DeviceMemoryBase filter_data, const dnn::BatchDescriptor& output_descriptor, - DeviceMemoryBase output_data, + const dnn::BatchDescriptor& output_descriptor, const dnn::ConvolutionDescriptor& convolution_descriptor, bool use_fallback, - ScratchAllocator* scratch_allocator, const NumericOptions& numeric_options, - std::vector>* out_exec_plans, + const NumericOptions& numeric_options, + std::vector>* out_exec_plans, string serialized_graph) { dnn::DnnSupport* dnn_support = AsDnn(); if (!dnn_support) { return tsl::errors::Unimplemented("DNN library is not found."); } return dnn_support->GetGraphConvolveRunners( - use_cudnn_frontend, kind, input_type, output_type, stream, - input_descriptor, input_data, filter_descriptor, filter_data, - output_descriptor, output_data, convolution_descriptor, use_fallback, - scratch_allocator, numeric_options, out_exec_plans, serialized_graph); + kind, input_type, output_type, stream, input_descriptor, + filter_descriptor, output_descriptor, convolution_descriptor, + use_fallback, numeric_options, out_exec_plans, serialized_graph); } tsl::Status StreamExecutor::GetFusedConvolveRunners( diff --git a/tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.h b/tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.h index 7a92d32aa3f99b..e31a58e618c29e 100644 --- a/tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.h +++ b/tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.h @@ -378,17 +378,14 @@ class StreamExecutor { std::vector>* out_exec_plans); tsl::Status GetGraphConvolveRunners( - bool use_cudnn_frontend, dnn::ConvolutionKind kind, - dnn::DataType input_type, dnn::DataType output_type, Stream* stream, - const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data, + dnn::ConvolutionKind kind, dnn::DataType input_type, + dnn::DataType output_type, Stream* stream, + const dnn::BatchDescriptor& input_descriptor, const dnn::FilterDescriptor& filter_descriptor, - DeviceMemoryBase filter_data, const dnn::BatchDescriptor& output_descriptor, - DeviceMemoryBase output_data, const dnn::ConvolutionDescriptor& convolution_descriptor, - bool use_fallback, ScratchAllocator* scratch_allocator, - const NumericOptions& numeric_options, - std::vector>* out_exec_plans, + bool use_fallback, const NumericOptions& numeric_options, + std::vector>* out_exec_plans, string serialized_graph); tsl::Status GetFusedConvolveRunners( From ecd080bd6c64682f6bee62f4455ea2c37c279f26 Mon Sep 17 00:00:00 2001 From: Philipp Hack Date: Fri, 21 Jul 2023 21:12:51 +0000 Subject: [PATCH 050/410] Support for FP8 convolutions in XLA. --- tensorflow/compiler/xla/service/gpu/BUILD | 75 ++++--- .../xla/service/gpu/buffer_comparator_test.cc | 16 +- .../service/gpu/cudnn_fused_conv_rewriter.cc | 199 ++++++++--------- .../gpu/cudnn_fused_conv_rewriter_test.cc | 202 +++++++----------- .../service/gpu/gpu_layout_assignment_test.cc | 25 +++ .../xla/stream_executor/cuda/cuda_dnn.cc | 57 +++-- 6 files changed, 282 insertions(+), 292 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 2c2ec570cea653..aa9d0055cf55ee 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -3066,6 +3066,45 @@ cc_library( "//tensorflow/tsl/platform:statusor", ] + if_cuda_is_configured([ "@local_config_cuda//cuda:cuda_headers", + "@local_config_cuda//cuda:cudnn_header", + ]), +) + +xla_cc_test( + name = "cudnn_fused_conv_rewriter_test", + srcs = ["cudnn_fused_conv_rewriter_test.cc"], + shard_count = 10, + tags = [ + "gpu", + "no_oss", + "noasan", + "nomsan", + # This test runs some fusions that are only supported on Ampere+. + "requires-gpu-sm80", + ], + deps = [ + ":backend_configs_cc", + ":cublas_cudnn", + ":cudnn_fused_conv_rewriter", + ":gpu_conv_rewriter", + "//tensorflow/compiler/xla/service:algebraic_simplifier", + "//tensorflow/compiler/xla/service:convert_mover", + "//tensorflow/compiler/xla/service:hlo_constant_folding", + "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/compiler/xla/service:hlo_pass_pipeline", + "//tensorflow/compiler/xla/service:pattern_matcher", + "//tensorflow/compiler/xla/service:pattern_matcher_gmock", + "//tensorflow/compiler/xla/service:reshape_mover", + "//tensorflow/compiler/xla/service/gpu/tests:gpu_codegen_test", + "//tensorflow/compiler/xla/tests:filecheck", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/tsl/lib/core:status_test_util", + "//tensorflow/tsl/platform:test_main", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", + ] + if_cuda_is_configured([ + "@local_config_cuda//cuda:cuda_headers", + "@local_config_cuda//cuda:cudnn_header", ]), ) @@ -3086,6 +3125,7 @@ cc_library( "//tensorflow/tsl/platform:statusor", ] + if_cuda_is_configured([ "@local_config_cuda//cuda:cuda_headers", + "@local_config_cuda//cuda:cudnn_header", ]), ) @@ -3109,41 +3149,6 @@ xla_cc_test( ], ) -xla_cc_test( - name = "cudnn_fused_conv_rewriter_test", - srcs = ["cudnn_fused_conv_rewriter_test.cc"], - shard_count = 10, - tags = [ - "gpu", - "no_oss", - "noasan", - "nomsan", - # This test runs some fusions that are only supported on Ampere+. - "requires-gpu-sm80", - ], - deps = [ - ":backend_configs_cc", - ":cublas_cudnn", - ":cudnn_fused_conv_rewriter", - ":gpu_conv_rewriter", - "//tensorflow/compiler/xla/service:algebraic_simplifier", - "//tensorflow/compiler/xla/service:convert_mover", - "//tensorflow/compiler/xla/service:hlo_constant_folding", - "//tensorflow/compiler/xla/service:hlo_pass", - "//tensorflow/compiler/xla/service:hlo_pass_pipeline", - "//tensorflow/compiler/xla/service:pattern_matcher", - "//tensorflow/compiler/xla/service:pattern_matcher_gmock", - "//tensorflow/compiler/xla/service:reshape_mover", - "//tensorflow/compiler/xla/service/gpu/tests:gpu_codegen_test", - "//tensorflow/compiler/xla/tests:filecheck", - "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/tsl/lib/core:status_test_util", - "//tensorflow/tsl/platform:test_main", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest_main", - ], -) - xla_cc_test( name = "conv_layout_normalization_test", srcs = ["conv_layout_normalization_test.cc"], diff --git a/tensorflow/compiler/xla/service/gpu/buffer_comparator_test.cc b/tensorflow/compiler/xla/service/gpu/buffer_comparator_test.cc index 04550d6bb122aa..74c91b34082240 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_comparator_test.cc +++ b/tensorflow/compiler/xla/service/gpu/buffer_comparator_test.cc @@ -166,22 +166,18 @@ TEST_F(BufferComparatorTest, TestInfs) { EXPECT_TRUE( CompareEqualFloatBuffers({inf}, {std::nanf("")})); EXPECT_TRUE(CompareEqualFloatBuffers({inf}, {inf})); - EXPECT_TRUE(CompareEqualFloatBuffers({inf}, {65504})); - EXPECT_TRUE(CompareEqualFloatBuffers({-inf}, {-65504})); - EXPECT_TRUE(CompareEqualFloatBuffers({inf}, {-65504})); - EXPECT_TRUE(CompareEqualFloatBuffers({-inf}, {65504})); + EXPECT_TRUE(CompareEqualFloatBuffers({inf}, {-inf})); + EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {448})); + EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {-448})); EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {20})); EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {-20})); - EXPECT_FALSE(CompareEqualFloatBuffers({-inf}, {20})); - EXPECT_FALSE(CompareEqualFloatBuffers({-inf}, {-20})); EXPECT_FALSE( CompareEqualFloatBuffers({inf}, {std::nanf("")})); EXPECT_TRUE(CompareEqualFloatBuffers({inf}, {inf})); - EXPECT_TRUE(CompareEqualFloatBuffers({inf}, {65504})); - EXPECT_TRUE(CompareEqualFloatBuffers({-inf}, {-65504})); - EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {-65504})); - EXPECT_FALSE(CompareEqualFloatBuffers({-inf}, {65504})); + EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {-inf})); + EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {57344})); + EXPECT_FALSE(CompareEqualFloatBuffers({-inf}, {-57344})); EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {20})); EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {-20})); EXPECT_FALSE(CompareEqualFloatBuffers({-inf}, {20})); diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc index 3c1f337c0f1d20..e5a0339c13faec 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/tsl/platform/errors.h" #include "tensorflow/tsl/platform/statusor.h" #include "third_party/gpus/cuda/include/cuda.h" +#include "third_party/gpus/cudnn/cudnn.h" namespace xla { namespace gpu { @@ -353,83 +354,83 @@ void CaptureConvGraphRecursive(HloInstruction* instr, final_instr = instr; } - if (instr->user_count() == 1) { - HloInstruction *op, *operand, *user = instr->users()[0]; - if (pattern_level == 0) { - // Add - if (Match(user, m::AddAnyOrder(&op, m::Op(), m::Op(&operand)))) { - graph_string.AppendOp("add", op->shape().element_type()); - operands.push_back(operand); - CaptureConvGraphRecursive(user, operands, graph_string, visited_instrs, - final_instr, 0); - return; - } - // Scale - if (Match(user, m::MultiplyAnyOrder(&op, m::Op(), - m::Broadcast(m::Op(&operand)))) && - ShapeUtil::IsScalar(operand->shape())) { - graph_string.AppendOp("scale", op->shape().element_type()); - operands.push_back(operand); - CaptureConvGraphRecursive(user, operands, graph_string, visited_instrs, - final_instr, 0); - return; - } - // Inverse Scale - if (Match(user, m::Divide(&op, m::Op(), m::Broadcast(m::Op(&operand)))) && - ShapeUtil::IsScalar(operand->shape())) { - graph_string.AppendOp("invscale", op->shape().element_type()); - operands.push_back(operand); - CaptureConvGraphRecursive(user, operands, graph_string, visited_instrs, - final_instr, 0); - return; - } - // ReLU - if (Match(user, m::MaximumAnyOrder(&op, m::Op(), - m::Broadcast(m::ConstantScalar(0))))) { - graph_string.AppendOp("relu", op->shape().element_type()); - CaptureConvGraphRecursive(user, operands, graph_string, visited_instrs, - final_instr, 0); - return; - } - } + if (instr->user_count() != 1) { + return; + } - if (pattern_level == 1) { - // Convert with clamp to FP8 types - HloInstruction *clamp_lower, *clamp_upper; - if (Match(user, - m::Convert( - &op, - m::Clamp(m::Broadcast(m::ConstantScalar(&clamp_lower)), - m::Op(), - m::Broadcast(m::ConstantScalar(&clamp_upper)))))) { - if ((op->shape().element_type() == F8E4M3FN && - clamp_lower->literal().IsAllFloat(static_cast( - std::numeric_limits::lowest())) && - clamp_upper->literal().IsAllFloat(static_cast( - std::numeric_limits::max()))) || - (op->shape().element_type() == F8E5M2 && - clamp_lower->literal().IsAllFloat(static_cast( - std::numeric_limits::lowest())) && - clamp_upper->literal().IsAllFloat(static_cast( - std::numeric_limits::max())))) { - graph_string.ChangeDataType(op->shape().element_type()); - CaptureConvGraphRecursive(user, operands, graph_string, - visited_instrs, final_instr, 0); - return; - } - } + HloInstruction *op, *operand, *user = instr->users()[0]; + if (pattern_level == 0) { + // Add + if (Match(user, m::AddAnyOrder(&op, m::Op(), m::Op(&operand)))) { + graph_string.AppendOp("add", op->shape().element_type()); + operands.push_back(operand); + CaptureConvGraphRecursive(user, operands, graph_string, visited_instrs, + final_instr, 0); + return; } - - // If none of the matches was successful and the pattern level is below the - // maximum level, attempt to match at higher level. - if (pattern_level < max_pattern_level) { + // Scale + if (Match(user, m::MultiplyAnyOrder(&op, m::Op(), + m::Broadcast(m::Op(&operand)))) && + ShapeUtil::IsScalar(operand->shape())) { + graph_string.AppendOp("scale", op->shape().element_type()); + operands.push_back(operand); + CaptureConvGraphRecursive(user, operands, graph_string, visited_instrs, + final_instr, 0); + return; + } + // Inverse Scale + if (Match(user, m::Divide(&op, m::Op(), m::Broadcast(m::Op(&operand)))) && + ShapeUtil::IsScalar(operand->shape())) { + graph_string.AppendOp("invscale", op->shape().element_type()); + operands.push_back(operand); CaptureConvGraphRecursive(user, operands, graph_string, visited_instrs, - final_instr, pattern_level + 1); + final_instr, 0); return; } + // ReLU + if (Match(user, m::MaximumAnyOrder(&op, m::Op(), + m::Broadcast(m::ConstantScalar(0))))) { + graph_string.AppendOp("relu", op->shape().element_type()); + CaptureConvGraphRecursive(user, operands, graph_string, visited_instrs, + final_instr, 0); + return; + } + } + + if (pattern_level == 1) { + // Convert with clamp to FP8 types + HloInstruction *clamp_lower, *clamp_upper; + if (Match( + user, + m::Convert( + &op, + m::Clamp(m::Broadcast(m::ConstantScalar(&clamp_lower)), m::Op(), + m::Broadcast(m::ConstantScalar(&clamp_upper)))))) { + if ((op->shape().element_type() == F8E4M3FN && + clamp_lower->literal().IsAllFloat(static_cast( + std::numeric_limits::lowest())) && + clamp_upper->literal().IsAllFloat(static_cast( + std::numeric_limits::max()))) || + (op->shape().element_type() == F8E5M2 && + clamp_lower->literal().IsAllFloat(static_cast( + std::numeric_limits::lowest())) && + clamp_upper->literal().IsAllFloat(static_cast( + std::numeric_limits::max())))) { + graph_string.ChangeDataType(op->shape().element_type()); + CaptureConvGraphRecursive(user, operands, graph_string, visited_instrs, + final_instr, 0); + return; + } + } } - return; + // If none of the matches was successful and the pattern level is below the + // maximum level, attempt to match at higher level. + if (pattern_level < max_pattern_level) { + CaptureConvGraphRecursive(user, operands, graph_string, visited_instrs, + final_instr, pattern_level + 1); + return; + } } // Captures in a GraphString the subgraph of pointwise operations operating on @@ -477,16 +478,17 @@ CaptureConvGraph(HloInstruction* instr, HloInstruction* x_scale, // following steps are elided and rewritten into a ForwardGraph Custom Call: // // 1. Cast the filter and input from FP8 to a wider type such as FP16 or FP32. -// 2. Unscale the filter and input by multiplying or dividing by scalars. +// 2. Optionally unscale the filter and input by multiplying or dividing by +// scalars. // 3. Evaluate the convolution based on the scaled filter and input. // 4. Apply a series of elementwise transformations, where a transformation can // be adding a matrix bias, applying a ReLU activation, or // multiplying or dividing by a broadcast scalar. -// 5. Cast the output back to FP8. +// 5. Optionally cast the output back to FP8. StatusOr F8GraphConv(HloComputation* comp, se::CudaComputeCapability cc) { bool changed = false; -#if CUDA_VERSION >= 12000 +#if (CUDA_VERSION >= 12000 && CUDNN_VERSION >= 8900) for (auto instr : comp->MakeInstructionPostOrder()) { const DebugOptions& debug_options = instr->GetModule()->config().debug_options(); @@ -504,6 +506,7 @@ StatusOr F8GraphConv(HloComputation* comp, se::CudaComputeCapability cc) { m::CustomCall( &convolution, m::AnyOf( + m::Op(&input).WithPredicate(IsF8Type), m::Convert(m::Op(&input).WithPredicate(IsF8Type)), m::Divide(&x_scale_op, m::Convert(m::Op(&input).WithPredicate(IsF8Type)), @@ -513,6 +516,7 @@ StatusOr F8GraphConv(HloComputation* comp, se::CudaComputeCapability cc) { m::Convert(m::Op(&input).WithPredicate(IsF8Type)), m::Broadcast(m::Op(&x_scale)))), m::AnyOf( + m::Op(&filter).WithPredicate(IsF8Type), m::Convert(m::Op(&filter).WithPredicate(IsF8Type)), m::Divide(&w_scale_op, m::Convert(m::Op(&input).WithPredicate(IsF8Type)), @@ -523,6 +527,12 @@ StatusOr F8GraphConv(HloComputation* comp, se::CudaComputeCapability cc) { m::Broadcast(m::Op(&w_scale))))), 0); if (Match(instr, pattern)) { + if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] { + return absl::StrCat("F8GraphConv: ", convolution->ToString()); + })) { + continue; + } + std::vector operands; GraphString graph_string; HloInstruction* final_instr; @@ -530,31 +540,28 @@ StatusOr F8GraphConv(HloComputation* comp, se::CudaComputeCapability cc) { const_cast(instr), x_scale, w_scale, x_scale_op ? x_scale_op->opcode() == HloOpcode::kMultiply : false, w_scale_op ? w_scale_op->opcode() == HloOpcode::kMultiply : false); - if (graph_string.Size() > 1) { - TF_ASSIGN_OR_RETURN( - auto config, convolution->backend_config()); - config.set_serialized_graph(graph_string.Graph()); - operands.insert(operands.begin(), input); - operands.insert(operands.begin() + 1, filter); - - Shape new_shape = ShapeUtil::MakeTupleShape( - {ShapeUtil::ChangeElementType( - ShapeUtil::GetTupleElementShape(convolution->shape(), 0), - final_instr->shape().element_type()), - ShapeUtil::GetTupleElementShape(convolution->shape(), 1)}); - HloInstruction* new_convolution = comp->AddInstruction( - convolution->CloneWithNewOperands(new_shape, operands)); - new_convolution->set_custom_call_target( - kCudnnConvForwardGraphCallTarget); - TF_RETURN_IF_ERROR(new_convolution->set_backend_config(config)); - TF_ASSIGN_OR_RETURN(HloInstruction * new_gte, - MakeGetTupleElementHlo(new_convolution, 0)); - TF_RETURN_IF_ERROR(comp->ReplaceInstruction(final_instr, new_gte)); - changed = true; - } + TF_ASSIGN_OR_RETURN( + auto config, convolution->backend_config()); + config.set_serialized_graph(graph_string.Graph()); + operands.insert(operands.begin(), input); + operands.insert(operands.begin() + 1, filter); + + Shape new_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::ChangeElementType( + ShapeUtil::GetTupleElementShape(convolution->shape(), 0), + final_instr->shape().element_type()), + ShapeUtil::GetTupleElementShape(convolution->shape(), 1)}); + HloInstruction* new_convolution = comp->AddInstruction( + convolution->CloneWithNewOperands(new_shape, operands)); + new_convolution->set_custom_call_target(kCudnnConvForwardGraphCallTarget); + TF_RETURN_IF_ERROR(new_convolution->set_backend_config(config)); + TF_ASSIGN_OR_RETURN(HloInstruction * new_gte, + MakeGetTupleElementHlo(new_convolution, 0)); + TF_RETURN_IF_ERROR(comp->ReplaceInstruction(final_instr, new_gte)); + changed = true; } } -#endif +#endif // CUDA_VERSION >= 12000 && CUDNN_VERSION >= 8900 return changed; } diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc index 401bbe13572447..c505b1124618ab 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc @@ -34,6 +34,8 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/filecheck.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/tsl/lib/core/status_test_util.h" +#include "third_party/gpus/cuda/include/cuda.h" +#include "third_party/gpus/cudnn/cudnn.h" namespace xla { namespace gpu { @@ -140,14 +142,13 @@ class CudnnFusedConvRewriterTest : public GpuCodegenTest { } } - void TestF8(absl::string_view pre_hlo_string, - absl::string_view custom_call_string, - absl::string_view serialized_graph_string) { + void TestF8(std::string pre_hlo_string, std::string custom_call_string, + std::string serialized_graph_string) { if (GetCudaComputeCapability().IsAtLeast( se::CudaComputeCapability::HOPPER)) { - // On Hopper or newer architectures, test numerical correctness and verify - // the HLO of the Custom Call with operand and return layouts and the - // serialized graph based on the full compiler pipeline. + // On Hopper and newer architectures, test numerical correctness and + // verify the HLO of the Custom Call with operand and return layouts and + // the serialized graph based on the full compiler pipeline. std::string optimized_hlo_string = GetOptimizedHlo(pre_hlo_string); EXPECT_THAT(optimized_hlo_string, Not(HasSubstr("Convert"))); EXPECT_THAT(optimized_hlo_string, HasSubstr("__cudnn$conv")); @@ -165,12 +166,24 @@ class CudnnFusedConvRewriterTest : public GpuCodegenTest { EXPECT_TRUE(*filecheck_result); } else { // On older architectures, disregard layout information and only verify - // the serialized graph based on the GpuConvRewriter and - // CudnnFusedConvRewriter passes. + // the basic configuration of the convolution Custom Call using the number + // of operands and the window_size and serialized graph attributes based + // on the GpuConvRewriter and CudnnFusedConvRewriter passes. + std::string::size_type p0 = custom_call_string.find(':'); + std::string::size_type p1 = custom_call_string.find("custom-call"); + custom_call_string.erase(p0 + 1, p1 - p0 - 2); + p0 = custom_call_string.find(", dim_labels"); + custom_call_string.erase(p0); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(pre_hlo_string)); TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(GpuConvRewriter(), module.get())); + RunAndFilecheckHloRewrite( + module->ToString(HloPrintOptions{}.set_print_operand_shape(false)), + CudnnFusedConvRewriter( + se::CudaComputeCapability{se::CudaComputeCapability::HOPPER, 0}), + custom_call_string); RunAndFilecheckHloRewrite( module->ToString(HloPrintOptions{}.set_print_operand_shape(false)), CudnnFusedConvRewriter( @@ -178,6 +191,30 @@ class CudnnFusedConvRewriterTest : public GpuCodegenTest { serialized_graph_string); } } + + void TestF8Parameterized(std::string template_pre_hlo_string, + std::string template_custom_call_string, + std::string template_serialized_graph_string) { + std::array types = {"f8e4m3fn", "f8e5m2"}; + std::array clamp_lower = {"-448.", "-57344."}; + std::array clamp_upper = {"448.", "57344."}; + absl::flat_hash_map replacements; + for (int i = 0; i < 2; ++i) { + replacements["<>"] = types[i]; + for (int j = 0; j < 2; ++j) { + replacements["<>"] = types[j]; + for (int k = 0; k < 2; ++k) { + replacements["<>"] = types[k]; + replacements["<>"] = clamp_lower[k]; + replacements["<>"] = clamp_upper[k]; + TestF8(absl::StrReplaceAll(template_pre_hlo_string, replacements), + absl::StrReplaceAll(template_custom_call_string, replacements), + absl::StrReplaceAll(template_serialized_graph_string, + replacements)); + } + } + } + } }; TEST_F(CudnnFusedConvRewriterTest, TestConvOnly) { @@ -640,7 +677,10 @@ TEST_F(CudnnFusedConvRewriterTest, TestPreservesFeatureGroupCount) { EXPECT_TRUE(RunAndCompare(kHloString, ErrorSpec{0.01})); } -TEST_F(CudnnFusedConvRewriterTest, TestConvScaledOutputF8) { +TEST_F(CudnnFusedConvRewriterTest, TestConvF8) { +#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8900) + GTEST_SKIP() << "FP8 convolutions require CUDA 12 and cuDNN 8.9."; +#endif TestF8( // pre_hlo R"( @@ -649,31 +689,23 @@ TEST_F(CudnnFusedConvRewriterTest, TestConvScaledOutputF8) { ENTRY Test { input = f8e4m3fn[1,128,6,6] parameter(0) filter = f8e4m3fn[3,3,128,16] parameter(1) - input_f32 = f32[1,128,6,6] convert(input) - filter_f32 = f32[3,3,128,16] convert(filter) - z_scale = f32[] parameter(2) - z_scale_bcast = f32[1,16,6,6] broadcast(z_scale), dimensions={} - conv_a = f32[1,16,6,6] convolution(input_f32, filter_f32), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1 - conv_a_scaled = f32[1,16,6,6] multiply(conv_a, z_scale_bcast) - c1 = f32[] constant(-448.) - c1_bcast = f32[1,16,6,6] broadcast(c1), dimensions={} - c2 = f32[] constant(448.) - c2_bcast = f32[1,16,6,6] broadcast(c2), dimensions={} - conv_a_clamped = f32[1,16,6,6] clamp(c1_bcast, conv_a_scaled, c2_bcast) - ROOT conv_f8 = f8e4m3fn[1,16,6,6] convert(conv_a_clamped) + ROOT conv_a = f8e4m3fn[1,16,6,6] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1 })", // custom_call R"( -// CHECK: [[cudnn_conv_4_0:%[^ ]+]] = (f8e4m3fn[1,6,6,16]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[OPERAND0:%[^ ]+]], [[OPERAND1:%[^ ]+]], [[OPERAND2:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForwardGraph" +// CHECK: [[cudnn_conv_4_0:%[^ ]+]] = (f8e4m3fn[1,6,6,16]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[OPERAND0:%[^ ]+]], [[OPERAND1:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForwardGraph" )", // serialized_graph R"( -// CHECK: "serialized_graph":"conv[f32]-\u003escale[f8e4m3fn]-\u003e" +// CHECK: "serialized_graph":"conv[f8e4m3fn]-\u003e" )"); } -TEST_F(CudnnFusedConvRewriterTest, TestConvScaledF8) { +TEST_F(CudnnFusedConvRewriterTest, TestConvScaledOutputF8) { +#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8900) + GTEST_SKIP() << "FP8 convolutions require CUDA 12 and cuDNN 8.9."; +#endif TestF8( // pre_hlo R"( @@ -682,17 +714,11 @@ TEST_F(CudnnFusedConvRewriterTest, TestConvScaledF8) { ENTRY Test { input = f8e4m3fn[1,128,6,6] parameter(0) filter = f8e4m3fn[3,3,128,16] parameter(1) - input_scale = f32[] parameter(2) - input_scale_bcast = f32[1,128,6,6] broadcast(input_scale), dimensions={} - filter_scale = f32[] parameter(3) - filter_scale_bcast = f32[3,3,128,16] broadcast(filter_scale), dimensions={} input_f32 = f32[1,128,6,6] convert(input) - input_unscaled = f32[1,128,6,6] multiply(input_f32, input_scale_bcast) filter_f32 = f32[3,3,128,16] convert(filter) - filter_unscaled = f32[3,3,128,16] multiply(filter_f32, filter_scale_bcast) - z_scale = f32[] parameter(4) + z_scale = f32[] parameter(2) z_scale_bcast = f32[1,16,6,6] broadcast(z_scale), dimensions={} - conv_a = f32[1,16,6,6] convolution(input_unscaled, filter_unscaled), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1 + conv_a = f32[1,16,6,6] convolution(input_f32, filter_f32), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1 conv_a_scaled = f32[1,16,6,6] multiply(conv_a, z_scale_bcast) c1 = f32[] constant(-448.) c1_bcast = f32[1,16,6,6] broadcast(c1), dimensions={} @@ -704,101 +730,26 @@ TEST_F(CudnnFusedConvRewriterTest, TestConvScaledF8) { })", // custom_call R"( -// CHECK: [[cudnn_conv_4_0:%[^ ]+]] = (f8e4m3fn[1,6,6,16]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[OPERAND0:%[^ ]+]], [[OPERAND1:%[^ ]+]], [[OPERAND2:%[^ ]+]], [[OPERAND3:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForwardGraph" - )", - // serialized_graph - R"( -// CHECK: "serialized_graph":"conv[f32]-\u003escale[f32]-\u003escale[f8e4m3fn]-\u003e" - )"); -} - -TEST_F(CudnnFusedConvRewriterTest, TestConvScaledInputE5M2F8) { - TestF8( - // pre_hlo - R"( - HloModule Test - - ENTRY Test { - input = f8e5m2[1,128,6,6] parameter(0) - filter = f8e4m3fn[3,3,128,16] parameter(1) - input_scale = f32[] parameter(2) - input_scale_bcast = f32[1,128,6,6] broadcast(input_scale), dimensions={} - filter_scale = f32[] parameter(3) - filter_scale_bcast = f32[3,3,128,16] broadcast(filter_scale), dimensions={} - input_f32 = f32[1,128,6,6] convert(input) - input_unscaled = f32[1,128,6,6] multiply(input_f32, input_scale_bcast) - filter_f32 = f32[3,3,128,16] convert(filter) - filter_unscaled = f32[3,3,128,16] multiply(filter_f32, filter_scale_bcast) - z_scale = f32[] parameter(4) - z_scale_bcast = f32[1,16,6,6] broadcast(z_scale), dimensions={} - conv_a = f32[1,16,6,6] convolution(input_unscaled, filter_unscaled), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1 - conv_a_scaled = f32[1,16,6,6] multiply(conv_a, z_scale_bcast) - c1 = f32[] constant(-57344.) - c1_bcast = f32[1,16,6,6] broadcast(c1), dimensions={} - c2 = f32[] constant(57344.) - c2_bcast = f32[1,16,6,6] broadcast(c2), dimensions={} - conv_a_clamped = f32[1,16,6,6] clamp(c1_bcast, conv_a_scaled, c2_bcast) - ROOT conv_f8 = f8e5m2[1,16,6,6] convert(conv_a_clamped) - - })", - // custom_call - R"( -// CHECK: [[cudnn_conv_4_0:%[^ ]+]] = (f8e5m2[1,6,6,16]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[OPERAND0:%[^ ]+]], [[OPERAND1:%[^ ]+]], [[OPERAND2:%[^ ]+]], [[OPERAND3:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForwardGraph" - )", - // serialized_graph - R"( -// CHECK: "serialized_graph":"conv[f32]-\u003escale[f32]-\u003escale[f8e5m2]-\u003e" - )"); -} - -TEST_F(CudnnFusedConvRewriterTest, TestConvScaledFilterE5M2F8) { - TestF8( - // pre_hlo - R"( - HloModule Test - - ENTRY Test { - input = f8e4m3fn[1,128,6,6] parameter(0) - filter = f8e5m2[3,3,128,16] parameter(1) - input_scale = f32[] parameter(2) - input_scale_bcast = f32[1,128,6,6] broadcast(input_scale), dimensions={} - filter_scale = f32[] parameter(3) - filter_scale_bcast = f32[3,3,128,16] broadcast(filter_scale), dimensions={} - input_f32 = f32[1,128,6,6] convert(input) - input_unscaled = f32[1,128,6,6] multiply(input_f32, input_scale_bcast) - filter_f32 = f32[3,3,128,16] convert(filter) - filter_unscaled = f32[3,3,128,16] multiply(filter_f32, filter_scale_bcast) - z_scale = f32[] parameter(4) - z_scale_bcast = f32[1,16,6,6] broadcast(z_scale), dimensions={} - conv_a = f32[1,16,6,6] convolution(input_unscaled, filter_unscaled), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1 - conv_a_scaled = f32[1,16,6,6] multiply(conv_a, z_scale_bcast) - c1 = f32[] constant(-57344.) - c1_bcast = f32[1,16,6,6] broadcast(c1), dimensions={} - c2 = f32[] constant(57344.) - c2_bcast = f32[1,16,6,6] broadcast(c2), dimensions={} - conv_a_clamped = f32[1,16,6,6] clamp(c1_bcast, conv_a_scaled, c2_bcast) - ROOT conv_f8 = f8e5m2[1,16,6,6] convert(conv_a_clamped) - - })", - // custom_call - R"( -// CHECK: [[cudnn_conv_4_0:%[^ ]+]] = (f8e5m2[1,6,6,16]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[OPERAND0:%[^ ]+]], [[OPERAND1:%[^ ]+]], [[OPERAND2:%[^ ]+]], [[OPERAND3:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForwardGraph" +// CHECK: [[cudnn_conv_4_0:%[^ ]+]] = (f8e4m3fn[1,6,6,16]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[OPERAND0:%[^ ]+]], [[OPERAND1:%[^ ]+]], [[OPERAND2:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForwardGraph" )", // serialized_graph R"( -// CHECK: "serialized_graph":"conv[f32]-\u003escale[f32]-\u003escale[f8e5m2]-\u003e" +// CHECK: "serialized_graph":"conv[f32]-\u003escale[f8e4m3fn]-\u003e" )"); } -TEST_F(CudnnFusedConvRewriterTest, TestConvScaledInputE5M2FilterE5M2F8) { - TestF8( +TEST_F(CudnnFusedConvRewriterTest, TestConvScaledF8Parameterized) { +#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8900) + GTEST_SKIP() << "FP8 convolutions require CUDA 12 and cuDNN 8.9."; +#endif + TestF8Parameterized( // pre_hlo R"( HloModule Test ENTRY Test { - input = f8e5m2[1,128,6,6] parameter(0) - filter = f8e5m2[3,3,128,16] parameter(1) + input = <>[1,128,6,6] parameter(0) + filter = <>[3,3,128,16] parameter(1) input_scale = f32[] parameter(2) input_scale_bcast = f32[1,128,6,6] broadcast(input_scale), dimensions={} filter_scale = f32[] parameter(3) @@ -811,25 +762,28 @@ TEST_F(CudnnFusedConvRewriterTest, TestConvScaledInputE5M2FilterE5M2F8) { z_scale_bcast = f32[1,16,6,6] broadcast(z_scale), dimensions={} conv_a = f32[1,16,6,6] convolution(input_unscaled, filter_unscaled), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1 conv_a_scaled = f32[1,16,6,6] multiply(conv_a, z_scale_bcast) - c1 = f32[] constant(-57344.) + c1 = f32[] constant(<>) c1_bcast = f32[1,16,6,6] broadcast(c1), dimensions={} - c2 = f32[] constant(57344.) + c2 = f32[] constant(<>) c2_bcast = f32[1,16,6,6] broadcast(c2), dimensions={} conv_a_clamped = f32[1,16,6,6] clamp(c1_bcast, conv_a_scaled, c2_bcast) - ROOT conv_f8 = f8e5m2[1,16,6,6] convert(conv_a_clamped) + ROOT conv_f8 = <>[1,16,6,6] convert(conv_a_clamped) })", // custom_call R"( -// CHECK: [[cudnn_conv_4_0:%[^ ]+]] = (f8e5m2[1,6,6,16]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[OPERAND0:%[^ ]+]], [[OPERAND1:%[^ ]+]], [[OPERAND2:%[^ ]+]], [[OPERAND3:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForwardGraph" +// CHECK: [[cudnn_conv_4_0:%[^ ]+]] = (<>[1,6,6,16]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[OPERAND0:%[^ ]+]], [[OPERAND1:%[^ ]+]], [[OPERAND2:%[^ ]+]], [[OPERAND3:%[^ ]+]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForwardGraph" )", // serialized_graph R"( -// CHECK: "serialized_graph":"conv[f32]-\u003escale[f32]-\u003escale[f8e5m2]-\u003e" +// CHECK: "serialized_graph":"conv[f32]-\u003escale[f32]-\u003escale[<>]-\u003e" )"); } TEST_F(CudnnFusedConvRewriterTest, TestConvScaledBiasF8) { +#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8900) + GTEST_SKIP() << "FP8 convolutions require CUDA 12 and cuDNN 8.9."; +#endif TestF8( // pre_hlo R"( @@ -871,6 +825,9 @@ TEST_F(CudnnFusedConvRewriterTest, TestConvScaledBiasF8) { } TEST_F(CudnnFusedConvRewriterTest, TestConvInvscaledF8) { +#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8900) + GTEST_SKIP() << "FP8 convolutions require CUDA 12 and cuDNN 8.9."; +#endif TestF8( // pre_hlo R"( @@ -904,6 +861,9 @@ TEST_F(CudnnFusedConvRewriterTest, TestConvInvscaledF8) { } TEST_F(CudnnFusedConvRewriterTest, TestConvScaledReluActivationF8) { +#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8900) + GTEST_SKIP() << "FP8 convolutions require CUDA 12 and cuDNN 8.9."; +#endif TestF8( // pre_hlo R"( diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc index b057885a5f2256..292d10050b5cbd 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc @@ -396,6 +396,31 @@ ENTRY entry { expect_layout(call_0->operand(1)->shape(), {1, 2, 0}); } +TEST_F(LayoutAssignmentTest, ConvCuDNNF8) { + if (!GetCudaComputeCapability().IsAtLeast( + se::CudaComputeCapability::HOPPER)) { + GTEST_SKIP() << "FP8 convolutions require HOPPER or newer archiecture."; + } + + const char* hlo = R"( + + HloModule jit_conv_general_dilated + + ENTRY main.4 { + Arg_0 = f8e4m3fn[1,64,64,16]{3,2,1,0} parameter(0) + Arg_1 = f8e4m3fn[3,3,16,32]{3,2,1,0} parameter(1) + ROOT conv = f8e4m3fn[1,64,64,32]{3,2,1,0} convolution(Arg_0, Arg_1), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f + } +)"; + + MatchOptimizedHlo(hlo, R"( + // CHECK: [[P0:%[^ ]+]] = f8e4m3fn[1,64,64,16]{3,2,1,0} parameter(0) + // CHECK: [[P1:%[^ ]+]] = f8e4m3fn[3,3,16,32]{3,2,1,0} parameter(1) + // CHECK-NEXT: [[P2:%[^ ]+]] = f8e4m3fn[32,3,3,16]{3,2,1,0} transpose([[P1]]), dimensions={3,0,1,2} + // CHECK-NEXT: [[CONV:%[^ ]+]] = (f8e4m3fn[1,64,64,32]{3,2,1,0}, u8[0]{0}) custom-call([[P0]], [[P2]]), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_o01i->b01f, custom_call_target="__cudnn$convForwardGraph" + )"); +} + TEST_F(LayoutAssignmentTest, ConvCuDNNBF16) { if (!GetCudaComputeCapability().IsAtLeast( se::CudaComputeCapability::AMPERE)) { diff --git a/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc b/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc index 717eecb72f199c..430361066d394d 100644 --- a/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc @@ -3541,18 +3541,6 @@ tsl::StatusOr CreateCudnnTensor( return tensor; } -tsl::StatusOr CreateCudnnTensor( - const cudnn_frontend::Tensor& original, int64_t uid, dnn::DataType dtype, - bool is_virtual = false) { - auto tensor = cudnn_frontend::TensorBuilder() - .cloneFrom(original, uid) - .setAlignment(32) - .setDataType(ToCudnnDataType(dtype)) - .setVirtual(is_virtual) - .build(); - RETURN_MSG_IF_CUDNN_ERROR(tensor); - return tensor; -} #else tsl::StatusOr CreateCudnnTensor( absl::Span dims, absl::Span strides, @@ -3583,13 +3571,22 @@ tsl::StatusOr CreateCudnnTensor( return tensor; } +#endif + +#if (CUDNN_VERSION >= 8900 && TF_ENABLE_CUDNN_FRONTEND) tsl::StatusOr CreateCudnnTensor( - cudnn_frontend::Tensor original, int64_t uid, dnn::DataType dtype, + const cudnn_frontend::Tensor& original, int64_t uid, dnn::DataType dtype, bool is_virtual = false) { - return tsl::errors : Internal("Not implemented."); + auto tensor = cudnn_frontend::TensorBuilder() + .cloneFrom(original, uid) + .setAlignment(32) + .setDataType(ToCudnnDataType(dtype)) + .setVirtual(is_virtual) + .build(); + RETURN_MSG_IF_CUDNN_ERROR(tensor); + return tensor; } - -#endif +#endif // CUDNN_VERSION >= 8900 && TF_ENABLE_CUDNN_FRONTEND #if (CUDNN_VERSION >= 8800 && TF_ENABLE_CUDNN_FRONTEND) enum CudnnfMHAUid { @@ -4214,9 +4211,9 @@ OpNameStringToInputKindAndMode(string opstring) { // TODO(philipphack): Consider merging with GetCudnnOperationGraph and // GetCudnnFusedOperationGraph. -// Returns a generic cuDNN OperationGraph for ForwardGraph -// convolutions with dynamically identified fused ops and the -// associated set of UIDs of non-virtual cuDNN tensors. +// Returns a generic cuDNN OperationGraph for ForwardGraph convolutions with the +// fused ops listed in serialized_graph and the associated set of UIDs of +// non-virtual cuDNN tensors. tsl::StatusOr, std::vector>> GetGenericCudnnOperationGraph( @@ -4290,15 +4287,15 @@ GetGenericCudnnOperationGraph( auto next_uid = [&non_virtual_uids, &virtual_uids](bool is_virtual) -> int64_t { - int64_t next_uid = - std::max(non_virtual_uids.empty() - ? 0 - : *std::max_element(non_virtual_uids.begin(), - non_virtual_uids.end()), - virtual_uids.empty() ? 0 - : *std::max_element(virtual_uids.begin(), - virtual_uids.end())) + - 1; + int64_t max_non_virtual_uid = + non_virtual_uids.empty() ? 0 + : *std::max_element(non_virtual_uids.begin(), + non_virtual_uids.end()); + int64_t max_virtual_uid = + virtual_uids.empty() + ? 0 + : *std::max_element(virtual_uids.begin(), virtual_uids.end()); + int64_t next_uid = std::max(max_non_virtual_uid, max_virtual_uid) + 1; if (is_virtual) { return virtual_uids.emplace_back(next_uid); } else { @@ -4351,9 +4348,9 @@ GetGenericCudnnOperationGraph( TF_ASSIGN_OR_RETURN( auto tensor_y, CreateCudnnTensor(output_dims, output_strides, - next_uid(/*is_virtual=*/true), + next_uid(/*is_virtual=*/op_sequence.size() > 1), op_sequence[0].output_type, vector_size, vector_dim, - /*is_virtual=*/true)); + /*is_virtual=*/op_sequence.size() > 1)); auto accumulator_type = ToCudnnDataType(GetConvAccumulatorType(input_type)); CHECK_NE(convolution_descriptor.pad_alignment(), From da22a881a3d24fd4f357207034ba6c596aa414d0 Mon Sep 17 00:00:00 2001 From: Philipp Hack Date: Sat, 22 Jul 2023 01:00:29 +0000 Subject: [PATCH 051/410] Support for FP8 convolutions in XLA. --- tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc b/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc index 430361066d394d..1608b15ba31144 100644 --- a/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc @@ -3573,10 +3573,10 @@ tsl::StatusOr CreateCudnnTensor( #endif -#if (CUDNN_VERSION >= 8900 && TF_ENABLE_CUDNN_FRONTEND) tsl::StatusOr CreateCudnnTensor( const cudnn_frontend::Tensor& original, int64_t uid, dnn::DataType dtype, bool is_virtual = false) { +#if (CUDNN_VERSION >= 8900 && TF_ENABLE_CUDNN_FRONTEND) auto tensor = cudnn_frontend::TensorBuilder() .cloneFrom(original, uid) .setAlignment(32) @@ -3585,8 +3585,10 @@ tsl::StatusOr CreateCudnnTensor( .build(); RETURN_MSG_IF_CUDNN_ERROR(tensor); return tensor; -} +#else + return tsl::errors::Internal("Not implemented."); #endif // CUDNN_VERSION >= 8900 && TF_ENABLE_CUDNN_FRONTEND +} #if (CUDNN_VERSION >= 8800 && TF_ENABLE_CUDNN_FRONTEND) enum CudnnfMHAUid { From a225f76b3ac5bb1cd9d0f490b786c1bd0f7a3dac Mon Sep 17 00:00:00 2001 From: johnnkp <22496821+johnnkp@users.noreply.github.com> Date: Sat, 22 Jul 2023 16:31:51 +0800 Subject: [PATCH 052/410] Update gpu_kernel_helper.h --- tensorflow/core/util/gpu_kernel_helper.h | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tensorflow/core/util/gpu_kernel_helper.h b/tensorflow/core/util/gpu_kernel_helper.h index 6857239488f307..7f22d65360696a 100644 --- a/tensorflow/core/util/gpu_kernel_helper.h +++ b/tensorflow/core/util/gpu_kernel_helper.h @@ -186,6 +186,12 @@ __host__ __device__ inline int tf_min(unsigned int x, int y) { __host__ __device__ inline int tf_max(unsigned int x, int y) { return max(static_cast(x), y); } +__host__ __device__ inline int tf_min(int x, unsigned int y) { + return min(x, static_cast(y)); +} +__host__ __device__ inline int tf_max(int x, unsigned int y) { + return max(x, static_cast(y)); +} #endif #endif From 6994ff4d182385de2dca842e786aaa00cbeb08ac Mon Sep 17 00:00:00 2001 From: johnnkp <22496821+johnnkp@users.noreply.github.com> Date: Sun, 23 Jul 2023 15:29:27 +0800 Subject: [PATCH 053/410] Update reduction_gpu_kernels.cu.h --- tensorflow/core/kernels/reduction_gpu_kernels.cu.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/kernels/reduction_gpu_kernels.cu.h b/tensorflow/core/kernels/reduction_gpu_kernels.cu.h index fd87098e5483fc..2eb19fed073961 100644 --- a/tensorflow/core/kernels/reduction_gpu_kernels.cu.h +++ b/tensorflow/core/kernels/reduction_gpu_kernels.cu.h @@ -191,7 +191,7 @@ __global__ __launch_bounds__(1024) void BlockReduceKernel( // elements: ----------------- // grid: |====|====|====|====|====| const int num_elements_to_reduce = - max(min(num_elems - bid * blockDim.x, num_threads), 0); + max(min(static_cast(num_elems - bid * blockDim.x), num_threads), 0); sum = BlockReduce(temp_storage).Reduce(sum, op, num_elements_to_reduce); From 849369f762483dc97c910b605a82b198dcbc3e26 Mon Sep 17 00:00:00 2001 From: shuw Date: Sun, 23 Jul 2023 20:22:57 -0700 Subject: [PATCH 054/410] Test format change --- .../service/gpu/tests/gemm_rewrite_test.cc | 150 +++++++++--------- 1 file changed, 75 insertions(+), 75 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc index 5fedc8a5fc2c3d..c5db73de2d3f4d 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc @@ -5380,21 +5380,21 @@ TEST_P(ParameterizedFp8GemmRewriteTest, Rank3ScaledABUnscaledDVectorBiasF8) { ; CHECK-NEXT: [[B_F16:%[^ ]+]] = f16[32]{0} convert([[B]]) ; CHECK-NEXT: [[GEMM:%[^ ]+]] = f16[64,32]{1,0} custom-call([[P0_BITCAST]], [[P1_TRANSPOSE]], [[P2_CV]], [[P3_CV]], [[C]], /*index=5*/[[C]], [[B_F16]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":0 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"BIAS\" -; CHECK: }" +; CHECK: backend_config={ +; CHECK-DAG: "alpha_real":1 +; CHECK-DAG: "alpha_imag":0 +; CHECK-DAG: "beta":0 +; CHECK-DAG: "dot_dimension_numbers":{ +; CHECK-DAG: "lhs_contracting_dimensions":["1"] +; CHECK-DAG: "rhs_contracting_dimensions":["1"] +; CHECK-DAG: "lhs_batch_dimensions":[] +; CHECK-DAG: "rhs_batch_dimensions":[] +; CHECK-DAG: } +; CHECK-DAG: "precision_config":{ +; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"] +; CHECK-DAG: } +; CHECK-DAG: "epilogue":"BIAS" +; CHECK: } ; CHECK: ROOT [[OUT:%[^ ]+]] = f16[4,16,32]{2,1,0} bitcast([[GEMM]]) )"); } @@ -5465,21 +5465,21 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ; CHECK-NEXT: [[P2_PAD:%[^ ]+]] = f16[32]{0} pad([[B_F16]], [[C3]]), padding=0_1 ; CHECK-NEXT: [[GEMM:%[^ ]+]] = f16[64,32]{1,0} custom-call([[P0_PAD]], [[P1_PAD]], [[P2_CV]], [[P3_CV]], [[C]], /*index=5*/[[C]], [[P2_PAD]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":0 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"BIAS\" -; CHECK: }" +; CHECK: backend_config={ +; CHECK-DAG: "alpha_real":1 +; CHECK-DAG: "alpha_imag":0 +; CHECK-DAG: "beta":0 +; CHECK-DAG: "dot_dimension_numbers":{ +; CHECK-DAG: "lhs_contracting_dimensions":["1"] +; CHECK-DAG: "rhs_contracting_dimensions":["1"] +; CHECK-DAG: "lhs_batch_dimensions":[] +; CHECK-DAG: "rhs_batch_dimensions":[] +; CHECK-DAG: } +; CHECK-DAG: "precision_config":{ +; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"] +; CHECK-DAG: } +; CHECK-DAG: "epilogue":"BIAS" +; CHECK: } ; CHECK-NEXT: [[SLICE:%[^ ]+]] = f16[60,31]{1,0} slice([[GEMM]]), slice={[0:60], [0:31]} ; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f16[4,15,31]{2,1,0} bitcast([[SLICE]]) )"); @@ -5538,21 +5538,21 @@ TEST_P(ParameterizedFp8GemmRewriteTest, Rank3ScaledABUnscaledDMatrixBiasF8) { ; CHECK-NEXT: [[C:%[^ ]+]] = f32[] constant(1) ; CHECK-NEXT: [[GEMM:%[^ ]+]] = f32[64,32]{1,0} custom-call([[P0_BITCAST]], [[P1_TRANSPOSE]], [[B_BITCAST]], [[P2]], [[P3]], /*index=5*/[[C]], [[C]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":1 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"DEFAULT\" -; CHECK: }" +; CHECK: backend_config={ +; CHECK-DAG: "alpha_real":1 +; CHECK-DAG: "alpha_imag":0 +; CHECK-DAG: "beta":1 +; CHECK-DAG: "dot_dimension_numbers":{ +; CHECK-DAG: "lhs_contracting_dimensions":["1"] +; CHECK-DAG: "rhs_contracting_dimensions":["1"] +; CHECK-DAG: "lhs_batch_dimensions":[] +; CHECK-DAG: "rhs_batch_dimensions":[] +; CHECK-DAG: } +; CHECK-DAG: "precision_config":{ +; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"] +; CHECK-DAG: } +; CHECK-DAG: "epilogue":"DEFAULT" +; CHECK: } ; CHECK: ROOT [[OUT:%[^ ]+]] = f32[4,16,32]{2,1,0} bitcast([[GEMM]]) )"); } @@ -5619,21 +5619,21 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ; CHECK-NEXT: [[C:%[^ ]+]] = f32[] constant(1) ; CHECK-NEXT: [[GEMM:%[^ ]+]] = f32[48,32]{1,0} custom-call([[P0_PADDED]], [[P1_PADDED]], [[P2_PADDED]], [[P2]], [[P3]], /*index=5*/[[C]], [[C]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":1 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"DEFAULT\" -; CHECK: }" +; CHECK: backend_config={ +; CHECK-DAG: "alpha_real":1 +; CHECK-DAG: "alpha_imag":0 +; CHECK-DAG: "beta":1 +; CHECK-DAG: "dot_dimension_numbers":{ +; CHECK-DAG: "lhs_contracting_dimensions":["1"] +; CHECK-DAG: "rhs_contracting_dimensions":["1"] +; CHECK-DAG: "lhs_batch_dimensions":[] +; CHECK-DAG: "rhs_batch_dimensions":[] +; CHECK-DAG: } +; CHECK-DAG: "precision_config":{ +; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"] +; CHECK-DAG: } +; CHECK-DAG: "epilogue":"DEFAULT" +; CHECK: } ; CHECK-NEXT: [[SLICE:%[^ ]+]] = f32[45,31]{1,0} slice([[GEMM]]), slice={[0:45], [0:31]} ; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[3,15,31]{2,1,0} bitcast([[SLICE]]) )"); @@ -5686,21 +5686,21 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ; CHECK-NEXT: [[C:%[^ ]+]] = f32[] constant(1) ; CHECK-NEXT: [[GEMM:%[^ ]+]] = f32[48,32]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C]], /*index=5*/[[C]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":0 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"DEFAULT\" -; CHECK: }" +; CHECK: backend_config={ +; CHECK-DAG: "alpha_real":1 +; CHECK-DAG: "alpha_imag":0 +; CHECK-DAG: "beta":0 +; CHECK-DAG: "dot_dimension_numbers":{ +; CHECK-DAG: "lhs_contracting_dimensions":["1"] +; CHECK-DAG: "rhs_contracting_dimensions":["1"] +; CHECK-DAG: "lhs_batch_dimensions":[] +; CHECK-DAG: "rhs_batch_dimensions":[] +; CHECK-DAG: } +; CHECK-DAG: "precision_config":{ +; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"] +; CHECK-DAG: } +; CHECK-DAG: "epilogue":"DEFAULT" +; CHECK: } ; CHECK-NEXT: [[SLICE:%[^ ]+]] = f32[32,16]{1,0} slice([[GEMM]]), slice={[16:48], [16:32]} ; CHECK-NEXT: [[B:%[^ ]+]] = f32[32,16]{1,0} parameter(2) ; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[32,16]{1,0} add([[SLICE]], [[B]]) From 812ad31955a986ca5a247fd19f306fa5df560b60 Mon Sep 17 00:00:00 2001 From: shuw Date: Sun, 23 Jul 2023 20:32:34 -0700 Subject: [PATCH 055/410] Reformat test --- .../service/gpu/tests/gemm_rewrite_test.cc | 150 +++++++++--------- 1 file changed, 75 insertions(+), 75 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc index 5fedc8a5fc2c3d..c5db73de2d3f4d 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc @@ -5380,21 +5380,21 @@ TEST_P(ParameterizedFp8GemmRewriteTest, Rank3ScaledABUnscaledDVectorBiasF8) { ; CHECK-NEXT: [[B_F16:%[^ ]+]] = f16[32]{0} convert([[B]]) ; CHECK-NEXT: [[GEMM:%[^ ]+]] = f16[64,32]{1,0} custom-call([[P0_BITCAST]], [[P1_TRANSPOSE]], [[P2_CV]], [[P3_CV]], [[C]], /*index=5*/[[C]], [[B_F16]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":0 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"BIAS\" -; CHECK: }" +; CHECK: backend_config={ +; CHECK-DAG: "alpha_real":1 +; CHECK-DAG: "alpha_imag":0 +; CHECK-DAG: "beta":0 +; CHECK-DAG: "dot_dimension_numbers":{ +; CHECK-DAG: "lhs_contracting_dimensions":["1"] +; CHECK-DAG: "rhs_contracting_dimensions":["1"] +; CHECK-DAG: "lhs_batch_dimensions":[] +; CHECK-DAG: "rhs_batch_dimensions":[] +; CHECK-DAG: } +; CHECK-DAG: "precision_config":{ +; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"] +; CHECK-DAG: } +; CHECK-DAG: "epilogue":"BIAS" +; CHECK: } ; CHECK: ROOT [[OUT:%[^ ]+]] = f16[4,16,32]{2,1,0} bitcast([[GEMM]]) )"); } @@ -5465,21 +5465,21 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ; CHECK-NEXT: [[P2_PAD:%[^ ]+]] = f16[32]{0} pad([[B_F16]], [[C3]]), padding=0_1 ; CHECK-NEXT: [[GEMM:%[^ ]+]] = f16[64,32]{1,0} custom-call([[P0_PAD]], [[P1_PAD]], [[P2_CV]], [[P3_CV]], [[C]], /*index=5*/[[C]], [[P2_PAD]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":0 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"BIAS\" -; CHECK: }" +; CHECK: backend_config={ +; CHECK-DAG: "alpha_real":1 +; CHECK-DAG: "alpha_imag":0 +; CHECK-DAG: "beta":0 +; CHECK-DAG: "dot_dimension_numbers":{ +; CHECK-DAG: "lhs_contracting_dimensions":["1"] +; CHECK-DAG: "rhs_contracting_dimensions":["1"] +; CHECK-DAG: "lhs_batch_dimensions":[] +; CHECK-DAG: "rhs_batch_dimensions":[] +; CHECK-DAG: } +; CHECK-DAG: "precision_config":{ +; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"] +; CHECK-DAG: } +; CHECK-DAG: "epilogue":"BIAS" +; CHECK: } ; CHECK-NEXT: [[SLICE:%[^ ]+]] = f16[60,31]{1,0} slice([[GEMM]]), slice={[0:60], [0:31]} ; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f16[4,15,31]{2,1,0} bitcast([[SLICE]]) )"); @@ -5538,21 +5538,21 @@ TEST_P(ParameterizedFp8GemmRewriteTest, Rank3ScaledABUnscaledDMatrixBiasF8) { ; CHECK-NEXT: [[C:%[^ ]+]] = f32[] constant(1) ; CHECK-NEXT: [[GEMM:%[^ ]+]] = f32[64,32]{1,0} custom-call([[P0_BITCAST]], [[P1_TRANSPOSE]], [[B_BITCAST]], [[P2]], [[P3]], /*index=5*/[[C]], [[C]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":1 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"DEFAULT\" -; CHECK: }" +; CHECK: backend_config={ +; CHECK-DAG: "alpha_real":1 +; CHECK-DAG: "alpha_imag":0 +; CHECK-DAG: "beta":1 +; CHECK-DAG: "dot_dimension_numbers":{ +; CHECK-DAG: "lhs_contracting_dimensions":["1"] +; CHECK-DAG: "rhs_contracting_dimensions":["1"] +; CHECK-DAG: "lhs_batch_dimensions":[] +; CHECK-DAG: "rhs_batch_dimensions":[] +; CHECK-DAG: } +; CHECK-DAG: "precision_config":{ +; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"] +; CHECK-DAG: } +; CHECK-DAG: "epilogue":"DEFAULT" +; CHECK: } ; CHECK: ROOT [[OUT:%[^ ]+]] = f32[4,16,32]{2,1,0} bitcast([[GEMM]]) )"); } @@ -5619,21 +5619,21 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ; CHECK-NEXT: [[C:%[^ ]+]] = f32[] constant(1) ; CHECK-NEXT: [[GEMM:%[^ ]+]] = f32[48,32]{1,0} custom-call([[P0_PADDED]], [[P1_PADDED]], [[P2_PADDED]], [[P2]], [[P3]], /*index=5*/[[C]], [[C]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":1 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"DEFAULT\" -; CHECK: }" +; CHECK: backend_config={ +; CHECK-DAG: "alpha_real":1 +; CHECK-DAG: "alpha_imag":0 +; CHECK-DAG: "beta":1 +; CHECK-DAG: "dot_dimension_numbers":{ +; CHECK-DAG: "lhs_contracting_dimensions":["1"] +; CHECK-DAG: "rhs_contracting_dimensions":["1"] +; CHECK-DAG: "lhs_batch_dimensions":[] +; CHECK-DAG: "rhs_batch_dimensions":[] +; CHECK-DAG: } +; CHECK-DAG: "precision_config":{ +; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"] +; CHECK-DAG: } +; CHECK-DAG: "epilogue":"DEFAULT" +; CHECK: } ; CHECK-NEXT: [[SLICE:%[^ ]+]] = f32[45,31]{1,0} slice([[GEMM]]), slice={[0:45], [0:31]} ; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[3,15,31]{2,1,0} bitcast([[SLICE]]) )"); @@ -5686,21 +5686,21 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ; CHECK-NEXT: [[C:%[^ ]+]] = f32[] constant(1) ; CHECK-NEXT: [[GEMM:%[^ ]+]] = f32[48,32]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C]], /*index=5*/[[C]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", -; CHECK: backend_config="{ -; CHECK-DAG: \"alpha_real\":1 -; CHECK-DAG: \"alpha_imag\":0 -; CHECK-DAG: \"beta\":0 -; CHECK-DAG: \"dot_dimension_numbers\":{ -; CHECK-DAG: \"lhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"rhs_contracting_dimensions\":[\"1\"] -; CHECK-DAG: \"lhs_batch_dimensions\":[] -; CHECK-DAG: \"rhs_batch_dimensions\":[] -; CHECK-DAG: } -; CHECK-DAG: \"precision_config\":{ -; CHECK-DAG: \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] -; CHECK-DAG: } -; CHECK-DAG: \"epilogue\":\"DEFAULT\" -; CHECK: }" +; CHECK: backend_config={ +; CHECK-DAG: "alpha_real":1 +; CHECK-DAG: "alpha_imag":0 +; CHECK-DAG: "beta":0 +; CHECK-DAG: "dot_dimension_numbers":{ +; CHECK-DAG: "lhs_contracting_dimensions":["1"] +; CHECK-DAG: "rhs_contracting_dimensions":["1"] +; CHECK-DAG: "lhs_batch_dimensions":[] +; CHECK-DAG: "rhs_batch_dimensions":[] +; CHECK-DAG: } +; CHECK-DAG: "precision_config":{ +; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"] +; CHECK-DAG: } +; CHECK-DAG: "epilogue":"DEFAULT" +; CHECK: } ; CHECK-NEXT: [[SLICE:%[^ ]+]] = f32[32,16]{1,0} slice([[GEMM]]), slice={[16:48], [16:32]} ; CHECK-NEXT: [[B:%[^ ]+]] = f32[32,16]{1,0} parameter(2) ; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = f32[32,16]{1,0} add([[SLICE]], [[B]]) From 276a0915e0157aa63dcecfb628c69ae0a77a20b3 Mon Sep 17 00:00:00 2001 From: David Svantesson Date: Mon, 17 Jul 2023 14:16:18 +0000 Subject: [PATCH 056/410] Update to ACL 23.05.1, add ACL reorders --- tensorflow/workspace2.bzl | 15 +- third_party/compute_library/BUILD | 182 +-- .../acl_fixed_format_kernels_striding.patch | 70 - .../compute_library/acl_openmp_fix.patch | 46 - .../compute_library/compute_library.patch | 51 +- .../onednn_acl_depthwise_convolution.patch | 312 ++-- .../onednn_acl_fixed_format_kernels.patch | 1372 ++++++++++++----- .../mkl_dnn/onednn_acl_remove_winograd.patch | 326 ++++ third_party/mkl_dnn/onednn_acl_reorder.patch | 370 +++++ .../onednn_acl_threadpool_scheduler.patch | 17 + 10 files changed, 1921 insertions(+), 840 deletions(-) delete mode 100644 third_party/compute_library/acl_fixed_format_kernels_striding.patch delete mode 100644 third_party/compute_library/acl_openmp_fix.patch create mode 100644 third_party/mkl_dnn/onednn_acl_remove_winograd.patch create mode 100644 third_party/mkl_dnn/onednn_acl_reorder.patch diff --git a/tensorflow/workspace2.bzl b/tensorflow/workspace2.bzl index 9f1f474c270fe9..f6b9376dd7d73b 100644 --- a/tensorflow/workspace2.bzl +++ b/tensorflow/workspace2.bzl @@ -204,11 +204,13 @@ def _tf_repositories(): build_file = "//third_party/mkl_dnn:mkldnn_acl.BUILD", patch_file = [ "//third_party/mkl_dnn:onednn_acl_threadcap.patch", + "//third_party/mkl_dnn:onednn_acl_remove_winograd.patch", "//third_party/mkl_dnn:onednn_acl_fixed_format_kernels.patch", "//third_party/mkl_dnn:onednn_acl_depthwise_convolution.patch", "//third_party/mkl_dnn:onednn_acl_threadpool_scheduler.patch", "//third_party/mkl_dnn:onednn_acl_reorder_padded.patch", "//third_party/mkl_dnn:onednn_acl_reorder_update.patch", + "//third_party/mkl_dnn:onednn_acl_reorder.patch", ], sha256 = "a50993aa6265b799b040fe745e0010502f9f7103cc53a9525d59646aef006633", strip_prefix = "oneDNN-2.7.3", @@ -217,15 +219,10 @@ def _tf_repositories(): tf_http_archive( name = "compute_library", - sha256 = "e20a060d3c4f803889d96c2f0b865004ba3ef4e228299a44339ea1c1ba827c85", - strip_prefix = "ComputeLibrary-22.11", - build_file = "//third_party/compute_library:BUILD", - patch_file = [ - "//third_party/compute_library:compute_library.patch", - "//third_party/compute_library:acl_fixed_format_kernels_striding.patch", - "//third_party/compute_library:acl_openmp_fix.patch", - ], - urls = tf_mirror_urls("https://github.com/ARM-software/ComputeLibrary/archive/v22.11.tar.gz"), + patch_file = ["//third_party/compute_library:compute_library.patch"], + sha256 = "c4ca329a78da380163b2d86e91ba728349b6f0ee97d66e260a694ef37f0b0d93", + strip_prefix = "ComputeLibrary-23.05.1", + urls = tf_mirror_urls("https://github.com/ARM-software/ComputeLibrary/archive/v23.05.1.tar.gz"), ) tf_http_archive( diff --git a/third_party/compute_library/BUILD b/third_party/compute_library/BUILD index 14bde5ac345c80..2e0d7500b4b05b 100644 --- a/third_party/compute_library/BUILD +++ b/third_party/compute_library/BUILD @@ -2,184 +2,12 @@ load("@bazel_skylib//:bzl_library.bzl", "bzl_library") exports_files(["LICENSE"]) -cc_library( - name = "include", - hdrs = glob([ - "include/**/*.h", - "include/**/*.hpp", - ]), - includes = ["include"], - strip_include_prefix = "include", -) - -_COMPUTE_LIBRARY_DEFINES = [ - "ARM_COMPUTE_OPENMP_SCHEDULER", - "ARM_COMPUTE_CPU_ENABLED", - "ENABLE_NEON", - "ARM_COMPUTE_ENABLE_NEON", - "ENABLE_SVE", - "ARM_COMPUTE_ENABLE_SVE", - "ARM_COMPUTE_ENABLE_BF16", - "ARM_COMPUTE_ENABLE_I8MM", - "ARM_COMPUTE_ENABLE_SVEF32MM", - "ENABLE_FP32_KERNELS", - "ENABLE_QASYMM8_KERNELS", - "ENABLE_QASYMM8_SIGNED_KERNELS", - "ENABLE_QSYMM16_KERNELS", - "ENABLE_INTEGER_KERNELS", - "ENABLE_NHWC_KERNELS", - "ENABLE_NCHW_KERNELS", - "ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS", -] - -cc_library( - name = "arm_compute_sve2", - srcs = glob( - [ - "src/cpu/kernels/**/sve2/*.cpp", - "**/*.h", - "**/*.hpp", - "**/*.inl", - ], - ), - copts = [ - "-march=armv8.6-a+sve2", - "-fopenmp", - ], - defines = _COMPUTE_LIBRARY_DEFINES + ["ARM_COMPUTE_ENABLE_SVE2"], - includes = [ - "src/core/NEON/kernels/arm_conv", - "src/core/NEON/kernels/arm_gemm", - "src/core/NEON/kernels/assembly", - "src/core/cpu/kernels/assembly", - "src/cpu/kernels/assembly", - ], - linkopts = ["-fopenmp"], - deps = ["include"], -) - -cc_library( - name = "arm_compute_sve", - srcs = glob( - [ - "src/core/NEON/kernels/arm_gemm/kernels/sve_*/*.cpp", - "src/core/NEON/kernels/arm_conv/**/kernels/sve_*/*.cpp", - "src/core/NEON/kernels/arm_conv/depthwise/interleaves/sve_*.cpp", - "src/core/NEON/kernels/batchnormalization/impl/SVE/*.cpp", - "src/core/NEON/kernels/convolution/winograd/input_transforms/sve_fp32_6x6.cpp", - "src/cpu/kernels/**/sve/*.cpp", - "**/*.h", - "**/*.hpp", - "**/*.inl", - ], - ) + [ - "src/core/NEON/kernels/arm_gemm/mergeresults-sve.cpp", - "src/core/NEON/kernels/arm_gemm/transform-sve.cpp", - ], - copts = [ - "-march=armv8.2-a+sve", - "-fopenmp", - ], - defines = _COMPUTE_LIBRARY_DEFINES, - includes = [ - "src/core/NEON/kernels/arm_conv", - "src/core/NEON/kernels/arm_gemm", - "src/core/NEON/kernels/assembly", - "src/core/cpu/kernels/assembly", - "src/cpu/kernels/assembly", - ], - linkopts = ["-fopenmp"], - deps = ["include"], -) - -cc_library( - name = "arm_compute", - srcs = glob( - [ - "src/common/**/*.cpp", - "src/core/*.cpp", - "src/core/CPP/kernels/*.cpp", - "src/core/helpers/*.cpp", - "src/core/utils/**/*.cpp", - "src/runtime/**/*.cpp", - "src/c/*.cpp", - "src/core/NEON/kernels/*.cpp", - "src/core/NEON/kernels/convolution/**/*.cpp", - "src/core/NEON/kernels/arm_gemm/kernels/a64_*/*.cpp", - "src/core/NEON/kernels/arm_conv/pooling/*.cpp", - "src/core/NEON/kernels/arm_conv/**/kernels/a64_*/*.cpp", - "src/core/NEON/kernels/arm_conv/depthwise/*.cpp", - "src/core/NEON/kernels/arm_conv/depthwise/interleaves/a64_*.cpp", - "src/core/NEON/kernels/arm_conv/depthwise/interleaves/generic*.cpp", - "src/core/NEON/kernels/batchnormalization/impl/NEON/*.cpp", - "src/cpu/*.cpp", - "src/cpu/kernels/*.cpp", - "src/cpu/kernels/fuse_batch_normalization/**/*.cpp", - "src/cpu/kernels/*/generic/*.cpp", - "src/cpu/operators/**/*.cpp", - "src/cpu/utils/*.cpp", - "src/cpu/kernels/internal/*.cpp", - "src/cpu/kernels/**/neon/*.cpp", - "src/cpu/kernels/**/nchw/*.cpp", - "src/core/NEON/kernels/arm_gemm/*.cpp", - "**/*.h", - "**/*.hpp", - "**/*.inl", - ], - exclude = [ - "src/core/utils/logging/**", - "src/core/TracePoint.cpp", - "src/core/NEON/kernels/arm_gemm/mergeresults-sve.cpp", - "src/core/NEON/kernels/arm_gemm/transform-sve.cpp", - "src/core/NEON/kernels/convolution/winograd/input_transforms/sve_fp32_6x6.cpp", - "src/runtime/CL/**", - "src/gpu/**", - ], - ) + [ - "src/c/operators/AclActivation.cpp", - "src/core/CPP/CPPTypes.cpp", - "src/core/NEON/kernels/arm_conv/addressing.cpp", - "src/core/NEON/kernels/arm_conv/depthwise/interleaves/8b_mla.cpp", - "src/core/NEON/kernels/arm_conv/pooling/kernels/cpp_nhwc_1x1_stride_any_depthfirst/generic.cpp", - ], - hdrs = glob([ - "src/core/NEON/kernels/**/*.h", - "src/core/NEON/kernels/**/*.hpp", - "arm_compute/runtime/**/*.h", - "arm_compute/runtime/*.h", - "arm_compute/core/**/*.h", - "**/*.inl", - ]) + [ - "arm_compute_version.embed", - ], - copts = [ - "-march=armv8-a", - "-fopenmp", - ], - defines = _COMPUTE_LIBRARY_DEFINES, - includes = [ - "arm_compute/runtime", - "src/core/NEON/kernels/assembly", - "src/core/NEON/kernels/convolution/common", - "src/core/NEON/kernels/convolution/winograd", - "src/core/cpu/kernels/assembly", - "src/cpu/kernels/assembly", - ], - linkopts = ["-fopenmp"], - visibility = ["//visibility:public"], - deps = [ - "arm_compute_sve", - "arm_compute_sve2", - "include", - ], -) - config_setting( - name = "build_with_acl", - define_values = { - "build_with_acl": "true", - }, - visibility = ["//visibility:public"], + name = "build_with_acl", + define_values = { + "build_with_acl": "true", + }, + visibility = ["//visibility:public"], ) bzl_library( diff --git a/third_party/compute_library/acl_fixed_format_kernels_striding.patch b/third_party/compute_library/acl_fixed_format_kernels_striding.patch deleted file mode 100644 index 8e501a1d6d9c79..00000000000000 --- a/third_party/compute_library/acl_fixed_format_kernels_striding.patch +++ /dev/null @@ -1,70 +0,0 @@ - ******************************************************************************* - Copyright 2022 Arm Limited and affiliates. - SPDX-License-Identifier: Apache-2.0 - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ******************************************************************************* - -diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp -index 77da83070..985f96761 100644 ---- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp -+++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp -@@ -495,48 +495,6 @@ void Fallback::run(ITensorPack &tensors) - { - ldb = b->info()->strides_in_bytes().y() / sizeof(TypeInput); - multi_stride_b = b->info()->strides_in_bytes().z() / sizeof(TypeInput); -- const arm_compute::WeightFormat wf = assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format); -- if(is_fixed_format(wf)) -- { -- // The 4D tensor of dimension O'HWI' created for the -- // OHWIoi format is in reality seen -- // as a 2D tensor at arm_gemm level, where the rows are -- // O'/ and the columns are * -- // H * W * I'. -- ITensorInfo *tensor_info = b->info(); -- const DataLayout data_layout = tensor_info->data_layout(); -- const TensorShape tensor_shape = tensor_info->tensor_shape(); -- const int tensor_height = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT)]; -- const int tensor_width = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH)]; -- int tensor_channels = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL)]; -- const int interleave_by = arm_compute::interleave_by(wf); -- const int blocked_by = arm_compute::block_by(wf); -- // We need to find a new stride that is distance from the data for one -- // set of output channels to the next -- if(ldb == tensor_channels && multi_stride_b == tensor_channels * tensor_width) -- { -- // In this case dimensions that are packed are height, width and channel -- // so we need to stride it by interleave_by -- if(tensor_channels % blocked_by != 0) -- { -- // We need to pad -- tensor_channels = arm_gemm::iceildiv(tensor_channels, blocked_by) * blocked_by; -- } -- ldb = interleave_by * tensor_height * tensor_width * tensor_channels; -- } -- else if(multi_stride_b == 0 || (ldb == tensor_width && multi_stride_b == tensor_height * tensor_width)) -- { -- // In this case dimension that is packed is only height -- // so we need to stride only height by interleave_by -- ldb = interleave_by * tensor_height; -- } -- else -- { -- // If dimensions are not packed as above error is thrown -- // as at the moment other forms of packing are not supported -- ARM_COMPUTE_ERROR("Unsupported packing for fixed format kernel"); -- } -- } - in1_ptr = reinterpret_cast(b->buffer() + b->info()->offset_first_element_in_bytes()); - } - diff --git a/third_party/compute_library/acl_openmp_fix.patch b/third_party/compute_library/acl_openmp_fix.patch deleted file mode 100644 index 512148c8eca114..00000000000000 --- a/third_party/compute_library/acl_openmp_fix.patch +++ /dev/null @@ -1,46 +0,0 @@ - ******************************************************************************* - Copyright 2022 Arm Limited and affiliates. - SPDX-License-Identifier: Apache-2.0 - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ******************************************************************************* - -diff --git a/src/runtime/OMP/OMPScheduler.cpp b/src/runtime/OMP/OMPScheduler.cpp -index aad24b4f0..78d1523af 100644 ---- a/src/runtime/OMP/OMPScheduler.cpp -+++ b/src/runtime/OMP/OMPScheduler.cpp -@@ -90,18 +116,21 @@ void OMPScheduler::schedule_op(ICPPKernel *kernel, const Hints &hints, const Win - void OMPScheduler::run_workloads(std::vector &workloads) - { - const unsigned int amount_of_work = static_cast(workloads.size()); -- if(amount_of_work < 1 || _num_threads == 1) -+ const unsigned int num_threads_to_use = std::min(_num_threads, amount_of_work ); -+ -+ if(amount_of_work < 1 || num_threads_to_use == 1) - { - return; - } - - ThreadInfo info; - info.cpu_info = &cpu_info(); -- info.num_threads = _num_threads; -- #pragma omp parallel for firstprivate(info) num_threads(_num_threads) default(shared) proc_bind(close) schedule(static, 1) -+ info.num_threads = num_threads_to_use; -+ #pragma omp parallel for firstprivate(info) num_threads(num_threads_to_use) default(shared) proc_bind(close) schedule(static, 1) - for(unsigned int wid = 0; wid < amount_of_work; ++wid) - { - const int tid = omp_get_thread_num(); -+ - info.thread_id = tid; - workloads[wid](info); - } diff --git a/third_party/compute_library/compute_library.patch b/third_party/compute_library/compute_library.patch index 2b9619dd03503f..5a86e2f68f9df8 100644 --- a/third_party/compute_library/compute_library.patch +++ b/third_party/compute_library/compute_library.patch @@ -1,8 +1,43 @@ -diff --git a/arm_compute_version.embed b/arm_compute_version.embed -new file mode 100644 -index 000000000..c986ad52a ---- /dev/null -+++ b/arm_compute_version.embed -@@ -0,0 +1,1 @@ -+"arm_compute_version=v22.11 Build options: {} Git hash=b'1b3192e8a23513031163dc14d248f47671986121'" -\ No newline at end of file + ******************************************************************************* + Copyright 2023 Arm Limited and affiliates. + SPDX-License-Identifier: Apache-2.0 + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ******************************************************************************* +diff --git a/BUILD.bazel b/BUILD.bazel +index f1766d958..0cb51f52d 100644 +--- a/BUILD.bazel ++++ b/BUILD.bazel +@@ -239,9 +239,11 @@ cc_library( + }), + visibility = ["//visibility:public"], + deps = [ +- "arm_compute", + "//:common_defines", + "//arm_compute:graph_headers", ++ "//include", ++ "//support", ++ "//utils", + ], + alwayslink = True, + ) +@@ -407,7 +409,8 @@ cc_library( + "//support", + "//utils", + "//:arm_compute_sve", +- "//:arm_compute_sve2" ++ "//:arm_compute_sve2", ++ "//:arm_compute_graph" + ], + alwayslink = True, + ) diff --git a/third_party/mkl_dnn/onednn_acl_depthwise_convolution.patch b/third_party/mkl_dnn/onednn_acl_depthwise_convolution.patch index 95f0374ec4ddd3..950077665fb4b7 100644 --- a/third_party/mkl_dnn/onednn_acl_depthwise_convolution.patch +++ b/third_party/mkl_dnn/onednn_acl_depthwise_convolution.patch @@ -1,5 +1,5 @@ ******************************************************************************* - Copyright 2022 Arm Limited and affiliates. + Copyright 2023 Arm Limited and affiliates. SPDX-License-Identifier: Apache-2.0 Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,87 +14,93 @@ See the License for the specific language governing permissions and limitations under the License. ******************************************************************************* - diff --git a/src/cpu/aarch64/acl_convolution_utils.cpp b/src/cpu/aarch64/acl_convolution_utils.cpp -index fc93d2aa9..6ebac0d17 100644 +index 6b57374643..85e45ace9d 100644 --- a/src/cpu/aarch64/acl_convolution_utils.cpp +++ b/src/cpu/aarch64/acl_convolution_utils.cpp -@@ -54,10 +54,12 @@ status_t acl_init_conf(acl_conv_conf_t &acp, memory_desc_t &src_md, +@@ -48,11 +48,14 @@ status_t acl_init_conf(acl_conv_conf_t &acp, memory_desc_t &src_md, + if (!is_fwd) return status::unimplemented; + const int ndims = src_d.ndims(); - const bool is_1d = ndims == 3; - const bool is_3d = ndims == 5; -+ const bool is_depthwise = wei_d.ndims() == 5 && wei_d.dims()[1] == 1 && wei_d.dims()[2] == 1; -+ - bool is_nspc; ++ const bool is_depthwise = wei_d.ndims() == 5 && wei_d.dims()[1] == 1 ++ && wei_d.dims()[2] == 1; - // Compute Library unsupported shape scenarios -- if (one_of(true, is_3d, is_1d, with_groups)) { -+ if (one_of(true, is_3d, is_1d, (with_groups && !is_depthwise))) { - return status::unimplemented; - } +- ACL_CHECK_SUPPORT(ndims != 4, " only supports 2 spatial dimensions"); ++ ACL_CHECK_SUPPORT( ++ ndims != 4 && !is_depthwise, " only supports 2 spatial dimensions"); -@@ -135,11 +137,11 @@ status_t acl_init_conf(acl_conv_conf_t &acp, memory_desc_t &src_md, - is_nspc = utils::one_of(src_tag, nhwc); + const int with_groups = wei_d.ndims() == src_d.ndims() + 1; +- ACL_CHECK_SUPPORT(with_groups, " does not support groups"); ++ ACL_CHECK_SUPPORT(with_groups && !is_depthwise, " does not support groups"); - memory_desc_t want_wei_md = weights_md; -- auto wei_tag = is_nspc ? ohwi : oihw; -+ auto wei_tag = is_depthwise ? hwigo : (is_nspc ? ohwi : oihw); - CHECK(memory_desc_init_by_tag(want_wei_md, wei_tag)); + ACL_CHECK_SUPPORT(src_d.data_type() != data_type::f32 + || wei_d.data_type() != data_type::f32 +@@ -108,7 +111,8 @@ status_t acl_init_conf(acl_conv_conf_t &acp, memory_desc_t &src_md, - // Compute Library does not support mismatching layouts -- if ((src_tag != wei_tag) || (src_tag != dst_tag)) -+ if (!is_depthwise && ((src_tag != wei_tag) || (src_tag != dst_tag))) - return status::unimplemented; + acp.with_bias = cd.bias_desc.format_kind != format_kind::undef; - if (weights_md.format_kind == format_kind::any) { -@@ -187,6 +189,12 @@ status_t acl_init_conf(acl_conv_conf_t &acp, memory_desc_t &src_md, - acl_wei_data_t, - acl_layout); +- if (wei_d.format_kind() != format_kind::any) return status::unimplemented; ++ if (wei_d.format_kind() != format_kind::any && !is_depthwise) ++ return status::unimplemented; -+ if(is_depthwise) { -+ // We need to set that values are not constant so that we -+ // we can update them in-place in ACL -+ acp.wei_info.set_are_values_constant(false); + auto src_tag = memory_desc_matches_one_of_tag( + src_md, format_tag::nhwc, format_tag::nchw); +@@ -138,8 +142,12 @@ status_t acl_init_conf(acl_conv_conf_t &acp, memory_desc_t &src_md, + || src_tag != dst_tag) + return status::unimplemented; + +- // Set weights to initially be the same as src +- CHECK(memory_desc_init_by_tag(weights_md, src_tag)); ++ if (is_depthwise) { ++ CHECK(memory_desc_init_by_tag(weights_md, format_tag::hwigo)); ++ } else { ++ // Set weights to initially be the same as src ++ CHECK(memory_desc_init_by_tag(weights_md, src_tag)); + } -+ - acp.dst_info = arm_compute::TensorInfo( - is_nspc ? arm_compute::TensorShape(oc, ow, oh, mb) : - arm_compute::TensorShape(ow, oh, oc, mb), -@@ -212,6 +220,12 @@ status_t acl_init_conf(acl_conv_conf_t &acp, memory_desc_t &src_md, - arm_compute::QuantizationInfo(1.0f / scales[0], 0)); - } + // Bias is just 1D, set to be the obvious format + if (acp.with_bias && bias_md.format_kind == format_kind::any) +@@ -166,6 +174,11 @@ status_t acl_init_conf(acl_conv_conf_t &acp, memory_desc_t &src_md, + 1, + acl_data_type, + acl_layout); + if(is_depthwise) { ++ // We need to set that values are not constant so that we ++ // we can update them in-place in ACL ++ acp.wei_tensor_info.set_are_values_constant(false); ++ } + + acp.dst_tensor_info = arm_compute::TensorInfo( + is_nhwc ? arm_compute::TensorShape(oc, ow, oh, mb) : +@@ -185,6 +198,11 @@ status_t acl_init_conf(acl_conv_conf_t &acp, memory_desc_t &src_md, + // Are we allowed to cast down to bf16 or not? + acp.fast_math + = one_of(attr.fpmath_mode_, fpmath_mode::bf16, fpmath_mode::any); ++ if (is_depthwise) { + // There is no support for fixed format kernels for depthwise convolution + // in ACL so we are going to use weight format that we set up earlier + return status::success; + } -+ + + // WeightFormat::ANY tells ACL we can handle any format acp.weights_info = arm_compute::WeightsInfo( - false, - kw, -@@ -302,6 +316,10 @@ status_t init_conf_gemm(acl_conv_conf_t &acp, memory_desc_t &src_md, +@@ -252,6 +270,7 @@ status_t init_conf_gemm(acl_conv_conf_t &acp, memory_desc_t &src_md, + memory_desc_t &weights_md, memory_desc_t &dst_md, + memory_desc_t &bias_md, const convolution_desc_t &cd, const primitive_attr_t &attr) { - acp.is_indirect = false; ++ if (weights_md.ndims != 4) return status::unimplemented; -+ if(weights_md.ndims != 4) { -+ return status::unimplemented; -+ } -+ // General Compute Library checks, memory tags are also set there CHECK(acl_init_conf(acp, src_md, weights_md, dst_md, bias_md, cd, attr)); +@@ -277,6 +296,7 @@ status_t init_conf_indirect_gemm(acl_conv_conf_t &acp, memory_desc_t &src_md, + memory_desc_t &weights_md, memory_desc_t &dst_md, + memory_desc_t &bias_md, const convolution_desc_t &cd, + const primitive_attr_t &attr) { ++ if (weights_md.ndims != 4) return status::unimplemented; -@@ -330,7 +348,8 @@ status_t init_conf_indirect_gemm(acl_conv_conf_t &acp, memory_desc_t &src_md, - auto math_mode = get_fpmath_mode(); - // Indirect convolution results in slowdown for low thread count or 1x1 - // kernels, so fall back to GEMM-based convolution in these cases -- if (one_of(true, weights_md.dims[2] == 1, // kh -+ if (one_of(true, weights_md.ndims != 4, -+ weights_md.dims[2] == 1, // kh - weights_md.dims[3] == 1, // kw - (!math_mode && dnnl_get_max_threads() < 28))) { - return status::unimplemented; -@@ -355,6 +374,27 @@ status_t init_conf_indirect_gemm(acl_conv_conf_t &acp, memory_desc_t &src_md, + // Indirect is slower for small convolution kernels + if (weights_md.dims[2] == 1 && weights_md.dims[3] == 1) +@@ -314,6 +334,22 @@ status_t init_conf_indirect_gemm(acl_conv_conf_t &acp, memory_desc_t &src_md, return status::success; } @@ -102,41 +108,26 @@ index fc93d2aa9..6ebac0d17 100644 + memory_desc_t &weights_md, memory_desc_t &dst_md, + memory_desc_t &bias_md, const convolution_desc_t &cd, + const primitive_attr_t &attr) { -+ acp.is_indirect = false; -+ // We need to make sure that number of dimensions for weights is either 5 or 3 -+ if(weights_md.ndims != 5) -+ return status::unimplemented; ++ if (weights_md.ndims != 5) return status::unimplemented; + + CHECK(acl_init_conf(acp, src_md, weights_md, dst_md, bias_md, cd, attr)); + + ACL_CHECK_VALID(arm_compute::NEDepthwiseConvolutionLayer::validate( -+ &acp.src_info, -+ &acp.wei_info, -+ acp.with_bias ? &acp.bia_info : nullptr, -+ &acp.dst_info, -+ acp.padstride_info)); ++ &acp.src_tensor_info, &acp.wei_tensor_info, ++ acp.with_bias ? &acp.bia_tensor_info : nullptr, ++ &acp.dst_tensor_info, acp.padstride_info)); + + return status::success; +} + - status_t init_conf_wino(acl_conv_conf_t &acp, memory_desc_t &src_md, - memory_desc_t &weights_md, memory_desc_t &dst_md, - memory_desc_t &bias_md, const convolution_desc_t &cd, -@@ -364,7 +404,8 @@ status_t init_conf_wino(acl_conv_conf_t &acp, memory_desc_t &src_md, - // Under these conditions, fallback to faster GEMM-based convolution - // unless the user explicitly specifies Winograd algorithm - // clang-format off -- if (one_of(true, src_md.dims[2] > 112, // ih -+ if (one_of(true, weights_md.ndims != 4, -+ src_md.dims[2] > 112, // ih - src_md.dims[3] > 112, // iw - src_md.dims[1] < 64, // ic - dst_md.dims[1] < 64, // oc + } // namespace acl_convolution_utils + + } // namespace aarch64 diff --git a/src/cpu/aarch64/acl_convolution_utils.hpp b/src/cpu/aarch64/acl_convolution_utils.hpp -index 44dc8eecb..7eae5cbb1 100644 +index e3d40a5e75..1ded5826c4 100644 --- a/src/cpu/aarch64/acl_convolution_utils.hpp +++ b/src/cpu/aarch64/acl_convolution_utils.hpp -@@ -67,6 +67,11 @@ status_t init_conf_indirect_gemm(acl_conv_conf_t &acp, memory_desc_t &src_md, +@@ -66,6 +66,11 @@ status_t init_conf_indirect_gemm(acl_conv_conf_t &acp, memory_desc_t &src_md, memory_desc_t &bias_md, const convolution_desc_t &cd, const primitive_attr_t &attr); @@ -145,37 +136,17 @@ index 44dc8eecb..7eae5cbb1 100644 + memory_desc_t &bias_md, const convolution_desc_t &cd, + const primitive_attr_t &attr); + - status_t init_conf_wino(acl_conv_conf_t &acp, memory_desc_t &src_md, - memory_desc_t &weights_md, memory_desc_t &dst_md, - memory_desc_t &bias_md, const convolution_desc_t &cd, -diff --git a/src/cpu/cpu_convolution_list.cpp b/src/cpu/cpu_convolution_list.cpp -index 4142dbc7e..1800aaf58 100644 ---- a/src/cpu/cpu_convolution_list.cpp -+++ b/src/cpu/cpu_convolution_list.cpp -@@ -65,6 +65,7 @@ using namespace dnnl::impl::cpu::x64; - #if DNNL_AARCH64 && DNNL_AARCH64_USE_ACL - #include "cpu/aarch64/acl_gemm_convolution.hpp" - #include "cpu/aarch64/acl_indirect_gemm_convolution.hpp" -+#include "cpu/aarch64/acl_depthwise_convolution.hpp" - #include "cpu/aarch64/acl_winograd_convolution.hpp" - #endif - using namespace dnnl::impl::cpu::aarch64; -@@ -104,6 +105,7 @@ const std::map> &impl_list_map() - CPU_INSTANCE_AARCH64(jit_sve_512_dw_convolution_fwd_t) - CPU_INSTANCE_AARCH64(jit_sve_512_1x1_convolution_fwd_f32_t) - CPU_INSTANCE_AARCH64(jit_sve_512_convolution_fwd_t) -+ CPU_INSTANCE_AARCH64_ACL(acl_depthwise_convolution_fwd_t) - CPU_INSTANCE_AARCH64_ACL(acl_indirect_gemm_convolution_fwd_t) - CPU_INSTANCE_AARCH64_ACL(acl_gemm_convolution_fwd_t) - CPU_INSTANCE(gemm_convolution_fwd_t) + } // namespace acl_convolution_utils + + template _lock {this->mtx}; -+ -+ auto *acl_resource -+ = ctx.get_resource_mapper()->get( -+ this); -+ acl_obj_t &acl_depthwise_obj -+ = acl_resource->get_acl_obj(); -+ -+ return execute_forward_conv_acl, pd_t, -+ data_t>(ctx, acl_depthwise_obj, pd()); -+ } -+ -+} -+} -+} ++ const exec_ctx_t &ctx) const { ++ std::lock_guard _lock {this->mtx}; ++ ++ auto *acl_resource ++ = ctx.get_resource_mapper() ++ ->get(this); ++ acl_obj_t &acl_depthwise_obj ++ = acl_resource->get_acl_obj(); ++ ++ return execute_forward_conv_acl< ++ acl_obj_t, pd_t, data_t>( ++ ctx, acl_depthwise_obj, pd()); +} ++ ++} // namespace aarch64 ++} // namespace cpu ++} // namespace impl ++} // namespace dnnl diff --git a/src/cpu/aarch64/acl_depthwise_convolution.hpp b/src/cpu/aarch64/acl_depthwise_convolution.hpp new file mode 100644 -index 000000000..d84fc4fb5 +index 0000000000..3e3d02cf41 --- /dev/null +++ b/src/cpu/aarch64/acl_depthwise_convolution.hpp -@@ -0,0 +1,139 @@ +@@ -0,0 +1,141 @@ +/******************************************************************************* -+* Copyright 2022 Arm Ltd. and affiliates ++* Copyright 2023 Arm Ltd. and affiliates +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. @@ -240,8 +212,8 @@ index 000000000..d84fc4fb5 +#ifndef CPU_AARCH64_ACL_DEPTHWISE_CONVOLUTION_HPP +#define CPU_AARCH64_ACL_DEPTHWISE_CONVOLUTION_HPP + -+#include "cpu/cpu_convolution_pd.hpp" +#include "cpu/aarch64/acl_convolution_utils.hpp" ++#include "cpu/cpu_convolution_pd.hpp" + +namespace dnnl { +namespace impl { @@ -250,15 +222,16 @@ index 000000000..d84fc4fb5 + +struct acl_depthwise_convolution_resource_t : public resource_t { + acl_depthwise_convolution_resource_t() -+ : acl_obj_(utils::make_unique>()) {} ++ : acl_obj_(utils::make_unique< ++ acl_obj_t>()) {} + + status_t configure(const acl_conv_conf_t &acp) { -+ if(!acl_obj_) return status::out_of_memory; ++ if (!acl_obj_) return status::out_of_memory; + -+ acl_obj_->src_tensor.allocator()->init(acp.src_info); -+ acl_obj_->wei_tensor.allocator()->init(acp.wei_info); -+ acl_obj_->dst_tensor.allocator()->init(acp.dst_info); -+ acl_obj_->bia_tensor.allocator()->init(acp.bia_info); ++ acl_obj_->src_tensor.allocator()->init(acp.src_tensor_info); ++ acl_obj_->wei_tensor.allocator()->init(acp.wei_tensor_info); ++ acl_obj_->dst_tensor.allocator()->init(acp.dst_tensor_info); ++ acl_obj_->bia_tensor.allocator()->init(acp.bia_tensor_info); + + // clang-format off + acl_obj_->conv.configure( @@ -281,14 +254,14 @@ index 000000000..d84fc4fb5 + DNNL_DISALLOW_COPY_AND_ASSIGN(acl_depthwise_convolution_resource_t); + +private: -+ std::unique_ptr> acl_obj_; -+ ++ std::unique_ptr> ++ acl_obj_; +}; + +struct acl_depthwise_convolution_fwd_t : public primitive_t { + + struct pd_t : public cpu_convolution_fwd_pd_t { -+ pd_t(const convolution_desc_t* adesc, const primitive_attr_t *attr, ++ pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, + const typename pd_t::base_class *hint_fwd_pd) + : cpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd), acp_() {} + @@ -297,16 +270,18 @@ index 000000000..d84fc4fb5 + + status_t init(engine_t *engine) { + using namespace data_type; -+ using smask_t = primitive_attr_t::skip_mask_t; + ++ const bool is_fp16_ok = expect_data_types(f16, f16, f16, f16, undef) ++ && attr()->has_default_values( ++ primitive_attr_t::skip_mask_t::post_ops, f16); ++ const bool is_fp32_ok = expect_data_types(f32, f32, f32, f32, undef) ++ && attr()->has_default_values( ++ primitive_attr_t::skip_mask_t::post_ops, f32); + bool ok = is_fwd() + && set_default_alg_kind(alg_kind::convolution_direct) -+ && expect_data_types(data_type::f32, data_type::f32, -+ data_type::f32, data_type::f32, undef) -+ && !has_zero_dim_memory() -+ && attr()->has_default_values( -+ smask_t::post_ops, data_type::f32); -+ if(!ok) return status::unimplemented; ++ && utils::one_of(true, is_fp16_ok, is_fp32_ok) ++ && !has_zero_dim_memory(); ++ if (!ok) return status::unimplemented; + + CHECK(acl_convolution_utils::init_conf_depthwise(acp_, src_md_, + weights_md_, dst_md_, bias_md_, *desc(), *attr())); @@ -326,32 +301,31 @@ index 000000000..d84fc4fb5 + acl_depthwise_convolution_fwd_t(const pd_t *apd) : primitive_t(apd) {} + + status_t create_resource( -+ engine_t *engine, resource_mapper_t &mapper) const override { -+ if(mapper.has_resource(this)) return status::success; ++ engine_t *engine, resource_mapper_t &mapper) const override { ++ if (mapper.has_resource(this)) return status::success; + -+ auto r = utils::make_unique(); -+ if(!r) return status::out_of_memory; ++ auto r = utils::make_unique(); ++ if (!r) return status::out_of_memory; + -+ CHECK(r->configure(pd()->acp_)); -+ mapper.add(this, std::move(r)); ++ CHECK(r->configure(pd()->acp_)); ++ mapper.add(this, std::move(r)); + -+ CHECK(pd()->post_ops.create_resource(engine, mapper)); ++ CHECK(pd()->post_ops.create_resource(engine, mapper)); + -+ return status::success; -+ } ++ return status::success; ++ } + -+ typedef typename prec_traits::type data_t; ++ typedef typename prec_traits::type data_t; + -+ status_t execute(const exec_ctx_t &ctx) const override { -+ return execute_forward(ctx); -+ } ++ status_t execute(const exec_ctx_t &ctx) const override { ++ return execute_forward(ctx); ++ } + +private: + mutable std::mutex mtx; + status_t execute_forward(const exec_ctx_t &ctx) const; + + const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } -+ +}; + +} // namespace aarch64 @@ -360,3 +334,23 @@ index 000000000..d84fc4fb5 +} // namespace dnnl + +#endif // CPU_AARCH64_ACL_DEPTHWISE_CONVOLUTION_HPP +diff --git a/src/cpu/cpu_convolution_list.cpp b/src/cpu/cpu_convolution_list.cpp +index 094c73aa36..80385432d8 100644 +--- a/src/cpu/cpu_convolution_list.cpp ++++ b/src/cpu/cpu_convolution_list.cpp +@@ -63,6 +63,7 @@ using namespace dnnl::impl::cpu::x64; + #include "cpu/aarch64/jit_sve_512_x8s8s32x_convolution.hpp" + #include "cpu/aarch64/jit_uni_dw_convolution.hpp" + #if DNNL_AARCH64 && DNNL_AARCH64_USE_ACL ++#include "cpu/aarch64/acl_depthwise_convolution.hpp" + #include "cpu/aarch64/acl_gemm_convolution.hpp" + #include "cpu/aarch64/acl_indirect_gemm_convolution.hpp" + #endif +@@ -102,6 +103,7 @@ const std::map> &impl_list_map() + CPU_INSTANCE_AARCH64(jit_sve_512_dw_convolution_fwd_t) + CPU_INSTANCE_AARCH64(jit_sve_512_1x1_convolution_fwd_f32_t) + CPU_INSTANCE_AARCH64(jit_sve_512_convolution_fwd_t) ++ CPU_INSTANCE_AARCH64_ACL(acl_depthwise_convolution_fwd_t) + CPU_INSTANCE_AARCH64_ACL(acl_indirect_gemm_convolution_fwd_t) + CPU_INSTANCE_AARCH64_ACL(acl_gemm_convolution_fwd_t) + CPU_INSTANCE(gemm_convolution_fwd_t) diff --git a/third_party/mkl_dnn/onednn_acl_fixed_format_kernels.patch b/third_party/mkl_dnn/onednn_acl_fixed_format_kernels.patch index 2c8af08ab8a4ff..5d918564fb1515 100644 --- a/third_party/mkl_dnn/onednn_acl_fixed_format_kernels.patch +++ b/third_party/mkl_dnn/onednn_acl_fixed_format_kernels.patch @@ -1,5 +1,5 @@ ******************************************************************************* - Copyright 2022 Arm Limited and affiliates. + Copyright 2023 Arm Limited and affiliates. SPDX-License-Identifier: Apache-2.0 Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,178 +14,479 @@ See the License for the specific language governing permissions and limitations under the License. ******************************************************************************* - +diff --git a/src/common/matmul_pd.hpp b/src/common/matmul_pd.hpp +index 4330ad938b..df16c5fcca 100644 +--- a/src/common/matmul_pd.hpp ++++ b/src/common/matmul_pd.hpp +@@ -159,6 +159,19 @@ protected: + + return true; + } ++ ++ // All implementations that do not support sparse inputs/outputs should ++ // call this function. ++ bool is_dense_data() { ++#ifdef DNNL_EXPERIMENTAL_SPARSE ++ for (auto md : {&src_md_, &weights_md_, &bias_md_, &dst_md_}) { ++ if (memory_desc_wrapper(md).format_kind() == format_kind::sparse) ++ return false; ++ } ++#endif ++ return true; ++ } ++ + }; + + } // namespace impl diff --git a/src/cpu/aarch64/acl_convolution_utils.cpp b/src/cpu/aarch64/acl_convolution_utils.cpp -index c46d69757..fc93d2aa9 100644 +index 37f8ecbc06..6b57374643 100644 --- a/src/cpu/aarch64/acl_convolution_utils.cpp +++ b/src/cpu/aarch64/acl_convolution_utils.cpp -@@ -212,6 +212,87 @@ status_t acl_init_conf(acl_conv_conf_t &acp, memory_desc_t &src_md, - arm_compute::QuantizationInfo(1.0f / scales[0], 0)); - } +@@ -41,25 +41,23 @@ status_t acl_init_conf(acl_conv_conf_t &acp, memory_desc_t &src_md, + const memory_desc_wrapper dst_d(&dst_md); + const memory_desc_wrapper bia_d(&bias_md); + +- auto math_mode = get_fpmath_mode(); +- acp.fast_math = one_of(math_mode, fpmath_mode::bf16, fpmath_mode::any); +- + // Compute Library currently supports forward propagation only + const prop_kind_t prop_kind = cd.prop_kind; + const bool is_fwd = (prop_kind == dnnl_forward_training) + || (prop_kind == dnnl_forward_inference); + if (!is_fwd) return status::unimplemented; + +- const int with_groups = wei_d.ndims() == src_d.ndims() + 1; + const int ndims = src_d.ndims(); +- const bool is_1d = ndims == 3; +- const bool is_3d = ndims == 5; +- bool is_nspc; + +- // Compute Library unsupported shape scenarios +- if (one_of(true, is_3d, is_1d, with_groups)) { +- return status::unimplemented; +- } ++ ACL_CHECK_SUPPORT(ndims != 4, " only supports 2 spatial dimensions"); ++ ++ const int with_groups = wei_d.ndims() == src_d.ndims() + 1; ++ ACL_CHECK_SUPPORT(with_groups, " does not support groups"); ++ ++ ACL_CHECK_SUPPORT(src_d.data_type() != data_type::f32 ++ || wei_d.data_type() != data_type::f32 ++ || dst_d.data_type() != data_type::f32, ++ " src, dst and wei must be fp32"); + + // batch size + const int mb = src_d.dims()[0]; +@@ -110,108 +108,143 @@ status_t acl_init_conf(acl_conv_conf_t &acp, memory_desc_t &src_md, + + acp.with_bias = cd.bias_desc.format_kind != format_kind::undef; + +- auto set_or_check_tags = [&](format_tag_t desired_src_tag, +- format_tag_t desired_dst_tag) -> status_t { +- using namespace format_tag; +- auto src_tag = any, dst_tag = any; +- +- if (src_d.format_kind() == format_kind::any) { +- CHECK(memory_desc_init_by_tag(src_md, desired_src_tag)); +- src_tag = desired_src_tag; +- } else { +- src_tag = memory_desc_matches_one_of_tag(src_md, nhwc, nchw); +- } +- +- if (dst_d.format_kind() == format_kind::any) { +- CHECK(memory_desc_init_by_tag(dst_md, desired_dst_tag)); +- dst_tag = desired_dst_tag; +- } else { +- dst_tag = memory_desc_matches_one_of_tag(dst_md, nhwc, nchw); +- } +- +- if (acp.with_bias && bias_md.format_kind == format_kind::any) +- CHECK(memory_desc_init_by_tag(bias_md, x)); +- +- is_nspc = utils::one_of(src_tag, nhwc); +- +- memory_desc_t want_wei_md = weights_md; +- auto wei_tag = is_nspc ? ohwi : oihw; +- CHECK(memory_desc_init_by_tag(want_wei_md, wei_tag)); +- +- // Compute Library does not support mismatching layouts +- if ((src_tag != wei_tag) || (src_tag != dst_tag)) +- return status::unimplemented; ++ if (wei_d.format_kind() != format_kind::any) return status::unimplemented; ++ ++ auto src_tag = memory_desc_matches_one_of_tag( ++ src_md, format_tag::nhwc, format_tag::nchw); ++ auto dst_tag = memory_desc_matches_one_of_tag( ++ dst_md, format_tag::nhwc, format_tag::nchw); ++ ++ // We want src and dst to match, preferrably both to be NHWC ++ if (src_d.format_kind() == format_kind::any ++ && dst_d.format_kind() == format_kind::any) { ++ CHECK(memory_desc_init_by_tag(src_md, format_tag::nhwc)); ++ CHECK(memory_desc_init_by_tag(dst_md, format_tag::nhwc)); ++ } else if (src_d.format_kind() == format_kind::any ++ && dst_tag != format_tag::undef) { ++ CHECK(memory_desc_init_by_tag(src_md, dst_tag)); ++ } else if (dst_d.format_kind() == format_kind::any ++ && src_tag != format_tag::undef) { ++ CHECK(memory_desc_init_by_tag(dst_md, src_tag)); ++ } + +- if (weights_md.format_kind == format_kind::any) { +- weights_md = want_wei_md; +- } +- return (want_wei_md == weights_md) ? status::success +- : status::unimplemented; +- }; ++ // Recompute tags after potentially running memory desc init ++ src_tag = memory_desc_matches_one_of_tag( ++ src_md, format_tag::nhwc, format_tag::nchw); ++ dst_tag = memory_desc_matches_one_of_tag( ++ dst_md, format_tag::nhwc, format_tag::nchw); + +- auto default_dat_tag = format_tag::nhwc; +- if (set_or_check_tags(default_dat_tag, default_dat_tag) != status::success) ++ if (src_tag == format_tag::undef || dst_tag == format_tag::undef ++ || src_tag != dst_tag) + return status::unimplemented; + +- const auto acl_layout = is_nspc ? arm_compute::DataLayout::NHWC +- : arm_compute::DataLayout::NCHW; ++ // Set weights to initially be the same as src ++ CHECK(memory_desc_init_by_tag(weights_md, src_tag)); + +- // For convolutions, int8 datatypes imply quantized types in ACL +- acp.is_int8 = utils::one_of(src_d.data_type(), s8, u8) +- && wei_d.data_type() == s8; ++ // Bias is just 1D, set to be the obvious format ++ if (acp.with_bias && bias_md.format_kind == format_kind::any) ++ CHECK(memory_desc_init_by_tag(bias_md, format_tag::x)); + +- auto acl_src_data_t +- = acl_utils::get_acl_data_t(src_d.data_type(), acp.is_int8); +- auto acl_wei_data_t +- = acl_utils::get_acl_data_t(wei_d.data_type(), acp.is_int8); +- auto acl_dst_data_t +- = acl_utils::get_acl_data_t(dst_d.data_type(), acp.is_int8); +- auto acl_bia_data_t +- = acl_utils::get_acl_data_t(bia_d.data_type(), acp.is_int8); ++ bool is_nhwc = src_tag == format_tag::nhwc; ++ // The layouts have to match (although we may later modify the weights) ++ const auto acl_layout = is_nhwc ? arm_compute::DataLayout::NHWC ++ : arm_compute::DataLayout::NCHW; + +- if (acl_bia_data_t == arm_compute::DataType::UNKNOWN) +- acl_bia_data_t = arm_compute::DataType::F32; ++ auto acl_data_type = arm_compute::DataType::F32; + // clang-format off +- acp.src_info = arm_compute::TensorInfo( +- is_nspc ? arm_compute::TensorShape(ic, iw, ih, mb) : ++ acp.src_tensor_info = arm_compute::TensorInfo( ++ is_nhwc ? arm_compute::TensorShape(ic, iw, ih, mb) : + arm_compute::TensorShape(iw, ih, ic, mb), + 1, +- acl_src_data_t, ++ acl_data_type, + acl_layout); + +- acp.wei_info = arm_compute::TensorInfo( +- is_nspc ? arm_compute::TensorShape(ic, kw, kh, oc) : ++ acp.wei_tensor_info = arm_compute::TensorInfo( ++ is_nhwc ? arm_compute::TensorShape(ic, kw, kh, oc) : + arm_compute::TensorShape(kw, kh, ic, oc), + 1, +- acl_wei_data_t, ++ acl_data_type, + acl_layout); + +- acp.dst_info = arm_compute::TensorInfo( +- is_nspc ? arm_compute::TensorShape(oc, ow, oh, mb) : ++ acp.dst_tensor_info = arm_compute::TensorInfo( ++ is_nhwc ? arm_compute::TensorShape(oc, ow, oh, mb) : + arm_compute::TensorShape(ow, oh, oc, mb), + 1, +- acl_dst_data_t, ++ acl_data_type, + acl_layout); + +- acp.bia_info = arm_compute::TensorInfo( ++ acp.bia_tensor_info = arm_compute::TensorInfo( + acp.with_bias ? arm_compute::TensorShape(oc) + : arm_compute::TensorShape(), + 1, +- acl_bia_data_t, ++ acl_data_type, + acl_layout); + // clang-format on + +- // Add quantization info to tensors +- if (acp.is_int8) { +- const float *scales = attr.output_scales_.scales_; +- acp.src_info.set_quantization_info(arm_compute::QuantizationInfo(1, 0)); +- acp.bia_info.set_quantization_info(arm_compute::QuantizationInfo(1, 0)); +- acp.wei_info.set_quantization_info(arm_compute::QuantizationInfo(1, 0)); +- acp.dst_info.set_quantization_info( +- arm_compute::QuantizationInfo(1.0f / scales[0], 0)); ++ // Are we allowed to cast down to bf16 or not? ++ acp.fast_math ++ = one_of(attr.fpmath_mode_, fpmath_mode::bf16, fpmath_mode::any); ++ ++ // WeightFormat::ANY tells ACL we can handle any format + acp.weights_info = arm_compute::WeightsInfo( -+ false, -+ kw, -+ kh, -+ oc, -+ false, -+ arm_compute::WeightFormat::ANY); ++ false, kw, kh, oc, false, arm_compute::WeightFormat::ANY); ++ ++ // Get the format that the ACL kernel will expect the weights to be ++ // in (if a kernel exists). Note that these are referred to as fixed format ++ // kernels, because they require one specific weights format + arm_compute::WeightFormat expected_weight_format; -+ auto acl_st = arm_compute::NEGEMMConvolutionLayer::has_opt_impl( -+ expected_weight_format, -+ &acp.src_info, -+ &acp.wei_info, -+ acp.with_bias ? &acp.bia_info : nullptr, -+ &acp.dst_info, -+ acp.padstride_info, -+ acp.weights_info, -+ acp.dilation_info, -+ acp.act_info, -+ acp.fast_math); -+ if(acl_st.error_code() != arm_compute::ErrorCode::OK) { -+ return status::unimplemented; -+ } ++ ACL_CHECK_VALID(arm_compute::NEGEMMConvolutionLayer::has_opt_impl( ++ expected_weight_format, &acp.src_tensor_info, &acp.wei_tensor_info, ++ acp.with_bias ? &acp.bia_tensor_info : nullptr, ++ &acp.dst_tensor_info, acp.padstride_info, acp.weights_info, ++ acp.dilation_info, acp.act_info, acp.fast_math)); ++ ++ // Set weights info to the one returned by has_opt_impl + acp.weights_info.set_weight_format(expected_weight_format); + -+ int interleaved_by = arm_compute::interleave_by(expected_weight_format); -+ int block_by = arm_compute::block_by(expected_weight_format); ++ // has_opt_impl may return a non fast math kernel, even if we requested one ++ acp.fast_math ++ = arm_compute::is_fixed_format_fast_math(expected_weight_format); + -+ bool is_fast_math_kernel = arm_compute::is_fixed_format_fast_math(expected_weight_format); -+ if(!is_fast_math_kernel) { -+ // FP32 kernel is faster then BF16 -+ acp.fast_math = false; -+ } ++ // Map OIHW used in ACL WeightFormat to the logical dimensions of the memory descriptor ++ dim_t O_dim = 0; ++ dim_t I_dim = 1; ++ dim_t H_dim = 2; ++ dim_t W_dim = 3; + -+ memory_desc_t want_wei_md = weights_md; -+ -+ int ic_multiply = ic; -+ if(ic % block_by != 0) { -+ ic_multiply = utils::div_up(ic, block_by) * block_by; -+ // Also we need to set padded dimensions as well -+ want_wei_md.padded_dims[1] = ic_multiply; -+ } else { -+ // If we do not need to pad input channels for fast math mode -+ // then it would be faster to run convolution with im2row -+ // instead of using indirect buffer -+ if(acp.fast_math && acp.is_indirect) { ++ if (!is_nhwc) { ++ // We can try to support NCHW by swapping IHW around, note that this ++ // requires weights_md.dims[I_dim] % block_by != 0 (see next block) ++ O_dim = 0; ++ I_dim = 3; ++ H_dim = 1; ++ W_dim = 2; + } + ++ // We can't currently support nchw and block_by != 1. If this is the case, ++ // try a non fast math kernel, which currently have no blocking ++ int block_by = arm_compute::block_by(acp.weights_info.weight_format()); ++ if (!is_nhwc && weights_md.dims[I_dim] % block_by != 0 && acp.fast_math) { ++ acp.fast_math = false; ++ acp.weights_info.set_weight_format(arm_compute::WeightFormat::ANY); ++ ACL_CHECK_VALID(arm_compute::NEGEMMConvolutionLayer::has_opt_impl( ++ expected_weight_format, &acp.src_tensor_info, ++ &acp.wei_tensor_info, ++ acp.with_bias ? &acp.bia_tensor_info : nullptr, ++ &acp.dst_tensor_info, acp.padstride_info, acp.weights_info, ++ acp.dilation_info, acp.act_info, acp.fast_math)); ++ acp.weights_info.set_weight_format(expected_weight_format); ++ block_by = arm_compute::block_by(expected_weight_format); ++ // This shouldn't happen, because non-fastmath have no blocking, but ++ // guard against it because it would silently return incorrect results ++ if (weights_md.dims[I_dim] % block_by != 0) + return status::unimplemented; -+ } -+ } -+ if(oc % interleaved_by != 0) { -+ int padded_dim = utils::div_up(oc, interleaved_by) * interleaved_by; -+ want_wei_md.padded_dims[0] = padded_dim; -+ } -+ -+ // Set strides based on blocking information -+ want_wei_md.format_desc.blocking.strides[0] = interleaved_by*ic_multiply*kw*kh; -+ want_wei_md.format_desc.blocking.strides[1] = interleaved_by*block_by; -+ want_wei_md.format_desc.blocking.strides[2] = interleaved_by*ic_multiply*kw; -+ want_wei_md.format_desc.blocking.strides[3] = interleaved_by*ic_multiply; -+ -+ acl_utils::update_strides_y_and_z( -+ acp.wei_info, -+ want_wei_md.format_desc.blocking.strides[0] * wei_d.data_type_size(), -+ acp.wei_info.strides_in_bytes().z()); -+ -+ // Set blocking -+ want_wei_md.format_desc.blocking.inner_nblks = (block_by > 1) + 1; -+ want_wei_md.format_desc.blocking.inner_idxs[0] = 0; // second to last dimension in abcd format -+ want_wei_md.format_desc.blocking.inner_blks[0] = interleaved_by; -+ -+ if(block_by > 1) { -+ want_wei_md.format_desc.blocking.inner_idxs[1] = 1; // second to last dimension in abcd format -+ want_wei_md.format_desc.blocking.inner_blks[1] = block_by; -+ } -+ -+ if(is_fast_math_kernel) { -+ // If it is fast math mode we need weights in BFloat16 -+ want_wei_md.data_type = dnnl_bf16; + } + -+ weights_md = want_wei_md; ++ acl_utils::reorder_to_weight_format(acp.wei_tensor_info, weights_md, ++ expected_weight_format, I_dim, O_dim, {W_dim, H_dim}, {}); + return status::success; } -@@ -219,6 +300,7 @@ status_t init_conf_gemm(acl_conv_conf_t &acp, memory_desc_t &src_md, +@@ -226,10 +259,10 @@ status_t init_conf_gemm(acl_conv_conf_t &acp, memory_desc_t &src_md, + // clang-format off + // Validate convolution manually to check for return status + ACL_CHECK_VALID(arm_compute::NEGEMMConvolutionLayer::validate( +- &acp.src_info, +- &acp.wei_info, +- acp.with_bias ? &acp.bia_info : nullptr, +- &acp.dst_info, ++ &acp.src_tensor_info, ++ &acp.wei_tensor_info, ++ acp.with_bias ? &acp.bia_tensor_info : nullptr, ++ &acp.dst_tensor_info, + acp.padstride_info, + acp.weights_info, + acp.dilation_info, +@@ -244,28 +277,38 @@ status_t init_conf_indirect_gemm(acl_conv_conf_t &acp, memory_desc_t &src_md, memory_desc_t &weights_md, memory_desc_t &dst_md, memory_desc_t &bias_md, const convolution_desc_t &cd, const primitive_attr_t &attr) { -+ acp.is_indirect = false; - - // General Compute Library checks, memory tags are also set there - CHECK(acl_init_conf(acp, src_md, weights_md, dst_md, bias_md, cd, attr)); -@@ -244,11 +326,13 @@ status_t init_conf_indirect_gemm(acl_conv_conf_t &acp, memory_desc_t &src_md, - memory_desc_t &weights_md, memory_desc_t &dst_md, - memory_desc_t &bias_md, const convolution_desc_t &cd, - const primitive_attr_t &attr) { -+ acp.is_indirect = true; -+ auto math_mode = get_fpmath_mode(); - // Indirect convolution results in slowdown for low thread count or 1x1 - // kernels, so fall back to GEMM-based convolution in these cases - if (one_of(true, weights_md.dims[2] == 1, // kh - weights_md.dims[3] == 1, // kw +- // Indirect convolution results in slowdown for low thread count or 1x1 +- // kernels, so fall back to GEMM-based convolution in these cases +- if (one_of(true, weights_md.dims[2] == 1, // kh +- weights_md.dims[3] == 1, // kw - dnnl_get_max_threads() < 28)) { -+ (!math_mode && dnnl_get_max_threads() < 28))) { ++ ++ // Indirect is slower for small convolution kernels ++ if (weights_md.dims[2] == 1 && weights_md.dims[3] == 1) return status::unimplemented; - } +- } -@@ -275,6 +359,7 @@ status_t init_conf_wino(acl_conv_conf_t &acp, memory_desc_t &src_md, - memory_desc_t &weights_md, memory_desc_t &dst_md, - memory_desc_t &bias_md, const convolution_desc_t &cd, - const primitive_attr_t &attr) { -+ acp.is_indirect = false; + CHECK(acl_init_conf(acp, src_md, weights_md, dst_md, bias_md, cd, attr)); + ++ // Indirect is slower than gemm for low thread counts, except for fast math ++ if (dnnl_get_max_threads() < 28 && !acp.fast_math) ++ return status::unimplemented; ++ ++ // If we do not need to pad input channels for fast math mode then it would ++ // be faster to run convolution with im2row instead of using indirect kernel ++ int block_by = arm_compute::block_by(acp.weights_info.weight_format()); ++ int ic = src_md.dims[1]; ++ if (acp.fast_math && ic % block_by == 0) return status::unimplemented; ++ ++ // TODO: remove this once NEGEMMConv2d::validate allows src and weights to mismatch ++ acp.wei_tensor_info.set_data_layout(arm_compute::DataLayout::NHWC); ++ + // clang-format off + // NOTE: indirect convolution method supports only nhwc layout. + ACL_CHECK_VALID(arm_compute::NEGEMMConv2d::validate( +- &acp.src_info, +- &acp.wei_info, +- acp.with_bias ? &acp.bia_info : nullptr, +- &acp.dst_info, ++ &acp.src_tensor_info, ++ &acp.wei_tensor_info, ++ acp.with_bias ? &acp.bia_tensor_info : nullptr, ++ &acp.dst_tensor_info, + arm_compute::Conv2dInfo(acp.padstride_info, + acp.dilation_info, + acp.act_info, + acp.fast_math, +- 1))); ++ 1, {}, acp.weights_info))); + // clang-format on - // Under these conditions, fallback to faster GEMM-based convolution - // unless the user explicitly specifies Winograd algorithm + return status::success; diff --git a/src/cpu/aarch64/acl_convolution_utils.hpp b/src/cpu/aarch64/acl_convolution_utils.hpp -index 3e56245fa..44dc8eecb 100644 +index 0398ab06b9..e3d40a5e75 100644 --- a/src/cpu/aarch64/acl_convolution_utils.hpp +++ b/src/cpu/aarch64/acl_convolution_utils.hpp -@@ -43,6 +43,7 @@ struct acl_conv_conf_t { +@@ -38,17 +38,17 @@ struct acl_obj_t { + + struct acl_conv_conf_t { + bool with_bias; +- bool is_int8; + bool fast_math; // If this is true, the result of the convolution goes into a temporarily // allocated ACL tensor to be accumulated into the oneDNN dst during postops bool use_dst_acc; -+ bool is_indirect; - arm_compute::TensorInfo src_info; - arm_compute::TensorInfo wei_info; - arm_compute::TensorInfo bia_info; +- arm_compute::TensorInfo src_info; +- arm_compute::TensorInfo wei_info; +- arm_compute::TensorInfo bia_info; +- arm_compute::TensorInfo dst_info; ++ arm_compute::TensorInfo src_tensor_info; ++ arm_compute::TensorInfo wei_tensor_info; ++ arm_compute::TensorInfo bia_tensor_info; ++ arm_compute::TensorInfo dst_tensor_info; + arm_compute::PadStrideInfo padstride_info; + arm_compute::Size2D dilation_info; ++ // Additional information about the weights not included in wei_tensor_info + arm_compute::WeightsInfo weights_info; + // Note: this will default to not enabled, and will do nothing + arm_compute::ActivationLayerInfo act_info; +diff --git a/src/cpu/aarch64/acl_gemm_convolution.hpp b/src/cpu/aarch64/acl_gemm_convolution.hpp +index 485db954ea..da58e4f610 100644 +--- a/src/cpu/aarch64/acl_gemm_convolution.hpp ++++ b/src/cpu/aarch64/acl_gemm_convolution.hpp +@@ -1,5 +1,5 @@ + /******************************************************************************* +-* Copyright 2020-2022 Arm Ltd. and affiliates ++* Copyright 2020-2023 Arm Ltd. and affiliates + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. +@@ -36,10 +36,10 @@ struct acl_resource_t : public resource_t { + if (!acl_obj_) return status::out_of_memory; + + // Init Compute Library tensors based on info from descriptor +- acl_obj_->src_tensor.allocator()->init(acp.src_info); +- acl_obj_->wei_tensor.allocator()->init(acp.wei_info); +- acl_obj_->dst_tensor.allocator()->init(acp.dst_info); +- acl_obj_->bia_tensor.allocator()->init(acp.bia_info); ++ acl_obj_->src_tensor.allocator()->init(acp.src_tensor_info); ++ acl_obj_->wei_tensor.allocator()->init(acp.wei_tensor_info); ++ acl_obj_->dst_tensor.allocator()->init(acp.dst_tensor_info); ++ acl_obj_->bia_tensor.allocator()->init(acp.bia_tensor_info); + + acl_obj_->conv.configure(&acl_obj_->src_tensor, &acl_obj_->wei_tensor, + acp.with_bias ? &acl_obj_->bia_tensor : nullptr, diff --git a/src/cpu/aarch64/acl_indirect_gemm_convolution.hpp b/src/cpu/aarch64/acl_indirect_gemm_convolution.hpp -index bcf031a77..4ddc8cf91 100644 +index bcf031a771..b7c8dce894 100644 --- a/src/cpu/aarch64/acl_indirect_gemm_convolution.hpp +++ b/src/cpu/aarch64/acl_indirect_gemm_convolution.hpp -@@ -41,6 +41,7 @@ struct acl_indirect_gemm_resource_t : public resource_t { - acl_obj_->bia_tensor.allocator()->init(acp.bia_info); +@@ -1,5 +1,5 @@ + /******************************************************************************* +-* Copyright 2021-2022 Arm Ltd. and affiliates ++* Copyright 2021-2023 Arm Ltd. and affiliates + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. +@@ -35,10 +35,10 @@ struct acl_indirect_gemm_resource_t : public resource_t { + if (!acl_obj_) return status::out_of_memory; + + // Init Compute Library tensors based on info from descriptor +- acl_obj_->src_tensor.allocator()->init(acp.src_info); +- acl_obj_->wei_tensor.allocator()->init(acp.wei_info); +- acl_obj_->dst_tensor.allocator()->init(acp.dst_info); +- acl_obj_->bia_tensor.allocator()->init(acp.bia_info); ++ acl_obj_->src_tensor.allocator()->init(acp.src_tensor_info); ++ acl_obj_->wei_tensor.allocator()->init(acp.wei_tensor_info); ++ acl_obj_->dst_tensor.allocator()->init(acp.dst_tensor_info); ++ acl_obj_->bia_tensor.allocator()->init(acp.bia_tensor_info); // clang-format off -+ arm_compute::experimental::PostOpList empty_post_ops = arm_compute::experimental::PostOpList {}; acl_obj_->conv.configure( - &acl_obj_->src_tensor, - &acl_obj_->wei_tensor, -@@ -50,7 +51,9 @@ struct acl_indirect_gemm_resource_t : public resource_t { +@@ -50,7 +50,9 @@ struct acl_indirect_gemm_resource_t : public resource_t { acp.dilation_info, acp.act_info, acp.fast_math, - 1)); + 1, -+ empty_post_ops, ++ {}, + acp.weights_info)); // clang-format on return status::success; diff --git a/src/cpu/aarch64/acl_inner_product.hpp b/src/cpu/aarch64/acl_inner_product.hpp -index c5e507085..163ff066e 100644 +index c5e507085f..a27df640fb 100644 --- a/src/cpu/aarch64/acl_inner_product.hpp +++ b/src/cpu/aarch64/acl_inner_product.hpp -@@ -45,6 +45,7 @@ struct acl_ip_conf_t { - arm_compute::TensorInfo bia_info; - arm_compute::TensorInfo dst_info; +@@ -40,11 +40,13 @@ struct acl_ip_conf_t { + // If this is true, the result of the inner product goes into a temporarily + // allocated ACL tensor to be accumulated into the oneDNN dst during postops + bool use_dst_acc; +- arm_compute::TensorInfo src_info; +- arm_compute::TensorInfo wei_info; +- arm_compute::TensorInfo bia_info; +- arm_compute::TensorInfo dst_info; ++ arm_compute::TensorInfo src_tensor_info; ++ arm_compute::TensorInfo wei_tensor_info; ++ arm_compute::TensorInfo bia_tensor_info; ++ arm_compute::TensorInfo dst_tensor_info; arm_compute::FullyConnectedLayerInfo fc_info; ++ // Additional information about the weights not included in wei_tensor_info + arm_compute::WeightsInfo weights_info; }; struct acl_ip_resource_t : public resource_t { acl_ip_resource_t() : acl_ip_obj_(utils::make_unique()) {} -@@ -64,7 +65,8 @@ struct acl_ip_resource_t : public resource_t { +@@ -53,10 +55,10 @@ struct acl_ip_resource_t : public resource_t { + if (!acl_ip_obj_) return status::out_of_memory; + + // Init Compute Library tensors based on info from descriptor +- acl_ip_obj_->src_tensor.allocator()->init(aip.src_info); +- acl_ip_obj_->wei_tensor.allocator()->init(aip.wei_info); +- acl_ip_obj_->dst_tensor.allocator()->init(aip.dst_info); +- acl_ip_obj_->bia_tensor.allocator()->init(aip.bia_info); ++ acl_ip_obj_->src_tensor.allocator()->init(aip.src_tensor_info); ++ acl_ip_obj_->wei_tensor.allocator()->init(aip.wei_tensor_info); ++ acl_ip_obj_->dst_tensor.allocator()->init(aip.dst_tensor_info); ++ acl_ip_obj_->bia_tensor.allocator()->init(aip.bia_tensor_info); + + // clang-format off + acl_ip_obj_->fc.configure( +@@ -64,7 +66,8 @@ struct acl_ip_resource_t : public resource_t { &acl_ip_obj_->wei_tensor, aip.with_bias ? &acl_ip_obj_->bia_tensor : nullptr, &acl_ip_obj_->dst_tensor, @@ -195,41 +496,127 @@ index c5e507085..163ff066e 100644 // clang-format on return status::success; -@@ -156,8 +158,8 @@ struct acl_inner_product_fwd_t : public primitive_t { - src_shape = (src_tag == nc) ? arm_compute::TensorShape(ic, n) - : arm_compute::TensorShape(n, ic); +@@ -89,12 +92,16 @@ struct acl_inner_product_fwd_t : public primitive_t { + DECLARE_COMMON_PD_T("acl", acl_inner_product_fwd_t); + status_t init(engine_t *engine) { +- const bool ok = is_fwd() && !has_zero_dim_memory() +- && expect_data_types(data_type::f32, data_type::f32, +- data_type::f32, data_type::f32, data_type::f32) ++ using namespace data_type; ++ const bool is_fp16_ok = expect_data_types(f16, f16, f16, f16, undef) ++ && attr()->has_default_values( ++ primitive_attr_t::skip_mask_t::post_ops, f16); ++ const bool is_fp32_ok = expect_data_types(f32, f32, f32, f32, undef) + && attr()->has_default_values( +- primitive_attr_t::skip_mask_t::post_ops, +- data_type::f32) ++ primitive_attr_t::skip_mask_t::post_ops, f32); ++ const bool ok = is_fwd() && !has_zero_dim_memory() ++ && utils::one_of(true, is_fp16_ok, is_fp32_ok) ++ && weights_md_.format_kind == format_kind::any + && set_default_params() == status::success; + + if (!ok) return status::unimplemented; +@@ -121,88 +128,46 @@ struct acl_inner_product_fwd_t : public primitive_t { + ACL_CHECK_SUPPORT( + !(is_2d || is_4d), "ACL supports only 2d or 4d cases"); + +- // batch size +- const int n = src_md()->dims[0]; +- +- // input and output channels +- const int ic = src_md()->dims[1]; +- const int oc = dst_md()->dims[1]; +- +- // source spatial dimensions +- const int ih = is_4d ? src_md()->dims[ndims - 2] : 0; +- const int iw = is_4d ? src_md()->dims[ndims - 1] : 0; +- +- // weights spatial dimensions +- const int kh = is_4d ? weights_md()->dims[ndims - 2] : 0; +- const int kw = is_4d ? weights_md()->dims[ndims - 1] : 0; +- +- // Only NCHW or NHWC derivatives supported by ACL kernels + using namespace format_tag; +- auto src_tag = memory_desc_matches_one_of_tag( +- src_md_, nhwc, nchw, nc, cn); +- auto wei_tag = memory_desc_matches_one_of_tag( +- weights_md_, ohwi, oihw, oi, io); +- auto dst_tag = memory_desc_matches_one_of_tag(dst_md_, nc, cn); ++ auto src_tag ++ = memory_desc_matches_one_of_tag(src_md_, nhwc, nchw, nc); ++ auto dst_tag = memory_desc_matches_one_of_tag(dst_md_, nc); + + ACL_CHECK_SUPPORT( +- utils::one_of(format_tag::undef, src_tag, wei_tag, dst_tag), ++ utils::one_of(format_tag::undef, src_tag, dst_tag), + "unsupported memory layout"); + + ACL_CHECK_SUPPORT(is_2d && src_tag != dst_tag, + "for src and dst layouts must match"); + +- arm_compute::TensorShape src_shape, wei_shape; +- if (is_2d) { +- src_shape = (src_tag == nc) ? arm_compute::TensorShape(ic, n) +- : arm_compute::TensorShape(n, ic); +- - wei_shape = (wei_tag == io) ? arm_compute::TensorShape(oc, ic) - : arm_compute::TensorShape(ic, oc); -+ // For fixed format kernels weight shape is always io -+ wei_shape = arm_compute::TensorShape(oc, ic); - } - if (is_4d) { - src_shape = (src_tag == nhwc) -@@ -166,7 +168,8 @@ struct acl_inner_product_fwd_t : public primitive_t { - - // ACL requires the weights to be in 2D flattened shape - const int flattened_ic = is_4d ? ic * kh * kw : ic; +- } +- if (is_4d) { +- src_shape = (src_tag == nhwc) +- ? arm_compute::TensorShape(ic, iw, ih, n) +- : arm_compute::TensorShape(iw, ih, ic, n); +- +- // ACL requires the weights to be in 2D flattened shape +- const int flattened_ic = is_4d ? ic * kh * kw : ic; - wei_shape = arm_compute::TensorShape(flattened_ic, oc); -+ // For fixed format kernels weights shape is always io -+ wei_shape = arm_compute::TensorShape(oc, flattened_ic); - } - - arm_compute::DataLayout src_layout = (src_tag == nhwc) -@@ -183,6 +186,9 @@ struct acl_inner_product_fwd_t : public primitive_t { - aip.wei_info = arm_compute::TensorInfo( - wei_shape, 1, arm_compute::DataType::F32, wei_layout); - -+ aip.weights_info = arm_compute::WeightsInfo( -+ false, 1, 1, is_4d ? ic * kh *kw : ic, false, arm_compute::WeightFormat::ANY); -+ - aip.dst_info - = arm_compute::TensorInfo(arm_compute::TensorShape(oc, n), - 1, arm_compute::DataType::F32); -@@ -194,15 +200,7 @@ struct acl_inner_product_fwd_t : public primitive_t { +- } +- +- arm_compute::DataLayout src_layout = (src_tag == nhwc) +- ? arm_compute::DataLayout::NHWC +- : arm_compute::DataLayout::NCHW; ++ const dim_t ic_total = IC_total(); ++ const dim_t n = MB(); ++ const dim_t oc = OC(); + +- arm_compute::DataLayout wei_layout = (wei_tag == ohwi) +- ? arm_compute::DataLayout::NHWC +- : arm_compute::DataLayout::NCHW; ++ aip.src_tensor_info = arm_compute::TensorInfo( ++ arm_compute::TensorShape(ic_total, n), 1, ++ acl_utils::get_acl_data_t(src_md()->data_type)); + +- aip.src_info = arm_compute::TensorInfo( +- src_shape, 1, arm_compute::DataType::F32, src_layout); ++ // ACL requires the weights to be in 2D flattened shape ++ aip.wei_tensor_info = arm_compute::TensorInfo( ++ arm_compute::TensorShape(oc, ic_total), 1, ++ acl_utils::get_acl_data_t(weights_md(0)->data_type)); + +- aip.wei_info = arm_compute::TensorInfo( +- wei_shape, 1, arm_compute::DataType::F32, wei_layout); +- +- aip.dst_info +- = arm_compute::TensorInfo(arm_compute::TensorShape(oc, n), +- 1, arm_compute::DataType::F32); ++ auto acl_dst_data_t ++ = acl_utils::get_acl_data_t(dst_md()->data_type); ++ aip.dst_tensor_info = arm_compute::TensorInfo( ++ arm_compute::TensorShape(oc, n), 1, acl_dst_data_t); + + aip.with_bias = desc()->bias_desc.format_kind != format_kind::undef; +- aip.bia_info = arm_compute::TensorInfo(aip.with_bias ++ auto acl_bia_data_t = aip.with_bias ++ ? acl_utils::get_acl_data_t(weights_md(1)->data_type) ++ : acl_dst_data_t; ++ aip.bia_tensor_info = arm_compute::TensorInfo(aip.with_bias + ? arm_compute::TensorShape(oc) + : arm_compute::TensorShape(), 1, arm_compute::DataType::F32); - aip.fc_info.weights_trained_layout = wei_layout; +- aip.fc_info.weights_trained_layout = wei_layout; - if (is_2d && wei_tag != src_tag) { - // weights are already transposed - aip.fc_info.transpose_weights = false; @@ -243,294 +630,537 @@ index c5e507085..163ff066e 100644 // Fast math mode auto math_mode = get_fpmath_mode(); -@@ -214,6 +212,80 @@ struct acl_inner_product_fwd_t : public primitive_t { +@@ -214,15 +179,103 @@ struct acl_inner_product_fwd_t : public primitive_t { aip.fc_info.activation_info)); aip.use_dst_acc = post_ops.has_sum(); ++ // WeightFormat::ANY tells ACL we can handle any format ++ aip.weights_info = arm_compute::WeightsInfo(false, 1, 1, ic_total, ++ false, arm_compute::WeightFormat::ANY); ++ ++ // Get the format that the ACL kernel will expect the weights to be ++ // in (if a kernel exists) Note that these are referred to as fixed ++ // format kernels, because they require one specific weights format + arm_compute::WeightFormat expected_weight_format; -+ auto acl_st = arm_compute::NEFullyConnectedLayer::has_opt_impl( -+ expected_weight_format, -+ &aip.src_info, -+ &aip.wei_info, -+ aip.with_bias ? &aip.bia_info : nullptr, -+ &aip.dst_info, -+ aip.fc_info, -+ aip.weights_info); -+ if(acl_st.error_code() != arm_compute::ErrorCode::OK) { -+ return status::unimplemented; -+ } ++ ACL_CHECK_VALID(arm_compute::NEFullyConnectedLayer::has_opt_impl( ++ expected_weight_format, &aip.src_tensor_info, ++ &aip.wei_tensor_info, ++ aip.with_bias ? &aip.bia_tensor_info : nullptr, ++ &aip.dst_tensor_info, aip.fc_info, aip.weights_info)); + ++ // Set weights info to the one returned by has_opt_impl + aip.weights_info.set_weight_format(expected_weight_format); + -+ int interleaved_by = arm_compute::interleave_by(expected_weight_format); -+ int block_by = arm_compute::block_by(expected_weight_format); -+ bool is_fast_math_kernel = arm_compute::is_fixed_format_fast_math(expected_weight_format); ++ // has_opt_impl may return a non fast math kernel, even if requested ++ aip.fc_info.enable_fast_math ++ = arm_compute::is_fixed_format_fast_math( ++ expected_weight_format); + -+ if(!is_fast_math_kernel) { -+ // FP32 kernel might be faster for some cases then BF16 -+ aip.fc_info.enable_fast_math = false; -+ } -+ -+ memory_desc_t want_wei_md = weights_md_; ++ // Inner product is the same as the matmul n x (chw) * (ihw) x o ++ // (note that the src c and weights i both correspond to the input ++ // channel). ACL FullyConnectedLayer assumes the chw dimensions of ++ // src and ihw dimensions of weights are collapsed, so we need to ++ // make sure that they have the same layout. Given that weights are ++ // more often fixed, (so reorders can be hoisted) it makes sense to ++ // reorder the weights to fit the src. + -+ int ic_multiply = ic; -+ if(is_4d) { -+ ic_multiply = ic * kh * kw; ++ // For 4D tensors we need to: ++ // - reorder the ihw of the weights to match the src chw ++ // - collapse ihw ++ // - pad the collapsed ihw ++ // But there is not yet a way to express this collapse+pad as a ++ // reorder. So we try to reorder the weights to match the src, ++ // implicitly collapse ihw in our definition of the weights ++ // TensorInfo and hope that the inner_dim has zero padding ++ // (weights_md_.dims[inner_dim] % block_by == 0). If it does, we ++ // fall back to a kernel without blocking (currently this is ++ // equivalent to non-fastmath). + -+ // Since we are flattening dimensions the memory descriptor -+ // should also be for 2D -+ want_wei_md.ndims = 2; ++ // 2D just works because we just pad the only dimension. + -+ want_wei_md.dims[1] = ic_multiply; -+ want_wei_md.padded_dims[1] = ic_multiply; -+ want_wei_md.format_desc.blocking.strides[1] = 1; ++ // o_dim is always the first logical dimension (oihw, ohwi, oi) ++ dim_t o_dim = 0; ++ dim_t inner_dim; ++ // Rest of logical dimensions in order of innermost to outermost ++ std::vector remaining_dims = {}; + -+ want_wei_md.dims[0] = oc; -+ want_wei_md.padded_dims[0] = want_wei_md.padded_dims[1]; -+ want_wei_md.padded_dims[0] = oc; ++ if (src_tag == nchw) { ++ inner_dim = 3; // w ++ remaining_dims = {2, 1}; // h, i ++ } else if (src_tag == nhwc) { ++ inner_dim = 1; // i ++ remaining_dims = {3, 2}; // w, h ++ } else { // Only remaining case is 2D (nc) ++ inner_dim = 1; // i ++ remaining_dims = {}; // No other dimensions for 2D + } + -+ want_wei_md.format_desc.blocking.strides[1] = interleaved_by * block_by; -+ if(want_wei_md.dims[1] % block_by != 0) { -+ want_wei_md.padded_dims[1] = utils::div_up(want_wei_md.dims[1], block_by) * block_by; -+ } -+ want_wei_md.format_desc.blocking.strides[0] = interleaved_by * want_wei_md.padded_dims[1]; -+ -+ if(oc % interleaved_by != 0) { -+ int padded_dim = utils::div_up(oc, interleaved_by) * interleaved_by; -+ want_wei_md.padded_dims[0] = padded_dim; -+ } -+ -+ int data_type_size = memory_desc_wrapper(want_wei_md).data_type_size(); -+ acl_utils::update_strides_y_and_z( -+ aip.wei_info, -+ want_wei_md.format_desc.blocking.strides[0] * data_type_size, -+ want_wei_md.format_desc.blocking.strides[1] * data_type_size); -+ -+ want_wei_md.format_desc.blocking.inner_nblks = (block_by > 1) + 1; -+ want_wei_md.format_desc.blocking.inner_idxs[0] = 0; -+ want_wei_md.format_desc.blocking.inner_blks[0] = interleaved_by; -+ if(block_by > 1) { -+ want_wei_md.format_desc.blocking.inner_idxs[1] = 1; -+ want_wei_md.format_desc.blocking.inner_blks[1] = block_by; -+ } -+ -+ if(is_fast_math_kernel) { -+ want_wei_md.data_type = dnnl_bf16; ++ // Fallback ++ int block_by = arm_compute::block_by(expected_weight_format); ++ if (is_4d && weights_md_.dims[inner_dim] % block_by != 0 ++ && aip.fc_info.enable_fast_math) { ++ aip.fc_info.enable_fast_math = false; ++ aip.weights_info.set_weight_format( ++ arm_compute::WeightFormat::ANY); ++ ACL_CHECK_VALID( ++ arm_compute::NEFullyConnectedLayer::has_opt_impl( ++ expected_weight_format, &aip.src_tensor_info, ++ &aip.wei_tensor_info, ++ aip.with_bias ? &aip.bia_tensor_info : nullptr, ++ &aip.dst_tensor_info, aip.fc_info, ++ aip.weights_info)); ++ aip.weights_info.set_weight_format(expected_weight_format); ++ block_by = arm_compute::block_by(expected_weight_format); ++ if (weights_md_.dims[inner_dim] % block_by != 0) ++ return status::unimplemented; + } + -+ weights_md_ = want_wei_md; ++ acl_utils::reorder_to_weight_format(aip.wei_tensor_info, ++ weights_md_, expected_weight_format, inner_dim, o_dim, ++ remaining_dims, {}); + // clang-format off ++ // Validate fully connected layer manually to check for return status ACL_CHECK_VALID(arm_compute::NEFullyConnectedLayer::validate( +- &aip.src_info, +- &aip.wei_info, +- aip.with_bias ? &aip.bia_info : nullptr, +- &aip.dst_info, +- aip.fc_info)); ++ &aip.src_tensor_info, ++ &aip.wei_tensor_info, ++ aip.with_bias ? &aip.bia_tensor_info : nullptr, ++ &aip.dst_tensor_info, ++ aip.fc_info, ++ aip.weights_info)); + // clang-format on ++ + return status::success; + } + }; // pd_t diff --git a/src/cpu/aarch64/acl_utils.cpp b/src/cpu/aarch64/acl_utils.cpp -index 79ea775d6..7ee4c7398 100644 +index 79ea775d6d..5792fd4911 100644 --- a/src/cpu/aarch64/acl_utils.cpp +++ b/src/cpu/aarch64/acl_utils.cpp -@@ -157,6 +157,28 @@ status_t tensor_info( - return status::success; +@@ -1,5 +1,5 @@ + /******************************************************************************* +-* Copyright 2021-2022 Arm Ltd. and affiliates ++* Copyright 2021-2023 Arm Ltd. and affiliates + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. +@@ -261,6 +261,75 @@ int reorder_dimensions_by_stride(std::vector permuted_mds, + return reordered_dims; } -+status_t update_strides_y_and_z( -+ arm_compute::TensorInfo &info, const int y, const int z) { ++void reorder_to_weight_format(arm_compute::TensorInfo &info, memory_desc_t &md, ++ arm_compute::WeightFormat wf, dim_t I_dim, dim_t O_dim, ++ std::vector spatial_dims, std::vector batch_dims) { ++ ++ md.format_kind = format_kind::blocked; ++ md.format_desc.blocking = blocking_desc_t {}; ++ const int interleaved_by = arm_compute::interleave_by(wf); ++ const int block_by = arm_compute::block_by(wf); + -+ arm_compute::TensorShape shape = info.tensor_shape(); -+ arm_compute::Strides old_strides_in_bytes = info.strides_in_bytes(); ++ // I dimension becomes densest (apart from blocking) ++ md.format_desc.blocking.strides[I_dim] = interleaved_by * block_by; ++ md.padded_dims[I_dim] = utils::rnd_up(md.dims[I_dim], block_by); + -+ arm_compute::Strides new_strides_in_bytes; -+ for(size_t i = 0; i < shape.num_dimensions(); ++i) { -+ new_strides_in_bytes.set(i, old_strides_in_bytes[i]); ++ // Then any spatial dimensions (e.g. HW) ++ dim_t ldb = interleaved_by * md.padded_dims[I_dim]; ++ for (dim_t sd : spatial_dims) { ++ md.format_desc.blocking.strides[sd] = ldb; ++ ldb *= md.padded_dims[sd]; + } + -+ // set y -+ new_strides_in_bytes.set(1, y); -+ // set z -+ new_strides_in_bytes.set(2, z); ++ // O dim (which was the innermost) becomes the outermost (apart from batching) ++ md.format_desc.blocking.strides[O_dim] = ldb; ++ md.padded_dims[O_dim] = utils::rnd_up(md.dims[O_dim], interleaved_by); + -+ info.init(info.tensor_shape(), info.num_channels(), info.data_type(), -+ new_strides_in_bytes, info.offset_first_element_in_bytes(), info.total_size()); ++ // Update the batch dimensions, starting with stride of the innermost batch ++ const dim_t innermost_batch_stride ++ = md.padded_dims[I_dim] * md.padded_dims[O_dim]; ++ dim_t batch_stride = innermost_batch_stride; ++ for (dim_t bd : batch_dims) { ++ md.format_desc.blocking.strides[bd] = batch_stride; ++ batch_stride *= md.padded_dims[bd]; ++ } ++ ++ // Weights can only be blocked if they are also interleaved ++ if (interleaved_by > 1) { ++ md.format_desc.blocking.inner_nblks = 1 + (block_by > 1); ++ ++ md.format_desc.blocking.inner_idxs[0] = O_dim; ++ md.format_desc.blocking.inner_blks[0] = interleaved_by; ++ if (block_by > 1) { ++ md.format_desc.blocking.inner_idxs[1] = I_dim; ++ md.format_desc.blocking.inner_blks[1] = block_by; ++ } ++ } ++ ++ if (arm_compute::is_fixed_format_fast_math(wf)) { ++ md.data_type = dnnl_bf16; ++ info.set_data_type(arm_compute::DataType::BFLOAT16); ++ } ++ ++ // The data layout is now determined by the manually set strides ++ info.set_data_layout(arm_compute::DataLayout::UNKNOWN); + -+ return status::success; ++ // x is ignored in fixed format kernels ++ // y is the leading dimension of b (ldb) in the GEMM d = a*b + c ++ // This is the stride of O_dim in the md ++ // z is the batch dimension (not strictly needed if there's only 1 batch) ++ // i.e. how much do I need to stride to get to the next matmul (ignoring ++ // the interleaving). Note that we use the innermost_batch_stride ++ // because all the batched dimensions are collapsed (as required by ACL). ++ arm_compute::Strides new_strides_in_bytes = info.strides_in_bytes(); ++ new_strides_in_bytes.set(1, ldb * info.element_size()); ++ new_strides_in_bytes.set(2, innermost_batch_stride * info.element_size()); ++ ++ info.init(info.tensor_shape(), info.num_channels(), info.data_type(), ++ new_strides_in_bytes, info.offset_first_element_in_bytes(), ++ memory_desc_wrapper(md).size()); +} + - status_t insert_singleton_dimension(arm_compute::TensorInfo &ti, size_t dim_i) { + } // namespace acl_utils - // Max 6 dims in ACL, so we can't insert another + } // namespace aarch64 diff --git a/src/cpu/aarch64/acl_utils.hpp b/src/cpu/aarch64/acl_utils.hpp -index 28693bb16..c7c9e1278 100644 +index 28693bb167..d9affe1c8f 100644 --- a/src/cpu/aarch64/acl_utils.hpp +++ b/src/cpu/aarch64/acl_utils.hpp -@@ -62,6 +62,9 @@ status_t tensor_info(arm_compute::TensorInfo &info, const memory_desc_t &md); - status_t tensor_info( - arm_compute::TensorInfo &info, const memory_desc_wrapper &md); +@@ -1,5 +1,5 @@ + /******************************************************************************* +-* Copyright 2021-2022 Arm Ltd. and affiliates ++* Copyright 2021-2023 Arm Ltd. and affiliates + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. +@@ -74,6 +74,28 @@ status_t insert_singleton_dimension(arm_compute::TensorInfo &ti, size_t dim_i); + int reorder_dimensions_by_stride(std::vector permuted_mds, + std::vector mds); -+// Update y and z strides in arm_compute::TensorInfo -+status_t update_strides_y_and_z(arm_compute::TensorInfo &info, const int y, const int z); ++// Reorder a memory_desc_t and set the strides on a arm_compute::TensorInfo to ++// match an arm_compute::WeightFormat. You are required to specify how various ++// logical dimensions in oneDNN correspond to logical dimensions in arm_compute. ++// info TensorInfo where the strides will be changed to match the reordering ++// md memory descriptor where the stride and padded dimensions will be ++// changed or reordering ++// wf Describes the memory format/layout of the weights ++// I_dim The logical dimension of md corresponding to the input channel of ++// a convolution or the K dimension in a matmul ++// O_dim The logical dimension of md corresponding to the output channel of a ++//   convolution or the N dimension in a matmul ++// spatial_dims The logical dimensions of md corresponding to the spatial ++// dimensions of the weights (H, W, D for example). These will be ++// the next densest after the inner blocks and the input channel. ++// batch_dims The logical dimensions of md related to the batch in a batched ++// matmul, ordered from innermost to outermost. ACL calls these ++// the multi_stride_b. These will become the outermost (least dense) ++// dimensions and will be collapsed. ++void reorder_to_weight_format(arm_compute::TensorInfo &info, memory_desc_t &md, ++ arm_compute::WeightFormat wf, dim_t I_dim, dim_t O_dim, ++ std::vector spatial_dims, std::vector batch_dims = {}); + - // Insert a dimension of size 1 at the index dim_i of TensorInfo - status_t insert_singleton_dimension(arm_compute::TensorInfo &ti, size_t dim_i); + // Logs a custom 'info' line describing an unsupported case + #define LOG_ACL_UNSUPPORTED(msg) \ + do { \ +diff --git a/src/cpu/aarch64/matmul/acl_matmul.cpp b/src/cpu/aarch64/matmul/acl_matmul.cpp +index dce220fb6e..ca1c7eb47e 100644 +--- a/src/cpu/aarch64/matmul/acl_matmul.cpp ++++ b/src/cpu/aarch64/matmul/acl_matmul.cpp +@@ -1,5 +1,5 @@ + /******************************************************************************* +-* Copyright 2021-2022 Arm Ltd. and affiliates ++* Copyright 2021-2023 Arm Ltd. and affiliates + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. +@@ -31,36 +31,19 @@ status_t acl_matmul_t::execute_forward(const exec_ctx_t &ctx) const { + auto wei_base = CTX_IN_MEM(const data_t *, DNNL_ARG_WEIGHTS); + + bool is_transA = pd()->amp_.is_transA; +- bool is_transB = pd()->amp_.is_transB; + bool use_dst_acc = pd()->amp_.use_dst_acc; + + std::lock_guard _lock {this->mtx}; + auto *acl_resource = ctx.get_resource_mapper()->get(this); + acl_matmul_obj_t &acl_obj = acl_resource->get_acl_obj(); + // Run transpose kernel +- if (is_transA && !is_transB) { ++ if (is_transA) { + acl_obj.src_tensor.allocator()->allocate(); + acl_obj.src_acc_tensor.allocator()->import_memory( + const_cast(src_base)); + acl_obj.transA.run(); + acl_obj.wei_tensor.allocator()->import_memory( + const_cast(wei_base)); +- } else if (is_transB && !is_transA) { +- acl_obj.wei_tensor.allocator()->allocate(); +- acl_obj.wei_acc_tensor.allocator()->import_memory( +- const_cast(wei_base)); +- acl_obj.transB.run(); +- acl_obj.src_tensor.allocator()->import_memory( +- const_cast(src_base)); +- } else if (is_transA && is_transB) { +- acl_obj.src_tensor.allocator()->allocate(); +- acl_obj.src_acc_tensor.allocator()->import_memory( +- const_cast(src_base)); +- acl_obj.wei_tensor.allocator()->allocate(); +- acl_obj.wei_acc_tensor.allocator()->import_memory( +- const_cast(wei_base)); +- acl_obj.transA.run(); +- acl_obj.transB.run(); + } else { + acl_obj.src_tensor.allocator()->import_memory( + const_cast(src_base)); +@@ -69,7 +52,7 @@ status_t acl_matmul_t::execute_forward(const exec_ctx_t &ctx) const { + } + + if (use_dst_acc) { +- // Put the result in a new tensor, it will be accumalated to the dst ++ // Put the result in a new tensor, it will be accumulated to the dst + // during the post ops + acl_obj.dst_tensor.allocator()->allocate(); + } else { +@@ -82,7 +65,6 @@ status_t acl_matmul_t::execute_forward(const exec_ctx_t &ctx) const { + acl_obj.src_tensor.allocator()->free(); + acl_obj.wei_tensor.allocator()->free(); + if (is_transA) acl_obj.src_acc_tensor.allocator()->free(); +- if (is_transB) acl_obj.wei_acc_tensor.allocator()->free(); + void *dst = acl_obj.dst_tensor.buffer(); + pd()->post_ops.execute(ctx, dst); +diff --git a/src/cpu/aarch64/matmul/acl_matmul.hpp b/src/cpu/aarch64/matmul/acl_matmul.hpp +index cdc942e995..832b1dbb68 100644 +--- a/src/cpu/aarch64/matmul/acl_matmul.hpp ++++ b/src/cpu/aarch64/matmul/acl_matmul.hpp +@@ -32,20 +32,15 @@ struct acl_resource_t : public resource_t { + + status_t configure(const acl_matmul_conf_t &) { + if (!acl_obj_) return status::out_of_memory; +- acl_obj_->src_tensor.allocator()->init(amp.src_info); +- acl_obj_->wei_tensor.allocator()->init(amp.wei_info); +- acl_obj_->dst_tensor.allocator()->init(amp.dst_info); ++ acl_obj_->src_tensor.allocator()->init(amp.src_tensor_info); ++ acl_obj_->wei_tensor.allocator()->init(amp.wei_tensor_info); ++ acl_obj_->dst_tensor.allocator()->init(amp.dst_tensor_info); + // Configure transpose kernel for src, wei or both + if (amp.is_transA) { + acl_obj_->src_acc_tensor.allocator()->init(amp.src_acc_info); + acl_obj_->transA.configure( + &acl_obj_->src_acc_tensor, &acl_obj_->src_tensor); + } +- if (amp.is_transB) { +- acl_obj_->wei_acc_tensor.allocator()->init(amp.wei_acc_info); +- acl_obj_->transB.configure( +- &acl_obj_->wei_acc_tensor, &acl_obj_->wei_tensor); +- } + // Configure GEMM + acl_obj_->gemm.configure(&acl_obj_->src_tensor, &acl_obj_->wei_tensor, + nullptr, &acl_obj_->dst_tensor, amp.alpha, 0.0f, amp.gemm_info); +@@ -72,12 +67,20 @@ struct acl_matmul_t : public primitive_t { + + status_t init(engine_t *engine) { + using smask_t = primitive_attr_t::skip_mask_t; +- bool ok = src_md()->data_type == data_type::f32 +- && weights_md()->data_type == data_type::f32 +- && desc()->accum_data_type == data_type::f32 +- && dst_md()->data_type == data_type::f32 +- && platform::has_data_type_support(data_type::f32) ++ const bool is_fp32_ok ++ = utils::everyone_is(data_type::f32, src_md()->data_type, ++ weights_md()->data_type, dst_md()->data_type, ++ desc()->accum_data_type) ++ && platform::has_data_type_support(data_type::f32); ++ const bool is_fp16_ok ++ = utils::everyone_is(data_type::f16, src_md()->data_type, ++ weights_md()->data_type, dst_md()->data_type) ++ && platform::has_data_type_support(data_type::f16); ++ bool ok = is_dense_data() ++ && utils::one_of(true, is_fp32_ok, is_fp16_ok) + && !has_zero_dim_memory() ++ && weights_md_.format_kind == format_kind::any ++ && set_default_formats() + && attr()->has_default_values( + smask_t::oscale | smask_t::post_ops) + && attr_oscale_ok() && !has_runtime_dims_or_strides(); +@@ -92,9 +95,9 @@ struct acl_matmul_t : public primitive_t { + amp_.use_dst_acc = post_ops.has_sum(); + + // Validate ACL GEMM +- ACL_CHECK_VALID(arm_compute::NEGEMM::validate(&_.src_info, +- &_.wei_info, nullptr, &_.dst_info, amp_.alpha, 0.0f, +- amp_.gemm_info)); ++ ACL_CHECK_VALID(arm_compute::NEGEMM::validate(&_.src_tensor_info, ++ &_.wei_tensor_info, nullptr, &_.dst_tensor_info, ++ amp_.alpha, 0.0f, amp_.gemm_info)); + + return status::success; + } diff --git a/src/cpu/aarch64/matmul/acl_matmul_utils.cpp b/src/cpu/aarch64/matmul/acl_matmul_utils.cpp -index 679baec3a..853277e37 100644 +index 679baec3a4..30bc2c1443 100644 --- a/src/cpu/aarch64/matmul/acl_matmul_utils.cpp +++ b/src/cpu/aarch64/matmul/acl_matmul_utils.cpp -@@ -66,15 +66,12 @@ status_t init_conf_matmul(acl_matmul_conf_t &, memory_desc_t &src_md, +@@ -41,6 +41,7 @@ status_t init_conf_matmul(acl_matmul_conf_t &, memory_desc_t &src_md, + const dim_t src_batch = helper.src_batch(); + const dim_t wei_batch = helper.wei_batch(); + ++ // We can only broadcast on one of src or wei at once + // ACL supports broadcast for 3D shapes, and 4D shapes + // for e.g when ab in abcd is 1x1 + bool batch_ok = IMPLICATION(src_batch > 1, wei_batch == 1) +@@ -53,44 +54,33 @@ status_t init_conf_matmul(acl_matmul_conf_t &, memory_desc_t &src_md, + bool with_bias = md.bias_desc.format_kind != format_kind::undef; + ACL_CHECK_SUPPORT(with_bias, "ACL does not support bias for matmul"); - // Transpose A (src) or B (wei) ++ // The two innermost dimensions can be transposed, but the batch dimensions ++ // must be the outermost + using namespace format_tag; + auto src_tag = memory_desc_matches_one_of_tag( + src_md, abcd, abdc, abc, acb, ab, ba); +- auto wei_tag = memory_desc_matches_one_of_tag( +- wei_md, abcd, abdc, abc, acb, ab, ba); +- auto dst_tag +- = memory_desc_matches_one_of_tag(dst_md, abcd, abc, acb, ab, ba); +- ACL_CHECK_SUPPORT( +- utils::one_of(format_tag::undef, src_tag, wei_tag, dst_tag), ++ auto dst_tag = memory_desc_matches_one_of_tag(dst_md, abcd, abc, ab, ba); ++ ACL_CHECK_SUPPORT(utils::one_of(format_tag::undef, src_tag, dst_tag), + "Format tag is undefined"); + +- // Transpose A (src) or B (wei) ++ // Transpose A (src) amp.is_transA = helper.transA() == 'T'; - amp.is_transB = helper.transB() == 'T'; -+ amp.is_transB = false; ++ ++ auto acl_src_data_t = acl_utils::get_acl_data_t(src_md.data_type); ++ auto acl_wei_data_t = acl_utils::get_acl_data_t(wei_md.data_type); ++ auto acl_dst_data_t = acl_utils::get_acl_data_t(dst_md.data_type); + if (amp.is_transA) amp.src_acc_info = arm_compute::TensorInfo( arm_compute::TensorShape(M, K, 1, src_batch), 1, - arm_compute::DataType::F32); +- arm_compute::DataType::F32); - if (amp.is_transB) - amp.wei_acc_info = arm_compute::TensorInfo( - arm_compute::TensorShape(K, N, wei_batch), 1, - arm_compute::DataType::F32); +- +- amp.src_info = arm_compute::TensorInfo( +- arm_compute::TensorShape(K, M, 1, src_batch), 1, +- arm_compute::DataType::F32); +- amp.wei_info +- = arm_compute::TensorInfo(arm_compute::TensorShape(N, K, wei_batch), +- 1, arm_compute::DataType::F32); +- amp.dst_info = arm_compute::TensorInfo( +- arm_compute::TensorShape(N, M, 1, dst_batch), 1, +- arm_compute::DataType::F32); +- +- // Fast-math mode +- auto math_mode = get_fpmath_mode(); +- bool is_fastmath_enabled +- = utils::one_of(math_mode, fpmath_mode::bf16, fpmath_mode::any); +- amp.gemm_info.set_fast_math(is_fastmath_enabled); ++ acl_src_data_t); ++ ++ amp.src_tensor_info = arm_compute::TensorInfo( ++ arm_compute::TensorShape(K, M, 1, src_batch), 1, acl_src_data_t); ++ amp.wei_tensor_info = arm_compute::TensorInfo( ++ arm_compute::TensorShape(N, K, wei_batch), 1, acl_wei_data_t); ++ amp.dst_tensor_info = arm_compute::TensorInfo( ++ arm_compute::TensorShape(N, M, 1, dst_batch), 1, acl_dst_data_t); - amp.src_info = arm_compute::TensorInfo( - arm_compute::TensorShape(K, M, 1, src_batch), 1, -@@ -103,6 +100,140 @@ status_t init_conf_matmul(acl_matmul_conf_t &, memory_desc_t &src_md, + // Set alpha (output scaling) + amp.alpha = attr.output_scales_.scales_[0]; +@@ -98,10 +88,45 @@ status_t init_conf_matmul(acl_matmul_conf_t &, memory_desc_t &src_md, + // Validate ACL transpose + if (amp.is_transA) ACL_CHECK_VALID(arm_compute::NETranspose::validate( - &.wei_acc_info, &.wei_info)); - -+ arm_compute::WeightFormat expected_weight_format; +- &.src_acc_info, &.src_info)); +- if (amp.is_transB) +- ACL_CHECK_VALID(arm_compute::NETranspose::validate( +- &.wei_acc_info, &.wei_info)); ++ &.src_acc_info, &.src_tensor_info)); ++ ++ bool is_fastmath_enabled = utils::one_of( ++ attr.fpmath_mode_, fpmath_mode::bf16, fpmath_mode::any); ++ amp.gemm_info.set_fast_math(is_fastmath_enabled); + + amp.gemm_info.set_fixed_format(true); ++ ++ // WeightFormat::ANY tells ACL we can handle any format + amp.gemm_info.set_weight_format(arm_compute::WeightFormat::ANY); + -+ auto acl_st = arm_compute::NEGEMM::has_opt_impl( -+ expected_weight_format, -+ &.src_info, -+ &.wei_info, -+ nullptr, -+ &.dst_info, -+ amp.alpha, -+ 0.0f, -+ amp.gemm_info); -+ -+ if(acl_st.error_code() != arm_compute::ErrorCode::OK) { -+ return status::unimplemented; -+ } ++ // Get the format that the ACL kernel will expect the weights to be ++ // in (if a kernel exists). Note that these are referred to as fixed format ++ // kernels, because they require one specific weights format ++ arm_compute::WeightFormat expected_weight_format; ++ ACL_CHECK_VALID(arm_compute::NEGEMM::has_opt_impl(expected_weight_format, ++ &.src_tensor_info, &.wei_tensor_info, nullptr, ++ &.dst_tensor_info, amp.alpha, 0.0f, amp.gemm_info)); + ++ // Set gemm weights info to the one returned by has_opt_impl + amp.gemm_info.set_weight_format(expected_weight_format); + -+ memory_desc_t want_wei_md = wei_md; -+ -+ // We need to transpose second to last dimension and use blocking -+ // as returned by interleave by from expecting strides -+ int interleaved_by = arm_compute::interleave_by(expected_weight_format); -+ int block_by = arm_compute::block_by(expected_weight_format); -+ bool is_fast_math_kernel = arm_compute::is_fixed_format_fast_math(expected_weight_format); -+ if(!is_fast_math_kernel) { -+ amp.gemm_info.set_fast_math(false); -+ } -+ -+ int blocked_first_dimension = -1; -+ int blocked_second_dimension = -1; -+ -+ // Assume that interleaved by is X and blocked by is Y -+ switch(want_wei_md.ndims) { -+ case 2: { -+ // For 2D case the format that we need to pass is BaXb and -+ // when doing fast mode BAXbYa -+ want_wei_md.format_desc.blocking.strides[0] = interleaved_by * block_by; -+ // check to see whether we need to pad -+ if(want_wei_md.dims[0] % block_by != 0) { -+ want_wei_md.padded_dims[0] = utils::div_up(want_wei_md.dims[0], block_by) * block_by; -+ } -+ want_wei_md.format_desc.blocking.strides[1] = interleaved_by * want_wei_md.padded_dims[0]; -+ if(want_wei_md.dims[1] % interleaved_by != 0) { -+ want_wei_md.padded_dims[1] = utils::div_up(want_wei_md.dims[1], interleaved_by) * interleaved_by; -+ } -+ -+ acl_utils::update_strides_y_and_z( -+ amp.wei_info, -+ want_wei_md.format_desc.blocking.strides[1] * wei_d.data_type_size(), -+ want_wei_md.format_desc.blocking.strides[0] * wei_d.data_type_size()); -+ -+ blocked_first_dimension = 1; -+ blocked_second_dimension = 0; -+ -+ break; -+ } -+ -+ case 3: { -+ // For 3D case the format we need to pass is aCbXc and -+ // when doing fast mode is aCBXcYb -+ want_wei_md.format_desc.blocking.strides[1] = interleaved_by*block_by; -+ if(want_wei_md.dims[1] % block_by != 0) { -+ want_wei_md.padded_dims[1] = utils::div_up(want_wei_md.dims[1], block_by) * block_by; -+ } -+ want_wei_md.format_desc.blocking.strides[2] = interleaved_by * want_wei_md.padded_dims[1]; -+ if(want_wei_md.dims[2] % interleaved_by != 0) { -+ want_wei_md.padded_dims[2] = utils::div_up(want_wei_md.dims[2], interleaved_by) * interleaved_by; -+ } -+ want_wei_md.format_desc.blocking.strides[0] = want_wei_md.padded_dims[2] * want_wei_md.padded_dims[1]; -+ -+ acl_utils::update_strides_y_and_z( -+ amp.wei_info, -+ want_wei_md.format_desc.blocking.strides[2] * wei_d.data_type_size(), -+ want_wei_md.format_desc.blocking.strides[0] * wei_d.data_type_size()); -+ -+ blocked_first_dimension = 2; -+ blocked_second_dimension = 1; -+ -+ break; -+ } -+ -+ case 4: { -+ // For 4D case the format we need to pass is abDcXd and -+ // when doing fast mode is abDCxdYc -+ int D_padded = want_wei_md.dims[3]; -+ if(D_padded % interleaved_by != 0) { -+ D_padded = utils::div_up(D_padded, interleaved_by) * interleaved_by; -+ want_wei_md.padded_dims[3] = D_padded; -+ } -+ -+ int C_padded = want_wei_md.dims[2]; -+ if(C_padded % block_by != 0) { -+ C_padded = utils::div_up(C_padded, block_by) * block_by; -+ want_wei_md.padded_dims[2] = C_padded; -+ } -+ -+ want_wei_md.format_desc.blocking.strides[0] = want_wei_md.dims[1]*D_padded*C_padded; -+ want_wei_md.format_desc.blocking.strides[1] = D_padded*C_padded; -+ want_wei_md.format_desc.blocking.strides[2] = interleaved_by*block_by; -+ want_wei_md.format_desc.blocking.strides[3] = interleaved_by*C_padded; -+ -+ acl_utils::update_strides_y_and_z( -+ amp.wei_info, -+ want_wei_md.format_desc.blocking.strides[3] * wei_d.data_type_size(), -+ want_wei_md.format_desc.blocking.strides[1] * wei_d.data_type_size()); ++ // has_opt_impl may return a non fast math kernel, even if we requested one ++ amp.gemm_info.set_fast_math( ++ arm_compute::is_fixed_format_fast_math(expected_weight_format)); + -+ blocked_first_dimension = 3; -+ blocked_second_dimension = 2; ++ // Logical dimension indices ++ dim_t innermost_dim = wei_md.ndims - 1; ++ dim_t N_dim = innermost_dim; ++ dim_t K_dim = innermost_dim - 1; + -+ break; -+ } -+ -+ default: -+ return status::unimplemented; -+ } -+ -+ want_wei_md.format_desc.blocking.inner_nblks = (block_by > 1) + 1; -+ want_wei_md.format_desc.blocking.inner_idxs[0] = blocked_first_dimension; -+ want_wei_md.format_desc.blocking.inner_blks[0] = interleaved_by; -+ if(block_by > 1) { -+ want_wei_md.format_desc.blocking.inner_idxs[1] = blocked_second_dimension; -+ want_wei_md.format_desc.blocking.inner_blks[1] = block_by; -+ } -+ -+ if(is_fast_math_kernel) { -+ want_wei_md.data_type = dnnl_bf16; -+ } -+ -+ wei_md = want_wei_md; ++ // The logical indices of dimensions related to the batch, ordered from ++ // innermost to outermost ++ std::vector batch_dims = {}; ++ for (dim_t i = K_dim - 1; i >= 0; --i) ++ batch_dims.push_back(i); + ++ acl_utils::reorder_to_weight_format(amp.wei_tensor_info, wei_md, ++ expected_weight_format, K_dim, N_dim, {}, batch_dims); + return status::success; } +diff --git a/src/cpu/aarch64/matmul/acl_matmul_utils.hpp b/src/cpu/aarch64/matmul/acl_matmul_utils.hpp +index 0a5ee6a987..67bb2e78eb 100644 +--- a/src/cpu/aarch64/matmul/acl_matmul_utils.hpp ++++ b/src/cpu/aarch64/matmul/acl_matmul_utils.hpp +@@ -1,5 +1,5 @@ + /******************************************************************************* +-* Copyright 2021-2022 Arm Ltd. and affiliates ++* Copyright 2021-2023 Arm Ltd. and affiliates + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. +@@ -29,25 +29,21 @@ namespace aarch64 { + struct acl_matmul_obj_t { + arm_compute::NEGEMM gemm; + arm_compute::NETranspose transA; +- arm_compute::NETranspose transB; + arm_compute::Tensor src_tensor; + arm_compute::Tensor src_acc_tensor; + arm_compute::Tensor wei_tensor; +- arm_compute::Tensor wei_acc_tensor; + arm_compute::Tensor dst_tensor; + }; + struct acl_matmul_conf_t { + bool is_transA; +- bool is_transB; + // If this is true, the result of the matmul goes into a temporarily + // allocated ACL tensor to be accumulated into the oneDNN dst during postops + bool use_dst_acc; +- arm_compute::TensorInfo src_info; ++ arm_compute::TensorInfo src_tensor_info; + arm_compute::TensorInfo src_acc_info; +- arm_compute::TensorInfo wei_info; +- arm_compute::TensorInfo wei_acc_info; +- arm_compute::TensorInfo dst_info; ++ arm_compute::TensorInfo wei_tensor_info; ++ arm_compute::TensorInfo dst_tensor_info; + arm_compute::GEMMInfo gemm_info; + float alpha; + }; diff --git a/third_party/mkl_dnn/onednn_acl_remove_winograd.patch b/third_party/mkl_dnn/onednn_acl_remove_winograd.patch new file mode 100644 index 00000000000000..18abcc8f54e922 --- /dev/null +++ b/third_party/mkl_dnn/onednn_acl_remove_winograd.patch @@ -0,0 +1,326 @@ + ******************************************************************************* + Copyright 2023 Arm Limited and affiliates. + SPDX-License-Identifier: Apache-2.0 + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ******************************************************************************* +diff --git a/src/cpu/aarch64/acl_convolution_utils.cpp b/src/cpu/aarch64/acl_convolution_utils.cpp +index c46d697575..37f8ecbc06 100644 +--- a/src/cpu/aarch64/acl_convolution_utils.cpp ++++ b/src/cpu/aarch64/acl_convolution_utils.cpp +@@ -271,54 +271,6 @@ status_t init_conf_indirect_gemm(acl_conv_conf_t &acp, memory_desc_t &src_md, + return status::success; + } + +-status_t init_conf_wino(acl_conv_conf_t &acp, memory_desc_t &src_md, +- memory_desc_t &weights_md, memory_desc_t &dst_md, +- memory_desc_t &bias_md, const convolution_desc_t &cd, +- const primitive_attr_t &attr) { +- +- // Under these conditions, fallback to faster GEMM-based convolution +- // unless the user explicitly specifies Winograd algorithm +- // clang-format off +- if (one_of(true, src_md.dims[2] > 112, // ih +- src_md.dims[3] > 112, // iw +- src_md.dims[1] < 64, // ic +- dst_md.dims[1] < 64, // oc +- dnnl_get_max_threads() > 28) +- && cd.alg_kind == alg_kind::convolution_auto) { +- return status::unimplemented; +- } +- // clang-format on +- +- // General Compute Library checks, memory tags are also set there +- CHECK(acl_init_conf(acp, src_md, weights_md, dst_md, bias_md, cd, attr)); +- +- const bool shape_ok +- // only unit strides allowed +- = (acp.padstride_info.stride() == std::pair {1, 1}) +- // Note: Compute Library supports arbitrary padding for wino kernels +- // but we only allow small padding to be consistent with oneDNN +- && (acp.padstride_info.pad().first <= 1) // padding left/right +- && (acp.padstride_info.pad().second <= 1) // padding top/bottom +- // only non-dilated convolutions allowed +- && (acp.dilation_info == arm_compute::Size2D(1, 1)); +- +- ACL_CHECK_SUPPORT(!shape_ok, "shape not supported by winograd kernels"); +- +- // clang-format off +- // Validate convolution manually to check for return status +- ACL_CHECK_VALID(arm_compute::NEWinogradConvolutionLayer::validate( +- &acp.src_info, +- &acp.wei_info, +- acp.with_bias ? &acp.bia_info : nullptr, +- &acp.dst_info, +- acp.padstride_info, +- acp.act_info, +- true)); // enable_fast_math flag in ACL Winograd +- // clang-format on +- +- return status::success; +-} +- + } // namespace acl_convolution_utils + + } // namespace aarch64 +diff --git a/src/cpu/aarch64/acl_convolution_utils.hpp b/src/cpu/aarch64/acl_convolution_utils.hpp +index 3e56245faf..0398ab06b9 100644 +--- a/src/cpu/aarch64/acl_convolution_utils.hpp ++++ b/src/cpu/aarch64/acl_convolution_utils.hpp +@@ -66,11 +66,6 @@ status_t init_conf_indirect_gemm(acl_conv_conf_t &acp, memory_desc_t &src_md, + memory_desc_t &bias_md, const convolution_desc_t &cd, + const primitive_attr_t &attr); + +-status_t init_conf_wino(acl_conv_conf_t &acp, memory_desc_t &src_md, +- memory_desc_t &weights_md, memory_desc_t &dst_md, +- memory_desc_t &bias_md, const convolution_desc_t &cd, +- const primitive_attr_t &attr); +- + } // namespace acl_convolution_utils + + template _lock {this->mtx}; +- // Retrieve primitive resource and configured Compute Library objects +- auto *acl_resource +- = ctx.get_resource_mapper()->get(this); +- acl_obj_t &acl_wino_obj +- = acl_resource->get_acl_obj(); +- +- return execute_forward_conv_acl< +- acl_obj_t, pd_t, data_t>( +- ctx, acl_wino_obj, pd()); +-} +- +-} // namespace aarch64 +-} // namespace cpu +-} // namespace impl +-} // namespace dnnl +diff --git a/src/cpu/aarch64/acl_winograd_convolution.hpp b/src/cpu/aarch64/acl_winograd_convolution.hpp +deleted file mode 100644 +index 215635fe3f..0000000000 +--- a/src/cpu/aarch64/acl_winograd_convolution.hpp ++++ /dev/null +@@ -1,146 +0,0 @@ +-/******************************************************************************* +-* Copyright 2020-2022 Arm Ltd. and affiliates +-* +-* Licensed under the Apache License, Version 2.0 (the "License"); +-* you may not use this file except in compliance with the License. +-* You may obtain a copy of the License at +-* +-* http://www.apache.org/licenses/LICENSE-2.0 +-* +-* Unless required by applicable law or agreed to in writing, software +-* distributed under the License is distributed on an "AS IS" BASIS, +-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-* See the License for the specific language governing permissions and +-* limitations under the License. +-*******************************************************************************/ +- +-#ifndef CPU_AARCH64_ACL_WINOGRAD_CONVOLUTION_HPP +-#define CPU_AARCH64_ACL_WINOGRAD_CONVOLUTION_HPP +- +-#include "cpu/cpu_convolution_pd.hpp" +- +-#include "cpu/aarch64/acl_convolution_utils.hpp" +- +-namespace dnnl { +-namespace impl { +-namespace cpu { +-namespace aarch64 { +- +-struct acl_wino_resource_t : public resource_t { +- acl_wino_resource_t() +- : acl_wino_obj_(utils::make_unique< +- acl_obj_t>()) {} +- +- status_t configure(const acl_conv_conf_t &acp) { +- if (!acl_wino_obj_) return status::out_of_memory; +- +- // Init Compute Library tensors based on info from descriptor +- acl_wino_obj_->src_tensor.allocator()->init(acp.src_info); +- acl_wino_obj_->wei_tensor.allocator()->init(acp.wei_info); +- acl_wino_obj_->dst_tensor.allocator()->init(acp.dst_info); +- acl_wino_obj_->bia_tensor.allocator()->init(acp.bia_info); +- +- // clang-format off +- acl_wino_obj_->conv.configure( +- &acl_wino_obj_->src_tensor, +- &acl_wino_obj_->wei_tensor, +- acp.with_bias ? &acl_wino_obj_->bia_tensor : nullptr, +- &acl_wino_obj_->dst_tensor, +- acp.padstride_info, +- acp.act_info, +- true); // to support 5x5, 7x7 filter shapes in addition to 3x3 +- // clang-format on +- +- return status::success; +- } +- +- acl_obj_t &get_acl_obj() const { +- return *acl_wino_obj_; +- } +- +- DNNL_DISALLOW_COPY_AND_ASSIGN(acl_wino_resource_t); +- +-private: +- std::unique_ptr> +- acl_wino_obj_; +-}; // acl_wino_resource_t +- +-struct acl_wino_convolution_fwd_t : public primitive_t { +- struct pd_t : public cpu_convolution_fwd_pd_t { +- pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, +- const typename pd_t::base_class *hint_fwd_pd) +- : cpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd) +- , acp_() +- , post_ops() {} +- +- DECLARE_COMMON_PD_T( +- "wino:acl", acl_wino_convolution_fwd_t, USE_GLOBAL_SCRATCHPAD); +- +- status_t init(engine_t *engine) { +- bool ok = is_fwd() +- && utils::one_of(desc()->alg_kind, +- alg_kind::convolution_auto, +- alg_kind::convolution_winograd) +- && expect_data_types(data_type::f32, data_type::f32, +- data_type::f32, data_type::f32, data_type::f32) +- && attr()->has_default_values( +- primitive_attr_t::skip_mask_t::post_ops, +- data_type::f32) +- && !has_zero_dim_memory(); +- if (!ok) return status::unimplemented; +- +- CHECK(acl_convolution_utils::init_conf_wino(acp_, src_md_, +- weights_md_, dst_md_, bias_md_, *desc(), *attr())); +- +- set_default_alg_kind(alg_kind::convolution_winograd); +- +- CHECK(post_ops.init( +- engine, attr_.post_ops_, dst_md_, acp_.act_info)); +- acp_.use_dst_acc = post_ops.has_sum(); +- +- return status::success; +- } +- +- acl_conv_conf_t acp_; +- acl_post_ops_t post_ops; +- }; +- +- acl_wino_convolution_fwd_t(const pd_t *apd) : primitive_t(apd) {} +- +- status_t create_resource( +- engine_t *engine, resource_mapper_t &mapper) const override { +- if (mapper.has_resource(this)) return status::success; +- +- auto r = utils::make_unique(); +- if (!r) return status::out_of_memory; +- +- // Configure the resource based on information from primitive descriptor +- CHECK(r->configure(pd()->acp_)); +- mapper.add(this, std::move(r)); +- +- CHECK(pd()->post_ops.create_resource(engine, mapper)); +- +- return status::success; +- } +- +- ~acl_wino_convolution_fwd_t() {} +- +- typedef typename prec_traits::type data_t; +- +- status_t execute(const exec_ctx_t &ctx) const override { +- return execute_forward(ctx); +- } +- +-private: +- // To guard the const execute_forward(), the mutex must be 'mutable' +- mutable std::mutex mtx; +- status_t execute_forward(const exec_ctx_t &ctx) const; +- const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } +-}; // acl_wino_convolution_fwd_t +- +-} // namespace aarch64 +-} // namespace cpu +-} // namespace impl +-} // namespace dnnl +- +-#endif // CPU_AARCH64_ACL_WINOGRAD_CONVOLUTION_HPP +diff --git a/src/cpu/cpu_convolution_list.cpp b/src/cpu/cpu_convolution_list.cpp +index 4142dbc7e7..094c73aa36 100644 +--- a/src/cpu/cpu_convolution_list.cpp ++++ b/src/cpu/cpu_convolution_list.cpp +@@ -65,7 +65,6 @@ using namespace dnnl::impl::cpu::x64; + #if DNNL_AARCH64 && DNNL_AARCH64_USE_ACL + #include "cpu/aarch64/acl_gemm_convolution.hpp" + #include "cpu/aarch64/acl_indirect_gemm_convolution.hpp" +-#include "cpu/aarch64/acl_winograd_convolution.hpp" + #endif + using namespace dnnl::impl::cpu::aarch64; + #endif +@@ -100,7 +99,6 @@ const std::map> &impl_list_map() + CPU_INSTANCE_SSE41(jit_sse41_1x1_convolution_fwd_t) + CPU_INSTANCE_AVX2(jit_avx2_convolution_fwd_t) + CPU_INSTANCE_SSE41(jit_sse41_convolution_fwd_t) +- CPU_INSTANCE_AARCH64_ACL(acl_wino_convolution_fwd_t) + CPU_INSTANCE_AARCH64(jit_sve_512_dw_convolution_fwd_t) + CPU_INSTANCE_AARCH64(jit_sve_512_1x1_convolution_fwd_f32_t) + CPU_INSTANCE_AARCH64(jit_sve_512_convolution_fwd_t) +diff --git a/tests/gtests/test_iface_wino_convolution.cpp b/tests/gtests/test_iface_wino_convolution.cpp +index 03861b1de4..2235ceae36 100644 +--- a/tests/gtests/test_iface_wino_convolution.cpp ++++ b/tests/gtests/test_iface_wino_convolution.cpp +@@ -59,9 +59,6 @@ protected: + input_f16.wino_supported = is_gpu; + input_int8.wino_supported = is_cpu && has_avx512_core; + input_f32.backward_supported = is_cpu && impl::dnnl_thr_syncable(); +-#elif DNNL_AARCH64 && DNNL_AARCH64_USE_ACL +- const bool is_cpu = get_test_engine_kind() == engine::kind::cpu; +- input_f32.wino_supported = is_cpu; + #endif + + #else diff --git a/third_party/mkl_dnn/onednn_acl_reorder.patch b/third_party/mkl_dnn/onednn_acl_reorder.patch new file mode 100644 index 00000000000000..7241aca4eefc88 --- /dev/null +++ b/third_party/mkl_dnn/onednn_acl_reorder.patch @@ -0,0 +1,370 @@ + ******************************************************************************* + Copyright 2023 Arm Limited and affiliates. + SPDX-License-Identifier: Apache-2.0 + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ******************************************************************************* +diff --git a/src/cpu/aarch64/acl_reorder.cpp b/src/cpu/aarch64/acl_reorder.cpp +new file mode 100644 +index 0000000000..061751b555 +--- /dev/null ++++ b/src/cpu/aarch64/acl_reorder.cpp +@@ -0,0 +1,52 @@ ++/******************************************************************************* ++* Copyright 2023 Arm Ltd. and affiliates ++* ++* Licensed under the Apache License, Version 2.0 (the "License"); ++* you may not use this file except in compliance with the License. ++* You may obtain a copy of the License at ++* ++* http://www.apache.org/licenses/LICENSE-2.0 ++* ++* Unless required by applicable law or agreed to in writing, software ++* distributed under the License is distributed on an "AS IS" BASIS, ++* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++* See the License for the specific language governing permissions and ++* limitations under the License. ++*******************************************************************************/ ++ ++#include "cpu/aarch64/acl_reorder.hpp" ++ ++namespace dnnl { ++namespace impl { ++namespace cpu { ++namespace aarch64 { ++ ++status_t acl_reorder_fwd_t::execute_forward(const exec_ctx_t &ctx) const { ++ // Lock here is needed because resource_mapper does not support ++ // concurrent multithreaded access. ++ std::lock_guard _lock {this->mtx}; ++ ++ auto src = CTX_IN_MEM(const void *, DNNL_ARG_FROM); ++ auto dst = CTX_OUT_MEM(void *, DNNL_ARG_TO); ++ ++ // Retrieve primitive resource and configured Compute Library objects ++ auto *acl_resource ++ = ctx.get_resource_mapper()->get(this); ++ ++ acl_reorder_obj_t &acl_obj = acl_resource->get_acl_obj(); ++ ++ acl_obj.src_tensor.allocator()->import_memory(const_cast(src)); ++ acl_obj.dst_tensor.allocator()->import_memory(dst); ++ ++ acl_obj.reorder.run(); ++ ++ acl_obj.src_tensor.allocator()->free(); ++ acl_obj.dst_tensor.allocator()->free(); ++ ++ return status::success; ++} ++ ++} // namespace aarch64 ++} // namespace cpu ++} // namespace impl ++} // namespace dnnl +diff --git a/src/cpu/aarch64/acl_reorder.hpp b/src/cpu/aarch64/acl_reorder.hpp +new file mode 100644 +index 0000000000..edbc38914d +--- /dev/null ++++ b/src/cpu/aarch64/acl_reorder.hpp +@@ -0,0 +1,262 @@ ++/******************************************************************************* ++* Copyright 2023 Arm Ltd. and affiliates ++* ++* Licensed under the Apache License, Version 2.0 (the "License"); ++* you may not use this file except in compliance with the License. ++* You may obtain a copy of the License at ++* ++* http://www.apache.org/licenses/LICENSE-2.0 ++* ++* Unless required by applicable law or agreed to in writing, software ++* distributed under the License is distributed on an "AS IS" BASIS, ++* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++* See the License for the specific language governing permissions and ++* limitations under the License. ++*******************************************************************************/ ++#ifndef CPU_AARCH64_ACL_REORDER_HPP ++#define CPU_AARCH64_ACL_REORDER_HPP ++ ++#include "cpu/aarch64/acl_utils.hpp" ++#include "cpu/reorder/cpu_reorder_pd.hpp" ++#include "arm_compute/core/Types.h" ++#include "common/utils.hpp" ++ ++namespace dnnl { ++namespace impl { ++namespace cpu { ++namespace aarch64 { ++ ++struct acl_reorder_obj_t { ++ arm_compute::NEReorderLayer reorder; ++ arm_compute::Tensor src_tensor; ++ arm_compute::Tensor dst_tensor; ++ arm_compute::WeightFormat src_wf; ++ arm_compute::WeightFormat dst_wf; ++}; ++ ++struct acl_reorder_conf_t { ++ arm_compute::TensorInfo src_info; ++ arm_compute::TensorInfo dst_info; ++ arm_compute::WeightFormat src_wf; ++ arm_compute::WeightFormat dst_wf; ++}; ++ ++struct acl_reorder_resource_t : public resource_t { ++ acl_reorder_resource_t() : acl_obj_(utils::make_unique()) {} ++ ++ status_t configure(const acl_reorder_conf_t &app) { ++ if (!acl_obj_) return status::out_of_memory; ++ ++ // Init Compute Library tensors based on info from descriptor ++ acl_obj_->src_tensor.allocator()->init(app.src_info); ++ acl_obj_->dst_tensor.allocator()->init(app.dst_info); ++ ++ // clang-format off ++ acl_obj_->reorder.configure( ++ &acl_obj_->src_tensor, ++ &acl_obj_->dst_tensor, ++ app.src_wf, ++ app.dst_wf ++ ); ++ // clang-format on ++ ++ return status::success; ++ } ++ ++ acl_reorder_obj_t &get_acl_obj() const { return *acl_obj_; } ++ DNNL_DISALLOW_COPY_AND_ASSIGN(acl_reorder_resource_t); ++ ++private: ++ std::unique_ptr acl_obj_; ++}; // acl_reorder_resource_t ++ ++struct acl_reorder_fwd_t : public primitive_t { ++ using primitive_t::primitive_t; ++ struct pd_t : public cpu_reorder_pd_t { ++ ++ using cpu_reorder_pd_t::cpu_reorder_pd_t; ++ ++ DECLARE_COMMON_PD_T("acl", acl_reorder_fwd_t); ++ ++ static status_t create(reorder_pd_t **reorder_pd, engine_t *engine, ++ const primitive_attr_t *attr, engine_t *src_engine, ++ const memory_desc_t *src_md, engine_t *dst_engine, ++ const memory_desc_t *dst_md) { ++ ++ using namespace acl_utils; ++ // using skip_mask_t = dnnl_primitive_attr::skip_mask_t; ++ ++ bool ok = src_md->data_type ++ == dst_md->data_type // ACL only supports matching src/dst data types ++ && utils::one_of(src_md->data_type, ++ data_type::f32) // Only supports f32 for now ++ && attr->has_default_values(); ++ if (!ok) return status::unimplemented; ++ ++ int mask = -1; ++ bool is_set = false; ++ // CHECK(attr->scales_.get(DNNL_ARG_DST, &mask, &is_set)); ++ const memory_desc_wrapper input_d(src_md); ++ if (input_d.has_runtime_dims_or_strides() && is_set && mask > 0) ++ return status::unimplemented; ++ ++ // Create and check primitive descriptor ++ auto _pd = new pd_t(attr, src_engine->kind(), src_md, ++ dst_engine->kind(), dst_md); ++ if (_pd == nullptr) return status::out_of_memory; ++ if (_pd->init(engine, src_engine, dst_engine) != status::success) { ++ delete _pd; ++ return status::unimplemented; ++ } ++ ++ const memory_desc_wrapper src_d(*src_md); ++ const memory_desc_wrapper dst_d(*dst_md); ++ ++ const int ndims = src_d.ndims(); ++ ++ auto src_tag = memory_desc_matches_one_of_tag( ++ *src_md, format_tag::ba, format_tag::cdba); ++ ACL_CHECK_SUPPORT( ++ utils::one_of(format_tag::undef, src_tag), ++ ""); ++ ++ arm_compute::TensorShape acl_tensor_shape_in; ++ arm_compute::TensorShape acl_tensor_shape_out; ++ // Need even amount of dims in dim 0 for ACL kernel (eg mulitple of 8 rows when blocking by 8) ++ int dim_0_rounded_up; ++ ++ // Switch for 2 or 4 dim tensors ++ switch(ndims) ++ { ++ // Currently for Ab4a and Ab8a ++ // No format_tag for these, have to deduce from stride ++ case 2: ++ { ++ if(dst_md->dims[0] == 1 || dst_md->dims[1] == 1){ ++ return status::unimplemented; ++ } ++ int dst_dim_1 = dst_md->dims[1]; ++ int dst_dim_0_stride = dst_md->format_desc.blocking.strides[0]; ++ int dst_dim_1_stride = dst_md->format_desc.blocking.strides[1]; ++ // Interleave of 4 or 8 that stride for dim 1 ++ if (dst_dim_1_stride != 4 && dst_dim_1_stride != 8){ ++ return status::unimplemented; ++ } ++ // Check to ensure it's a blocking transpose ++ if (dst_dim_1 * dst_dim_1_stride != dst_dim_0_stride){ ++ return status::unimplemented; ++ } ++ if(dst_dim_1_stride == 4){ ++ // Set Dest WeightFormat ++ _pd->app_.dst_wf = arm_compute::WeightFormat::OHWIo4; ++ dim_0_rounded_up ++ = utils::rnd_up(src_md->dims[0], 4); ++ } else { ++ // Set Dest WeightFormat ++ _pd->app_.dst_wf = arm_compute::WeightFormat::OHWIo8; ++ dim_0_rounded_up ++ = utils::rnd_up(src_md->dims[0], 8); ++ } ++ acl_tensor_shape_in = arm_compute::TensorShape(src_md->dims[1], src_md->dims[0]); ++ acl_tensor_shape_out = arm_compute::TensorShape(src_md->dims[1], dim_0_rounded_up); ++ ++ break; ++ } ++ // Currently for Acdb4a and Acdb8a ++ case 4: ++ { ++ ++ auto dst_tag = memory_desc_matches_one_of_tag( ++ *dst_md, format_tag::Acdb4a, format_tag::Acdb8a); ++ ACL_CHECK_SUPPORT( ++ utils::one_of(format_tag::undef, dst_tag), ++ ""); ++ if(dst_tag == format_tag::Acdb4a){ ++ // Set Dest WeightFormat ++ _pd->app_.dst_wf = arm_compute::WeightFormat::OHWIo4; ++ dim_0_rounded_up ++ = utils::rnd_up(src_md->dims[0], 4); ++ } ++ else{ ++ // Set Dest WeightFormat ++ _pd->app_.dst_wf = arm_compute::WeightFormat::OHWIo8; ++ dim_0_rounded_up ++ = utils::rnd_up(src_md->dims[0], 8); ++ } ++ // Currently only supporting AxBx1x1 cases ++ if(dst_md->dims[2] != 1 || dst_md->dims[3] != 1){ ++ return status::unimplemented; ++ } ++ if(dst_md->dims[0] == 1 || dst_md->dims[1] == 1){ ++ return status::unimplemented; ++ } ++ acl_tensor_shape_in = arm_compute::TensorShape(src_md->dims[3], src_md->dims[2], src_md->dims[1], src_md->dims[0]); ++ acl_tensor_shape_out = arm_compute::TensorShape(src_md->dims[3], src_md->dims[2], src_md->dims[1], dim_0_rounded_up); ++ break; ++ } ++ default: ++ return status::unimplemented; ++ } ++ ++ // Choose the data layout ++ // bool is_nspc = utils::one_of(src_tag, format_tag::nhwc); ++ const auto acl_layout = arm_compute::DataLayout::NCHW; ++ ++ // Set Source WeightFormat ++ _pd->app_.src_wf = arm_compute::WeightFormat::OHWI; ++ ++ // Create ACL tensor infos ++ const data_type_t data_type = src_d.data_type(); ++ const arm_compute::DataType acl_data_t ++ = acl_utils::get_acl_data_t(data_type); ++ _pd->app_.src_info = arm_compute::TensorInfo( ++ acl_tensor_shape_in, 1, acl_data_t, acl_layout); ++ _pd->app_.dst_info = arm_compute::TensorInfo( ++ acl_tensor_shape_out, 1, acl_data_t, acl_layout); ++ ++ // Init scratch memory, not used so 0 in this implementation ++ _pd->init_scratchpad_md(); ++ ++ return safe_ptr_assign(*reorder_pd, _pd); ++ } // create ++ ++ friend dnnl::impl::impl_list_item_t; ++ acl_reorder_conf_t app_; ++ ++ }; // pd_t ++ ++ acl_reorder_fwd_t(const pd_t *apd) : primitive_t(apd) {} ++ ++ status_t create_resource( ++ engine_t *engine, resource_mapper_t &mapper) const override { ++ if (mapper.has_resource(this)) return status::success; ++ ++ auto r = utils::make_unique(); ++ if (!r) return status::out_of_memory; ++ ++ // Configure the resource based on information from primitive descriptor ++ CHECK(r->configure(pd()->app_)); ++ ++ mapper.add(this, std::move(r)); ++ return status::success; ++ } ++ ++ status_t execute(const exec_ctx_t &ctx) const override { ++ return execute_forward(ctx); ++ } ++ ++private: ++ // To guard the const execute_forward, the mutex must be 'mutable' ++ mutable std::mutex mtx; ++ status_t execute_forward(const exec_ctx_t &ctx) const; ++ const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } ++ ++ ++}; // acl_reorder_fwd_t ++ ++} // namespace aarch64 ++} // namespace cpu ++} // namespace impl ++} // namespace dnnl ++ ++#endif // CPU_AARCH64_ACL_REORDER_HPP +diff --git a/src/cpu/reorder/cpu_reorder_regular_f32_f32.cpp b/src/cpu/reorder/cpu_reorder_regular_f32_f32.cpp +index bccd2f75f4..5e5ea331ba 100644 +--- a/src/cpu/reorder/cpu_reorder_regular_f32_f32.cpp ++++ b/src/cpu/reorder/cpu_reorder_regular_f32_f32.cpp +@@ -15,6 +15,7 @@ + *******************************************************************************/ + + #include "cpu/reorder/cpu_reorder.hpp" ++#include "cpu/aarch64/acl_reorder.hpp" + + namespace dnnl { + namespace impl { +@@ -27,6 +28,7 @@ const impl_list_map_t ®ular_f32_f32_impl_list_map() { + // f32 -> f32 + {{f32, f32, 0}, { + REG_FAST_DIRECT_COPY_F32_F32 ++ DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::acl_reorder_fwd_t)) + + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) +@@ -64,6 +66,7 @@ const impl_list_map_t ®ular_f32_f32_impl_list_map() { + nullptr, + }}, + {{f32, f32, 4}, { ++ DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::acl_reorder_fwd_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::wino_reorder_t)) + + CPU_REORDER_INSTANCE(rnn_weights_reorder_t) diff --git a/third_party/mkl_dnn/onednn_acl_threadpool_scheduler.patch b/third_party/mkl_dnn/onednn_acl_threadpool_scheduler.patch index 7e3725af270292..0e0cb39e82f1bb 100644 --- a/third_party/mkl_dnn/onednn_acl_threadpool_scheduler.patch +++ b/third_party/mkl_dnn/onednn_acl_threadpool_scheduler.patch @@ -1,3 +1,20 @@ + ******************************************************************************* + Copyright 2023 Arm Limited and affiliates. + SPDX-License-Identifier: Apache-2.0 + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ******************************************************************************* + diff --git a/src/cpu/aarch64/acl_threadpool_scheduler.cpp b/src/cpu/aarch64/acl_threadpool_scheduler.cpp index 418d7f30f..439ca862e 100644 --- a/src/cpu/aarch64/acl_threadpool_scheduler.cpp From dc9ab19efad29836f23e5d90b91bc3790684c099 Mon Sep 17 00:00:00 2001 From: johnnkp <22496821+johnnkp@users.noreply.github.com> Date: Tue, 25 Jul 2023 00:17:39 +0800 Subject: [PATCH 057/410] gpu_kernel_helper.h: Code cleanup https://github.com/tensorflow/tensorflow/pull/61339#discussion_r1272461821 --- tensorflow/core/util/gpu_kernel_helper.h | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/tensorflow/core/util/gpu_kernel_helper.h b/tensorflow/core/util/gpu_kernel_helper.h index 7f22d65360696a..ba6ae9ab153fa0 100644 --- a/tensorflow/core/util/gpu_kernel_helper.h +++ b/tensorflow/core/util/gpu_kernel_helper.h @@ -180,18 +180,6 @@ __host__ __device__ inline int tf_min(int x, int y) { __host__ __device__ inline int tf_max(int x, int y) { return max(x, y); } -__host__ __device__ inline int tf_min(unsigned int x, int y) { - return min(static_cast(x), y); -} -__host__ __device__ inline int tf_max(unsigned int x, int y) { - return max(static_cast(x), y); -} -__host__ __device__ inline int tf_min(int x, unsigned int y) { - return min(x, static_cast(y)); -} -__host__ __device__ inline int tf_max(int x, unsigned int y) { - return max(x, static_cast(y)); -} #endif #endif From 3192751883708e8560ca0ff2039a78ca443c3b14 Mon Sep 17 00:00:00 2001 From: "guozhong.zhuang" Date: Mon, 24 Jul 2023 11:06:07 -0700 Subject: [PATCH 058/410] Refine code per Penport's review suggestion --- tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h b/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h index b9cfe85369e628..3db1738a9b4462 100644 --- a/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h +++ b/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h @@ -79,7 +79,7 @@ inline bool ExecuteSingleThreadedGemm(int64_t m, int64_t n, int64_t k, constexpr float kHeuristicMultiplier = 1.01; const float mul_size = bytes * (m * n + k * (m + n)); const float l2_heur = l2_size * kHeuristicMultiplier; - return (!(mul_size < 0) && (mul_size < l2_heur)); + return (mul_size >= 0 && mul_size < l2_heur); } // This structure aggregates multiple inputs to MklDnnMatMul* methods. From 865213a30aa27e868c2beb4181b1c313b0cff818 Mon Sep 17 00:00:00 2001 From: Anlun Xu Date: Mon, 24 Jul 2023 12:24:20 -0700 Subject: [PATCH 059/410] [xla:gpu] NFC: Convert AddConcurrentRegionsPass to a function pass The pass only operates on function but not the whole module. PiperOrigin-RevId: 550639561 --- .../gpu/transforms/add_concurrent_regions.cc | 23 +++++++++++-------- .../xla/mlir/backends/gpu/transforms/passes.h | 2 +- .../mlir/backends/gpu/transforms/passes.td | 2 +- 3 files changed, 15 insertions(+), 12 deletions(-) diff --git a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/add_concurrent_regions.cc b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/add_concurrent_regions.cc index bdb3944d32d64e..2a8457cf1630df 100644 --- a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/add_concurrent_regions.cc +++ b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/add_concurrent_regions.cc @@ -202,22 +202,25 @@ void InsertConcurrentRegions(FuncOp capture_func, //===----------------------------------------------------------------------===// void AddConcurrentRegionsPass::runOnOperation() { - FuncOp func_op = getOperation(); + ModuleOp module = getOperation(); + SymbolTable sym_table(module); + CustomCallDeclarations custom_calls(std::move(sym_table)); - if (!absl::StrContains(func_op.getSymNameAttr().str(), - "xla.gpu.cuda.graph.capture")) { - return; - } + auto func_ops = llvm::to_vector(module.getOps()); - SymbolTable sym_table(func_op->getParentOfType()); - CustomCallDeclarations custom_calls(std::move(sym_table)); - InsertConcurrentRegions(func_op, custom_calls, - getAnalysis()); + for (auto func_op : func_ops) { + // Find the cuda graph capture function. + if (absl::StrContains(func_op.getSymNameAttr().str(), + "xla.gpu.cuda.graph.capture")) { + InsertConcurrentRegions(func_op, custom_calls, + getAnalysis()); + } + } } } // namespace -std::unique_ptr> createAddConcurrentRegionsPass() { +std::unique_ptr> createAddConcurrentRegionsPass() { return std::make_unique(); } diff --git a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.h b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.h index 56a1e01b7db6dc..7d3ea262443770 100644 --- a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.h +++ b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.h @@ -110,7 +110,7 @@ createOutlineCudaGraphsPass(int32_t cuda_graph_level, int32_t min_graph_size); // Passes for marking concurrent region in CUDA graph capture function. //===----------------------------------------------------------------------===// -std::unique_ptr> +std::unique_ptr> createAddConcurrentRegionsPass(); //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.td b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.td index 3e5909236c267b..29cf6991e26431 100644 --- a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.td +++ b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.td @@ -197,7 +197,7 @@ def OutlineCudaGraphsPass : //===----------------------------------------------------------------------===// def AddConcurrentRegionsPass: - Pass<"xla-gpu-add-concurrent-regions", "mlir::func::FuncOp"> { + Pass<"xla-gpu-add-concurrent-regions", "mlir::ModuleOp"> { let summary = "Identify and mark concurrent regions in CUDA graph capture " "functions"; From f3db89465f1eb558daad5bb1e72b3b61954f7de3 Mon Sep 17 00:00:00 2001 From: Gauri1 Deshpande Date: Mon, 24 Jul 2023 12:59:33 -0700 Subject: [PATCH 060/410] add extra blank line --- tensorflow/core/kernels/fused_batch_norm_op.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/kernels/fused_batch_norm_op.cc b/tensorflow/core/kernels/fused_batch_norm_op.cc index 643d949485049c..79ac4484b0a489 100644 --- a/tensorflow/core/kernels/fused_batch_norm_op.cc +++ b/tensorflow/core/kernels/fused_batch_norm_op.cc @@ -1963,7 +1963,6 @@ REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV3") .TypeConstraint("U"), FusedBatchNormGradOpV3); - REGISTER_KERNEL_BUILDER(Name("FusedBatchNormV3") .Device(DEVICE_CPU) .TypeConstraint("T") @@ -1975,6 +1974,7 @@ REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV3") .TypeConstraint("T") .TypeConstraint("U"), FusedBatchNormGradOpV3); + #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM REGISTER_KERNEL_BUILDER( From 86bd340bcdc35551035da355cd5af3be84710011 Mon Sep 17 00:00:00 2001 From: Eugene Burmako Date: Mon, 24 Jul 2023 13:03:38 -0700 Subject: [PATCH 061/410] Integrate StableHLO at openxla/stablehlo@f4c43e12 Manual changes: * CMakeLists.txt: keep MLIR-HLO-only customizations to the CMake build. * BUILD.bazel, stablehlo/dialect/CMakeLists.txt, stablehlo/dialect/Base.h, stablehlo/dialect/Base.cpp, stablehlo/dialect/ExperimentalOps.h, stablehlo/dialect/ExperimentalOps.cpp, stablehlo/transforms/StablehloCanonicalizeDynamism.cpp, stablehlo/transforms/StablehloRefineShapes.cpp, stablehlo/tests/stablehlo_canonicalize_dynamism.mlir, stablehlo/tests/stablehlo_refine_shapes.mlir: keep XLA-only customizations to StableHLO shape refinement. PiperOrigin-RevId: 550649848 --- third_party/stablehlo/temporary.patch | 64 ++++----------------------- third_party/stablehlo/workspace.bzl | 4 +- 2 files changed, 10 insertions(+), 58 deletions(-) diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch index 5e86db006cb512..1b32eb1c8480f3 100644 --- a/third_party/stablehlo/temporary.patch +++ b/third_party/stablehlo/temporary.patch @@ -184,23 +184,15 @@ diff --ruN a/stablehlo/CMakeLists.txt b/stablehlo/CMakeLists.txt diff --ruN a/stablehlo/stablehlo/dialect/Base.cpp b/stablehlo/stablehlo/dialect/Base.cpp --- stablehlo/stablehlo/dialect/Base.cpp +++ stablehlo/stablehlo/dialect/Base.cpp -@@ -155,8 +155,14 @@ - LogicalResult matchInts(Value value, SmallVector& result) { +@@ -156,6 +156,7 @@ DenseIntElementsAttr attr; if (!matchPattern(value, m_Constant(&attr))) return failure(); -+ - // Signless types are treated as signed, per StableHLO convention. -- auto isUnsigned = attr.getType().getElementType().isUnsignedInteger(); -+ // Unless the type is i1 (which models boolean type from the StableHLO spec), -+ // in which case it's considered to be unsigned. -+ auto elementType = attr.getType().getElementType(); -+ auto isUnsigned = elementType.isUnsignedInteger() || -+ elementType.getIntOrFloatBitWidth() == 1; -+ - for (auto element : attr.getValues()) { - result.push_back(APSInt(element, /*isUnsigned=*/isUnsigned)); - } -@@ -594,5 +600,18 @@ + ++ // Signless types are treated as signed, per StableHLO convention. + // Unless the type is i1 (which models boolean type from the StableHLO spec), + // in which case it's considered to be unsigned. + auto elementType = attr.getType().getElementType(); +@@ -599,5 +600,18 @@ return UnrankedTensorType::get(components.getElementType()); } @@ -849,17 +841,6 @@ diff --ruN a/stablehlo/stablehlo/integrations/python/mlir/dialects/StablehloOps. include "stablehlo/dialect/StablehloOps.td" #endif -diff --ruN a/stablehlo/stablehlo/reference/Element.cpp b/stablehlo/stablehlo/reference/Element.cpp ---- stablehlo/stablehlo/reference/Element.cpp -+++ stablehlo/stablehlo/reference/Element.cpp -@@ -18,6 +18,7 @@ - #include - - #include "llvm/ADT/APFloat.h" -+#include "llvm/ADT/APSInt.h" - #include "llvm/Support/Error.h" - #include "mlir/Dialect/Complex/IR/Complex.h" - #include "mlir/IR/BuiltinAttributes.h" diff --ruN a/stablehlo/stablehlo/tests/stablehlo_canonicalize_dynamism.mlir b/stablehlo/stablehlo/tests/stablehlo_canonicalize_dynamism.mlir --- stablehlo/stablehlo/tests/stablehlo_canonicalize_dynamism.mlir +++ stablehlo/stablehlo/tests/stablehlo_canonicalize_dynamism.mlir @@ -1084,36 +1065,7 @@ diff --ruN a/stablehlo/stablehlo/tests/stablehlo_canonicalize_dynamism.mlir b/st diff --ruN a/stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir b/stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir --- stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir +++ stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir -@@ -257,14 +257,26 @@ - - // ----- - --// CHECK-LABEL: func @eval_convert --func.func @eval_convert() -> tensor { -+// CHECK-LABEL: func @eval_convert_common_case -+func.func @eval_convert_common_case() -> tensor { - // CHECK-NOT: stablehlo.convert - // CHECK: [[RESULT:%.*]] = stablehlo.constant dense<4> : tensor - // CHECK: return [[RESULT]] - %0 = stablehlo.constant dense<4> : tensor - %1 = stablehlo.convert %0 : (tensor) -> tensor - func.return %1 : tensor -+} -+ -+// ----- -+ -+// CHECK-LABEL: func @eval_convert_i1 -+func.func @eval_convert_i1() -> tensor<2xi64> { -+ // CHECK-NOT: stablehlo.convert -+ // CHECK: [[RESULT:%.*]] = stablehlo.constant dense<[1, 0]> : tensor<2xi64> -+ // CHECK: return [[RESULT]] -+ %0 = stablehlo.constant dense<[true, false]> : tensor<2xi1> -+ %1 = stablehlo.convert %0 : (tensor<2xi1>) -> tensor<2xi64> -+ return %1 : tensor<2xi64> - } - - // ----- -@@ -595,12 +607,45 @@ +@@ -607,12 +607,45 @@ // ----- diff --git a/third_party/stablehlo/workspace.bzl b/third_party/stablehlo/workspace.bzl index c7d30c5b560974..2ca73b67e5523f 100644 --- a/third_party/stablehlo/workspace.bzl +++ b/third_party/stablehlo/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): # LINT.IfChange - STABLEHLO_COMMIT = "41bad512515d609ccd3896d74bf697e7d456e1d3" - STABLEHLO_SHA256 = "01d143b57efda2fcf5e3482cbd0c4beae2a51164082e0797f0093cdbd8c82b06" + STABLEHLO_COMMIT = "f4c43e121aa94a27eb1b846f2e1d776732203400" + STABLEHLO_SHA256 = "9dacfeb1bdb8468b8736995cffc5db42de04b6d1b2f0572751aa41f10232d9a8" # LINT.ThenChange(Google-internal path) tf_http_archive( From 1f582c3a187be0b408f02ea6ed9539e012774e90 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Mon, 24 Jul 2023 13:04:40 -0700 Subject: [PATCH 062/410] [xla:gpu] Add OpenXLA executable to GpuExecutable XLA_FLAGS=--xla_gpu_enable_openxla_runtime enables OpenXLA/IREE backend for XLA:GPU. Currently it's a proof-of-concept quality, and does not support most of the valid XLA programs and uses legacy sync ABI. OpenXLA backend is disabled in open source build because we do not have bazel dependency configured. PiperOrigin-RevId: 550650123 --- .../compiler/xla/debug_options_flags.cc | 8 + .../conversion/convert_compiled_ops.cc | 4 +- .../mlir/backends/openxla/transforms/BUILD | 12 +- .../mlir/backends/openxla/transforms/passes.h | 25 +- tensorflow/compiler/xla/service/gpu/BUILD | 2 + .../xla/service/gpu/autotuner_compile_util.cc | 2 + .../service/gpu/compile_module_to_llvm_ir.cc | 96 +++++-- .../service/gpu/compile_module_to_llvm_ir.h | 3 +- .../xla/service/gpu/gpu_executable.cc | 76 +++++- .../compiler/xla/service/gpu/gpu_executable.h | 20 +- .../compiler/xla/service/gpu/openxla/BUILD | 99 +++++++ .../xla/service/gpu/openxla/compiler.cc | 238 ++++++++++++++++ .../xla/service/gpu/openxla/compiler.h | 95 +++++++ .../xla/service/gpu/openxla/default/BUILD | 17 ++ .../service/gpu/openxla/default/compiler.h | 33 +++ .../xla/service/gpu/openxla/executable.cc | 253 ++++++++++++++++++ .../xla/service/gpu/openxla/executable.h | 167 ++++++++++++ tensorflow/compiler/xla/xla.proto | 11 +- 18 files changed, 1118 insertions(+), 43 deletions(-) create mode 100644 tensorflow/compiler/xla/service/gpu/openxla/BUILD create mode 100644 tensorflow/compiler/xla/service/gpu/openxla/compiler.cc create mode 100644 tensorflow/compiler/xla/service/gpu/openxla/compiler.h create mode 100644 tensorflow/compiler/xla/service/gpu/openxla/default/BUILD create mode 100644 tensorflow/compiler/xla/service/gpu/openxla/default/compiler.h create mode 100644 tensorflow/compiler/xla/service/gpu/openxla/executable.cc create mode 100644 tensorflow/compiler/xla/service/gpu/openxla/executable.h diff --git a/tensorflow/compiler/xla/debug_options_flags.cc b/tensorflow/compiler/xla/debug_options_flags.cc index 31cd1f3e4a0163..282d3e30e27e8c 100644 --- a/tensorflow/compiler/xla/debug_options_flags.cc +++ b/tensorflow/compiler/xla/debug_options_flags.cc @@ -134,6 +134,9 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_nccl_termination_timeout_seconds(-1); opts.set_xla_gpu_enable_shared_constants(true); + // OpenXLA/IREE runtime flags. + opts.set_xla_gpu_enable_openxla_runtime(false); + // Set 4GB space limit for redzone scratch allocator. opts.set_xla_gpu_redzone_scratch_max_megabytes(1LL << 12); opts.set_xla_gpu_redzone_padding_bytes(8 * 1024 * 1024); @@ -961,6 +964,11 @@ void MakeDebugOptionsFlags(std::vector* flag_list, bool_setter_for(&DebugOptions::set_xla_gpu_enable_xla_runtime_executable), debug_options->xla_gpu_enable_xla_runtime_executable(), "Whether to enable XLA runtime for XLA:GPU backend")); + flag_list->push_back(tsl::Flag( + "xla_gpu_enable_openxla_runtime", + bool_setter_for(&DebugOptions::set_xla_gpu_enable_openxla_runtime), + debug_options->xla_gpu_enable_openxla_runtime(), + "Whether to enable OpenXLA runtime for XLA:GPU backend")); flag_list->push_back(tsl::Flag( "xla_gpu_nccl_termination_timeout_seconds", int64_setter_for( diff --git a/tensorflow/compiler/xla/mlir/backends/openxla/conversion/convert_compiled_ops.cc b/tensorflow/compiler/xla/mlir/backends/openxla/conversion/convert_compiled_ops.cc index f1eea73ef63710..924da08d5ca78c 100644 --- a/tensorflow/compiler/xla/mlir/backends/openxla/conversion/convert_compiled_ops.cc +++ b/tensorflow/compiler/xla/mlir/backends/openxla/conversion/convert_compiled_ops.cc @@ -310,8 +310,8 @@ LogicalResult ConvertCompiledOp::matchAndRewrite( auto src_memref = cast>(copy->source_value()); auto dst_memref = cast>(copy->destination_value()); - auto src = state->remapped[block][src_memref]; - auto dst = state->remapped[block][dst_memref]; + auto src = state->remapped[block][stripReinterpretCast(src_memref)]; + auto dst = state->remapped[block][stripReinterpretCast(dst_memref)]; assert(src && "unknown mapping from `src` memref to a tensor"); assert(dst && "unknown mapping from `dst` memref to a tensor"); diff --git a/tensorflow/compiler/xla/mlir/backends/openxla/transforms/BUILD b/tensorflow/compiler/xla/mlir/backends/openxla/transforms/BUILD index 4de455672ebda4..2856f2551b3c7b 100644 --- a/tensorflow/compiler/xla/mlir/backends/openxla/transforms/BUILD +++ b/tensorflow/compiler/xla/mlir/backends/openxla/transforms/BUILD @@ -31,8 +31,8 @@ package( # "passes.cc", # ], # hdrs = ["passes.h"], -# # TODO(ezhulenev): Override cc_library()'s internal default value of ["//buildenv/target:gce"] -# # because IREE targets are not compatible with `non_prod` constraint. +# # TODO(ezhulenev): Override cc_library()'s default compatibility because IREE targets are not +# # compatible with `non_prod` constraint. # compatible_with = [], # deps = [ # ":passes_inc_gen", @@ -54,4 +54,10 @@ package( # ], # ) # -# copybara:uncomment_end +# copybara:uncomment_end_and_comment_begin +cc_library( + name = "passes", + hdrs = ["passes.h"], + defines = ["XLA_DISABLE_OPENXLA_RUNTIME=1"], +) +# copybara:comment_end diff --git a/tensorflow/compiler/xla/mlir/backends/openxla/transforms/passes.h b/tensorflow/compiler/xla/mlir/backends/openxla/transforms/passes.h index 2a1e1a01ca3466..7249094b880e33 100644 --- a/tensorflow/compiler/xla/mlir/backends/openxla/transforms/passes.h +++ b/tensorflow/compiler/xla/mlir/backends/openxla/transforms/passes.h @@ -16,6 +16,26 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_MLIR_BACKENDS_OPENXLA_TRANSFORMS_PASSES_H_ #define TENSORFLOW_COMPILER_XLA_MLIR_BACKENDS_OPENXLA_TRANSFORMS_PASSES_H_ +//===----------------------------------------------------------------------===// +// TODO(ezhulenev): We currently do not build with OpenXLA runtime in open +// source because we do not have bazel dependency from XLA to IREE. +#if XLA_DISABLE_OPENXLA_RUNTIME +//===----------------------------------------------------------------------===// + +namespace mlir { +class OpPassManager; +} // namespace mlir + +namespace xla::gpu { +class ThunkSequence; +inline void populateOpenXlaRuntimePasses(mlir::OpPassManager&, ThunkSequence*) { +} +} // namespace xla::gpu + +//===----------------------------------------------------------------------===// +#else // !XLA_DISABLE_OPENXLA_RUNTIME +//===----------------------------------------------------------------------===// + #include #include "mlir/IR/BuiltinOps.h" // from @llvm-project @@ -35,8 +55,8 @@ void populateOpenXlaRuntimePasses(mlir::OpPassManager& pm, // Conversion from LMHLO dialects to OpenXLA runtime //===----------------------------------------------------------------------===// -std::unique_ptr> createConvertToOpenXlaPass( - ThunkSequence* thunk_sequence = nullptr); +std::unique_ptr > +createConvertToOpenXlaPass(ThunkSequence* thunk_sequence = nullptr); //===----------------------------------------------------------------------===// // OpenXLA passes registration @@ -46,4 +66,5 @@ void registerOpenXlaPases(); } // namespace xla::gpu +#endif // !XLA_DISABLE_OPENXLA_RUNTIME #endif // TENSORFLOW_COMPILER_XLA_MLIR_BACKENDS_OPENXLA_TRANSFORMS_PASSES_H_ diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index a8e23d8b273cbe..4bf5b9d4159a14 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -931,6 +931,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/service:xla_debug_info_manager", + "//tensorflow/compiler/xla/service/gpu/openxla:executable", "//tensorflow/compiler/xla/service/gpu/runtime:executable", "//tensorflow/compiler/xla/service/gpu/runtime:support", "//tensorflow/compiler/xla/stream_executor", @@ -2260,6 +2261,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/hlo/ir:hlo", "//tensorflow/compiler/xla/mlir/backends/gpu/transforms:passes", + "//tensorflow/compiler/xla/mlir/backends/openxla/transforms:passes", "//tensorflow/compiler/xla/mlir/runtime/transforms:compilation_pipeline_gpu", "//tensorflow/compiler/xla/mlir_hlo:transforms_gpu_passes", "//tensorflow/compiler/xla/service:bitcast_dtypes_expander", diff --git a/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.cc b/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.cc index 14c185675abc74..46288839d6567a 100644 --- a/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.cc +++ b/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.cc @@ -121,6 +121,8 @@ AutotunerCompileUtil::AutotunerCompileUtil(const AutotuneConfig& config, opts_.set_xla_gpu_force_compilation_parallelism(1); // Avoid using GPU graphs as we don't want to measure graph construction time. opts_.set_xla_gpu_cuda_graph_level(0); + // Disable experimental OpenXLA runtime. + opts_.set_xla_gpu_enable_openxla_runtime(false); } StatusOr> diff --git a/tensorflow/compiler/xla/service/gpu/compile_module_to_llvm_ir.cc b/tensorflow/compiler/xla/service/gpu/compile_module_to_llvm_ir.cc index 7c20c6bec73819..fd9081b3b716c1 100644 --- a/tensorflow/compiler/xla/service/gpu/compile_module_to_llvm_ir.cc +++ b/tensorflow/compiler/xla/service/gpu/compile_module_to_llvm_ir.cc @@ -41,6 +41,7 @@ limitations under the License. #include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_module.h" #include "tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.h" +#include "tensorflow/compiler/xla/mlir/backends/openxla/transforms/passes.h" #include "tensorflow/compiler/xla/mlir/runtime/transforms/compilation_pipeline_gpu.h" #include "tensorflow/compiler/xla/mlir_hlo/transforms/gpu_passes.h" #include "tensorflow/compiler/xla/service/bitcast_dtypes_expander.h" @@ -129,6 +130,22 @@ static Status LowerToXlaGpuRuntime(mlir::ModuleOp module, return OkStatus(); } +// Lowers MLIR module to the OpenXla runtime (aka IREE input dialects). +static Status LowerToOpenXlaRuntime(mlir::ModuleOp module, + llvm::StringRef entry_function_name, + llvm::ArrayRef buffer_sizes, + ThunkSequence* thunk_sequence, + const DebugOptions& debug_options) { + mlir::PassManager pm(module->getName(), mlir::PassManager::Nesting::Implicit); + populateOpenXlaRuntimePasses(pm, thunk_sequence); + + if (pm.run(module).failed()) { + return InternalError("Failed to lower LMHLO to OpenXLA input dialects."); + } + + return OkStatus(); +} + void ForAllThunks(const std::function& fn, ThunkSequence* thunk_sequence) { for (std::unique_ptr& thunk : *thunk_sequence) { @@ -162,23 +179,24 @@ std::optional DummyCanShareBufferFunction(const HloInstruction*, return std::nullopt; } -static StatusOr LowerToJitRt( +static void ForwardCollectiveAttrs(mlir::ModuleOp module, + llvm::StringRef entry_function_name, + const HloModuleConfig& config) { + mlir::OpBuilder b(module.getContext()); + auto func = module.lookupSymbol(entry_function_name); + func->setAttr("replica_count", b.getI64IntegerAttr(config.replica_count())); + func->setAttr("num_partitions", b.getI64IntegerAttr(config.num_partitions())); +} + +StatusOr LowerToJitRt( mlir::ModuleOp mlir_module, llvm::StringRef entry_function_name, llvm::ArrayRef buffer_sizes, const HloModuleConfig& module_config, std::unique_ptr thunk_sequence, const HloModule* hlo_module_for_dump) { // Forward collective (NCCL) attributes for use by the lowering pipeline. - mlir::OpBuilder builder(mlir_module.getContext()); - mlir::IntegerAttr replica_count_attr = - builder.getI64IntegerAttr(module_config.replica_count()); - mlir::IntegerAttr num_partitions_attr = - builder.getI64IntegerAttr(module_config.num_partitions()); - mlir::func::FuncOp func = - mlir_module.lookupSymbol(entry_function_name); - func->setAttr("replica_count", replica_count_attr); - func->setAttr("num_partitions", num_partitions_attr); - - // Lower LMHLO operations to the JitRt compatible custom calls. + ForwardCollectiveAttrs(mlir_module, entry_function_name, module_config); + + // Lower LMHLO operations to the XLA:GPU runtime custom calls. TF_RETURN_IF_ERROR(LowerToXlaGpuRuntime( mlir_module, {entry_function_name.data(), entry_function_name.size()}, buffer_sizes, thunk_sequence.get(), module_config.debug_options())); @@ -197,6 +215,32 @@ static StatusOr LowerToJitRt( module_config.debug_options()); } +StatusOr LowerToOpenXla( + std::unique_ptr ctx, + mlir::OwningOpRef module, + llvm::StringRef entry_function_name, llvm::ArrayRef buffer_sizes, + const HloModuleConfig& module_config, + std::unique_ptr thunk_sequence, + const HloModule* hlo_module_for_dump) { + // Forward collective (NCCL) attributes for use by the lowering pipeline. + ForwardCollectiveAttrs(*module, entry_function_name, module_config); + + // Lower LMHLO operations to the OpenXLA compiler input dialects. + TF_RETURN_IF_ERROR(LowerToOpenXlaRuntime( + *module, {entry_function_name.data(), entry_function_name.size()}, + buffer_sizes, thunk_sequence.get(), module_config.debug_options())); + + if (hlo_module_for_dump != nullptr) { + std::string module_str = llvm_ir::DumpToString(*module); + DumpToFileInDirOrStdout(*hlo_module_for_dump, "gpu_rt_host", "mlir", + module_str); + } + + return std::make_unique( + std::move(ctx), std::move(module), entry_function_name.str(), + buffer_sizes.vec(), module_config.debug_options()); +} + StatusOr> CompileModuleToLlvmIr( HloModule* hlo_module, llvm::LLVMContext* llvm_context, const std::string& target_triple, const std::string& data_layout, @@ -369,10 +413,10 @@ Status CompileModuleToLlvmIrImpl( uint64_t start_usecs = tsl::Env::Default()->NowMicros(); mlir::DialectRegistry registry; IrEmitterUnnested::GetDependentDialects(registry); - mlir::MLIRContext mlir_context(registry); - mlir_context.getDiagEngine().registerHandler(DiagnosticHandler); + auto mlir_context = std::make_unique(registry); + mlir_context->getDiagEngine().registerHandler(DiagnosticHandler); mlir::OwningOpRef mlir_module = - mlir::ModuleOp::create(mlir::Builder(&mlir_context).getUnknownLoc()); + mlir::ModuleOp::create(mlir::Builder(mlir_context.get()).getUnknownLoc()); TF_RETURN_IF_ERROR( HloToLhloModule(*results->buffer_assignment, *hlo_module, *mlir_module)); @@ -393,7 +437,7 @@ Status CompileModuleToLlvmIrImpl( IrEmitterContext ir_emitter_context( hlo_module, /*buffer_assignment=*/nullptr, platform_name, gpu_device_info, - cuda_compute_capability, rocm_compute_capability, &mlir_context, + cuda_compute_capability, rocm_compute_capability, mlir_context.get(), results->llvm_module.get()); ir_emitter_context.set_allocations(results->allocations); @@ -425,14 +469,16 @@ Status CompileModuleToLlvmIrImpl( RecordHloToLlvmDuration(end_usecs - start_usecs); } + // Sizes of all buffers required for running XLA module. + std::vector buffer_sizes; + llvm::transform( + results->allocations, std::back_inserter(buffer_sizes), + [](const BufferAllocation& allocation) { return allocation.size(); }); + // TODO(ezhulenev): Remove the FP8 check once https://reviews.llvm.org/D140088 // is submitted. Currently we can't emit LLVM IR with fp8 types. if (IsXlaRuntimeExecutableEnabled(hlo_module->config()) && !HasFp8(*hlo_module)) { - std::vector buffer_sizes; - llvm::transform( - results->allocations, std::back_inserter(buffer_sizes), - [](const BufferAllocation& allocation) { return allocation.size(); }); TF_ASSIGN_OR_RETURN( results->executable, LowerToJitRt(*mlir_module, entry_function.getName(), buffer_sizes, @@ -441,6 +487,16 @@ Status CompileModuleToLlvmIrImpl( return OkStatus(); } + if (IsOpenXlaRuntimeEnabled(hlo_module->config())) { + TF_ASSIGN_OR_RETURN( + results->executable, + LowerToOpenXla(std::move(mlir_context), std::move(mlir_module), + entry_function.getName(), buffer_sizes, + hlo_module->config(), ir_emitter->ConsumeThunkSequence(), + /*hlo_module_for_dump=*/hlo_module)); + return OkStatus(); + } + auto thunk_sequence = ir_emitter->ConsumeThunkSequence(); ForAllThunks([](Thunk* thunk) { thunk->ClearCompileTimeInfo(); }, thunk_sequence.get()); diff --git a/tensorflow/compiler/xla/service/gpu/compile_module_to_llvm_ir.h b/tensorflow/compiler/xla/service/gpu/compile_module_to_llvm_ir.h index 30f4dca37e2f38..3bdf91e80588de 100644 --- a/tensorflow/compiler/xla/service/gpu/compile_module_to_llvm_ir.h +++ b/tensorflow/compiler/xla/service/gpu/compile_module_to_llvm_ir.h @@ -42,7 +42,8 @@ struct CompileModuleResults { std::unique_ptr buffer_assignment; std::vector allocations; std::variant + GpuExecutable::OwnedGpuRuntimeProgram, + GpuExecutable::OwnedOpenXlaRuntimeProgram> executable; EntryFunctionAttributes entry_func_attrs; std::vector constants; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index 10595fc9bbb5b7..cfafa628346571 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -41,6 +41,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/gpu_constants.h" #include "tensorflow/compiler/xla/service/gpu/gpu_types.h" #include "tensorflow/compiler/xla/service/gpu/non_atomically_upgradeable_rw_lock.h" +#include "tensorflow/compiler/xla/service/gpu/openxla/executable.h" #include "tensorflow/compiler/xla/service/gpu/runtime/executable.h" #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" @@ -70,8 +71,17 @@ limitations under the License. namespace xla { namespace gpu { +// If OpenXLA runtime is enabled, it automatically disables "classic" XLA +// runtime which is enabled by default. bool IsXlaRuntimeExecutableEnabled(const HloModuleConfig& config) { - return config.debug_options().xla_gpu_enable_xla_runtime_executable(); + bool runtime = config.debug_options().xla_gpu_enable_xla_runtime_executable(); + bool openxla = config.debug_options().xla_gpu_enable_openxla_runtime(); + return runtime && !openxla; +} + +bool IsOpenXlaRuntimeEnabled(const HloModuleConfig& config) { + bool openxla = config.debug_options().xla_gpu_enable_openxla_runtime(); + return openxla; } namespace { @@ -108,6 +118,15 @@ StatusOr> GpuExecutable::Create(Params params) { return result; } + if (std::holds_alternative(executable)) { + auto& program = std::get(executable); + TF_ASSIGN_OR_RETURN( + result->openxla_executable_, + OpenXlaRuntimeExecutable::Create(std::move(program), result->text(), + result->binary())); + return result; + } + return InternalError("No XLA gpu executable was provided"); } @@ -511,6 +530,36 @@ static Status ExecuteXlaRuntime(const std::string& module_name, block_host_until_done ? run_options->stream() : nullptr); } +static Status ExecuteOpenXlaRuntime( + const std::string& module_name, ModuleIdentifier module_id, + OpenXlaRuntimeExecutable& openxla_executable, + const ServiceExecutableRunOptions* run_options, + const BufferAllocations& buffer_allocations, + const BufferAllocation* temp_buffer, bool block_host_until_done) { + uint64_t start_nanos = tsl::Env::Default()->NowNanos(); + + tsl::profiler::TraceMe hlo_module_activity( + [&] { return absl::StrCat(module_name, ":XLA GPU module"); }, + tsl::profiler::TraceMeLevel::kInfo); + + ScopedAnnotationAlways annotation([&] { + std::string module_id_str; + if (module_id >= 0) { + module_id_str = absl::StrFormat(",program_id=%d", module_id); + } + return absl::StrFormat("XlaModule:#hlo_module=%s%s#", module_name, + module_id_str); + }); + + auto executed = + openxla_executable.Execute(run_options, buffer_allocations, temp_buffer); + if (!executed.ok()) return executed; + + return MaybeSyncAndProfile( + run_options, start_nanos, + block_host_until_done ? run_options->stream() : nullptr); +} + Status GpuExecutable::PopulatePersistentTempBuffers( se::StreamExecutor* executor) { auto search = persistent_temp_buffers_.find(executor); @@ -761,21 +810,28 @@ Status GpuExecutable::ExecuteThunksOrXlaRuntime( : false); } - if (gpu_runtime_executable_) { - // Match IrEmitter's temp buffer allocation for kernel launches. See - // IrEmitterUnnested::BuildKernelThunkImpl(). - const BufferAllocation* temp_buffer = nullptr; - for (const BufferAllocation& alloc : allocations_) { - if (alloc.IsPreallocatedTempBuffer()) { - // Retrieve the first seen temp buffer. - if (temp_buffer == nullptr) temp_buffer = &alloc; - } + // Match IrEmitter's temp buffer allocation for kernel launches. See + // IrEmitterUnnested::BuildKernelThunkImpl(). + const BufferAllocation* temp_buffer = nullptr; + for (const BufferAllocation& alloc : allocations_) { + if (alloc.IsPreallocatedTempBuffer()) { + // Retrieve the first seen temp buffer. + if (temp_buffer == nullptr) temp_buffer = &alloc; } + } + + if (gpu_runtime_executable_) { return ExecuteXlaRuntime(module_name_, unique_id, *gpu_runtime_executable_, run_options, text_, binary_, buffer_allocations, temp_buffer, block_host_until_done, gpu_lock); } + if (openxla_executable_) { + return ExecuteOpenXlaRuntime(module_name_, unique_id, *openxla_executable_, + run_options, buffer_allocations, temp_buffer, + block_host_until_done); + } + return FailedPrecondition("Expected XLA gpu executable is not supplied."); } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h index 830cc20914e517..94ceac748ed98a 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h @@ -38,6 +38,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" #include "tensorflow/compiler/xla/service/gpu/gpu_types.h" #include "tensorflow/compiler/xla/service/gpu/non_atomically_upgradeable_rw_lock.h" +#include "tensorflow/compiler/xla/service/gpu/openxla/executable.h" #include "tensorflow/compiler/xla/service/gpu/runtime/executable.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" @@ -52,6 +53,9 @@ namespace gpu { // Returns whether GpuExecutable runs with Xla Runtime. bool IsXlaRuntimeExecutableEnabled(const HloModuleConfig& config); +// Returns whether GpuExecutable runs with OpenXla Runtime. +bool IsOpenXlaRuntimeEnabled(const HloModuleConfig& config); + // GPU-targeting implementation of the XLA Executable interface. // // Launches the given GPU kernel via the StreamExecutor. @@ -61,6 +65,7 @@ class GpuExecutable : public Executable { public: using OwnedThunkSequence = std::unique_ptr; using OwnedGpuRuntimeProgram = std::unique_ptr; + using OwnedOpenXlaRuntimeProgram = std::unique_ptr; struct ConstantInfo { std::string symbol_name; @@ -84,9 +89,12 @@ class GpuExecutable : public Executable { std::string asm_text; std::vector binary; GpuVersion gpu_version; - // The GpuExecutable will either execute Thunks or a XLA Runtime compiled - // native function depending on which is supplied. - std::variant executable; + // The GpuExecutable will either execute Thunks, XLA runtime executable + // (native function) or OpenXLA runtime executable (IREE VM function) + // depending on which is supplied. + std::variant + executable; xla::EntryFunctionAttributes entry_func_attrs; std::vector constants; absl::flat_hash_map output_info; @@ -271,7 +279,7 @@ class GpuExecutable : public Executable { GpuVersion gpu_version_; // The thunks to be invoked by this GpuExecutable. They are generated by the - // IrEmitter (null if Xla runtime is enabled). + // IrEmitter (null if Xla/OpenXla runtime is enabled). OwnedThunkSequence thunks_; // Gpu runtime executable that encapsulates all the state for running Gpu @@ -279,6 +287,10 @@ class GpuExecutable : public Executable { // Xla runtime is enabled). std::unique_ptr gpu_runtime_executable_; + // OpenXLA executable that encapsulates all the state for running XLA:GPU + // executables compiled to IREE VM modules (including VM module itself). + std::unique_ptr openxla_executable_; + xla::EntryFunctionAttributes entry_func_attrs_; std::string module_name_; diff --git a/tensorflow/compiler/xla/service/gpu/openxla/BUILD b/tensorflow/compiler/xla/service/gpu/openxla/BUILD new file mode 100644 index 00000000000000..50113a4ffabf34 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/openxla/BUILD @@ -0,0 +1,99 @@ +load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") +load("//tensorflow/tsl/platform:build_config.bzl", "tf_platform_deps") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [":friends"], + licenses = ["notice"], +) + +package_group( + name = "friends", + includes = ["//tensorflow/compiler/xla:friends"], +) + +# copybara:uncomment_begin(not supported in OSS build) +# +# # Add `--define=xla_gpu_bundle_lib_iree_compiler=1` to build command to bundle `libIREECompiler.so` +# # with XLA:GPU by default. Otherwise use `XLA_OPENXLA_IREE_COMPILER_LIB` environment variable to +# # load custom compiler library. +# config_setting( +# name = "bundle_lib_iree_compiler", +# values = { +# "define": "xla_gpu_bundle_lib_iree_compiler=1", +# }, +# ) +# +# cc_library( +# name = "compiler", +# srcs = ["compiler.cc"], +# hdrs = ["compiler.h"], +# # TODO(ezhulenev): Override cc_library()'s default compatibility because IREE targets are not +# # compatible with `non_prod` constraint. +# compatible_with = [], +# data = select({ +# ":bundle_lib_iree_compiler": ["//third_party/iree/lib:libIREECompiler.so"], +# "//conditions:default": [], +# }), +# deps = [ +# "@com_google_absl//absl/base", +# "//third_party/iree/compiler/bindings/c:headers", +# "//third_party/iree/compiler/bindings/c:loader", +# "//third_party/iree/llvm-external-projects/iree-dialects:IREEInputDialect", +# "@llvm-project//llvm:Support", +# "@llvm-project//mlir:IR", +# "@llvm-project//mlir:Support", +# "//tensorflow/compiler/xla:status", +# "//tensorflow/compiler/xla:util", +# "//tensorflow/tsl/platform", +# ] + tf_platform_deps( +# "compiler", +# platform_dir = "//tensorflow/compiler/xla/service/gpu/openxla/", +# ), +# ) +# +# cc_library( +# name = "executable", +# srcs = ["executable.cc"], +# hdrs = ["executable.h"], +# # TODO(ezhulenev): Override cc_library()'s default compatibility because IREE targets are not +# # compatible with `non_prod` constraint. +# compatible_with = [], +# deps = [ +# ":compiler", +# "@com_google_absl//absl/log", +# "@com_google_absl//absl/log:check", +# "@com_google_absl//absl/strings", +# "//third_party/iree/runtime/src/iree/base", +# "//third_party/iree/runtime/src/iree/hal", +# "//third_party/iree/runtime/src/iree/hal/drivers/cuda", +# "//third_party/iree/runtime/src/iree/modules/hal", +# "//third_party/iree/runtime/src/iree/modules/hal:types", +# "//third_party/iree/runtime/src/iree/vm", +# "//third_party/iree/runtime/src/iree/vm/bytecode:module", +# "@llvm-project//mlir:IR", +# "//tensorflow/compiler/xla:status", +# "//tensorflow/compiler/xla:statusor", +# "//tensorflow/compiler/xla:util", +# "//tensorflow/compiler/xla:xla_proto_cc", +# "//tensorflow/compiler/xla/service:buffer_assignment", +# "//tensorflow/compiler/xla/service:executable", +# "//tensorflow/compiler/xla/service/gpu:buffer_allocations", +# "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", +# ], +# ) +# +# copybara:uncomment_end_and_comment_begin +cc_library( + name = "executable", + hdrs = ["executable.h"], + defines = ["XLA_DISABLE_OPENXLA_RUNTIME=1"], + deps = [ + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/service:executable", + "//tensorflow/compiler/xla/service/gpu:buffer_allocations", + "@llvm-project//mlir:IR", + ], +) +# copybara:comment_end diff --git a/tensorflow/compiler/xla/service/gpu/openxla/compiler.cc b/tensorflow/compiler/xla/service/gpu/openxla/compiler.cc new file mode 100644 index 00000000000000..74df74631a9a9b --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/openxla/compiler.cc @@ -0,0 +1,238 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/openxla/compiler.h" + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/call_once.h" +#include "third_party/iree/compiler/bindings/c/iree/compiler/embedding_api.h" +#include "third_party/iree/compiler/bindings/c/iree/compiler/loader.h" +#include "third_party/iree/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/Input/InputOps.h" +#include "llvm/Support/FormatVariadic.h" +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/tsl/platform/platform.h" + +#if defined(PLATFORM_GOOGLE) +#include "tensorflow/compiler/xla/service/gpu/openxla/google/compiler.h" +#else +#include "tensorflow/compiler/xla/service/gpu/openxla/default/compiler.h" +#endif + +namespace xla::gpu { + +//===-----------------------------------------------------------------------===/ +// OpenXlaCompiler::Bytecode +//===-----------------------------------------------------------------------===/ + +OpenXlaCompiler::Bytecode::Bytecode(iree_compiler_output_t* output, void* data, + size_t length) + : output_(output), data_(data), length_(length) {} + +OpenXlaCompiler::Bytecode::~Bytecode() { ireeCompilerOutputDestroy(output_); } + +//===-----------------------------------------------------------------------===/ +// OpenXlaCompiler +//===-----------------------------------------------------------------------===/ + +OpenXlaCompiler::OpenXlaCompiler(iree_compiler_session_t* session, + iree_compiler_invocation_t* inv) + : session_(session), inv_(inv) {} + +OpenXlaCompiler::~OpenXlaCompiler() { + if (error_) { + ireeCompilerErrorDestroy(error_); + } + + ireeCompilerInvocationDestroy(inv_); + ireeCompilerSessionDestroy(session_); + + if (output_) { + ireeCompilerOutputDestroy(output_); + } +} + +bool OpenXlaCompiler::ParseSourceBuffer(std::string_view buffer) { + iree_compiler_source_t* source; + auto* error = ireeCompilerSourceWrapBuffer( + session_, "", buffer.data(), buffer.size(), + /*isNullTerminated=*/false, &source); + if (error) { + SetError(error); + return false; + } + + return ireeCompilerInvocationParseSource(inv_, source); +} + +bool OpenXlaCompiler::SetFlag(const char* flag) { + auto* error = ireeCompilerSessionSetFlags(session_, 1, &flag); + if (error) { + SetError(error); + return false; + } + return true; +} + +std::unique_ptr +OpenXlaCompiler::CompileStandardPipeline() { + if (!ireeCompilerInvocationPipeline(inv_, IREE_COMPILER_PIPELINE_STD)) { + return nullptr; + } + + iree_compiler_error_t* error = ireeCompilerOutputOpenMembuffer(&output_); + if (error) { + SetError(error); + return nullptr; + } + + error = ireeCompilerInvocationOutputVMBytecode(inv_, output_); + if (error) { + SetError(error); + return nullptr; + } + + void* output_data = nullptr; + uint64_t size; + error = ireeCompilerOutputMapMemory(output_, &output_data, &size); + if (error) { + SetError(error); + return nullptr; + } + + // Transfer the output_ to Bytecode since the mapping is only + // valid for the life of the output. + iree_compiler_output_t* local_output = output_; + output_ = nullptr; + return std::make_unique(local_output, output_data, size); +} + +//===-----------------------------------------------------------------------===/ +// Loading OpenXlaCompiler from a library +//===-----------------------------------------------------------------------===/ + +static bool InitializeCompilerForProcess(const std::string& library_path) { + if (!ireeCompilerLoadLibrary(library_path.c_str())) { + return false; + } + + ireeCompilerGlobalInitialize(); + return true; +} + +static std::optional LoadCompilerStubOnce( + const std::string& library_path) { + static std::string* loaded_path = nullptr; + + static absl::once_flag loaded; + absl::call_once(loaded, [&] { + if (InitializeCompilerForProcess(library_path)) { + loaded_path = new std::string(library_path); + } + }); + + if (loaded_path) return *loaded_path; + return std::nullopt; +} + +std::unique_ptr CreateOpenXlaCompiler() { + LoadCompilerStubOnce(GetIREECompilerPath()); + + auto* session = ireeCompilerSessionCreate(); + auto* inv = ireeCompilerInvocationCreate(session); + + ireeCompilerInvocationEnableConsoleDiagnostics(inv); + + return std::make_unique(session, inv); +} + +//===-----------------------------------------------------------------------===/ +// Binding Xla device kernels to an input module. +//===-----------------------------------------------------------------------===/ + +using namespace mlir; // NOLINT +using namespace mlir::iree_compiler; // NOLINT + +// TODO(ezhulenev): Query compute capability from the XLA module and set it up +// at the module level. +static constexpr int kComputeCapability = 60; + +static IREE::Input::ExecutableTargetAttr getExecutableTarget(MLIRContext* ctx) { + Builder b(ctx); + + SmallVector config{ + {b.getStringAttr("target_arch"), + b.getStringAttr(llvm::formatv("sm_{0}", kComputeCapability))}, + }; + + return IREE::Input::ExecutableTargetAttr::get( + ctx, b.getStringAttr("cuda"), b.getStringAttr("cuda-nvptx-fb"), + b.getDictionaryAttr(config)); +} + +static IREE::Input::ExecutableObjectAttr getExecutableObject( + MLIRContext* ctx, const std::vector& binary) { + Builder b(ctx); + + // TODO(ezhulenev): Use dense i8 arrays to pass binary data. + auto vec = VectorType::get(binary.size(), b.getI8Type()); + return IREE::Input::ExecutableObjectAttr::get( + ctx, /*path=*/b.getStringAttr(""), + DenseIntElementsAttr::get(vec, binary)); +} + +static IREE::Input::ExecutableObjectsAttr getExecutableObjects( + IREE::Input::ExecutableTargetAttr target, + IREE::Input::ExecutableObjectAttr executable) { + Builder b(target.getContext()); + return IREE::Input::ExecutableObjectsAttr::get( + b.getContext(), b.getArrayAttr(target), + b.getArrayAttr(b.getArrayAttr(executable))); +} + +Status BindXlaDeviceKernels(mlir::ModuleOp module, std::string_view asm_text, + const std::vector& binary) { + auto* ctx = module.getContext(); + SymbolTable sym_table(module); + + auto src = + sym_table.lookup("xla.module.ptx"); + if (!src) return InternalError("failed to find XLA executable source"); + + // Bind XLA device kernels to an executable source. + auto objects = getExecutableObjects(getExecutableTarget(ctx), + getExecutableObject(ctx, binary)); + src.setObjectsAttr(objects); + + return OkStatus(); +} + +} // namespace xla::gpu diff --git a/tensorflow/compiler/xla/service/gpu/openxla/compiler.h b/tensorflow/compiler/xla/service/gpu/openxla/compiler.h new file mode 100644 index 00000000000000..fa67fce0b1544b --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/openxla/compiler.h @@ -0,0 +1,95 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_OPENXLA_COMPILER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_OPENXLA_COMPILER_H_ + +#include + +#include +#include +#include + +#include "third_party/iree/compiler/bindings/c/iree/compiler/embedding_api.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "tensorflow/compiler/xla/status.h" + +namespace xla::gpu { + +// Forward declare. +class OpenXlaCompiler; + +// Returns a new instance of the OpenXLA compiler loading it from a library. +// Every instance of the compiler creates a unique IREE compiler session. +std::unique_ptr CreateOpenXlaCompiler(); + +// Updates OpenXLA input module with device kernels compiled by XLA. +Status BindXlaDeviceKernels(mlir::ModuleOp, std::string_view asm_text, + const std::vector& binary); + +// Wrapper around IREE compiler + bundled OpenXLA compiler plugins to +// orchestrate compilation from OpenXLA input dialects for IREE VM +// flatbuffer. +// +// TODO(ezhulenev): Instead of returning `bool` return helpful Status +// errors. +class OpenXlaCompiler { + public: + // RAII wrapper around the compiler output (IREE VM bytecode). + class Bytecode { + public: + Bytecode(iree_compiler_output_t* output, void* data, size_t length); + ~Bytecode(); + + void* data() { return data_; } + size_t lenth() { return length_; } + + private: + iree_compiler_output_t* output_; + void* data_; + size_t length_; + }; + + OpenXlaCompiler(iree_compiler_session_t* session, + iree_compiler_invocation_t* inv); + + ~OpenXlaCompiler(); + + bool ParseSourceBuffer(std::string_view buffer); + + bool SetFlag(const char* flag); + + std::unique_ptr CompileStandardPipeline(); + + private: + void SetError(iree_compiler_error_t* error) { + LOG(ERROR) << "OpenXLA compiler error: " + << ireeCompilerErrorGetMessage(error); + if (error_) { + ireeCompilerErrorDestroy(error_); + } + error_ = error; + } + + iree_compiler_session_t* session_ = nullptr; + iree_compiler_invocation_t* inv_ = nullptr; + + iree_compiler_error_t* error_ = nullptr; + iree_compiler_output_t* output_ = nullptr; +}; + +} // namespace xla::gpu + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_OPENXLA_COMPILER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/openxla/default/BUILD b/tensorflow/compiler/xla/service/gpu/openxla/default/BUILD new file mode 100644 index 00000000000000..0270f0b8d93443 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/openxla/default/BUILD @@ -0,0 +1,17 @@ +load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [":friends"], + licenses = ["notice"], +) + +package_group( + name = "friends", + includes = ["//tensorflow/compiler/xla:friends"], +) + +cc_library( + name = "compiler", + hdrs = ["compiler.h"], +) diff --git a/tensorflow/compiler/xla/service/gpu/openxla/default/compiler.h b/tensorflow/compiler/xla/service/gpu/openxla/default/compiler.h new file mode 100644 index 00000000000000..496dedfc36f3df --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/openxla/default/compiler.h @@ -0,0 +1,33 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_OPENXLA_DEFAULT_COMPILER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_OPENXLA_DEFAULT_COMPILER_H_ + +#include +#include + +namespace xla::gpu { + +inline std::string GetIREECompilerPath() { + if (const char* path = std::getenv("XLA_OPENXLA_IREE_COMPILER_LIB")) { + return std::string(path); + } + return "libIREECompiler.so"; +} + +} // namespace xla::gpu + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_OPENXLA_DEFAULT_COMPILER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/openxla/executable.cc b/tensorflow/compiler/xla/service/gpu/openxla/executable.cc new file mode 100644 index 00000000000000..f96a63c347c13a --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/openxla/executable.cc @@ -0,0 +1,253 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/openxla/executable.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "third_party/iree/runtime/src/iree/base/api.h" +#include "third_party/iree/runtime/src/iree/hal/api.h" +#include "third_party/iree/runtime/src/iree/hal/drivers/cuda/api.h" +#include "third_party/iree/runtime/src/iree/modules/hal/module.h" +#include "third_party/iree/runtime/src/iree/modules/hal/types.h" +#include "third_party/iree/runtime/src/iree/vm/api.h" +#include "third_party/iree/runtime/src/iree/vm/bytecode/module.h" +#include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/gpu/openxla/compiler.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/service/service_executable_run_options.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/xla.pb.h" + +namespace xla::gpu { + +// TODO(ezhulenev): In this file we need to remove all IREE_CHECK_OK and replace +// with RETURN_IF_ERROR macro that will do iree_status_t => Status conversion. + +//===-----------------------------------------------------------------------===/ +// OpenXlaDevice +//===-----------------------------------------------------------------------===/ + +struct OpenXlaDevice { + ~OpenXlaDevice() { + iree_hal_device_destroy(device); + iree_hal_driver_destroy(driver); + } + + iree_hal_driver_t* driver = nullptr; + iree_hal_device_t* device = nullptr; +}; + +static iree_status_t CreateCudaDriver(iree_allocator_t allocator, + OpenXlaDevice* device) { + iree_string_view_t driver_name = iree_make_cstring_view("cuda"); + + iree_hal_cuda_device_params_t default_params; + iree_hal_cuda_device_params_initialize(&default_params); + default_params.command_buffer_mode = IREE_HAL_CUDA_COMMAND_BUFFER_MODE_STREAM; + default_params.allow_inline_execution = false; + + iree_hal_cuda_driver_options_t driver_options; + iree_hal_cuda_driver_options_initialize(&driver_options); + driver_options.default_device_index = 0; + + IREE_CHECK_OK(iree_hal_cuda_driver_create(driver_name, &default_params, + &driver_options, allocator, + &device->driver)); + + return iree_ok_status(); +} + +static iree_status_t CreateCudaDevice(iree_allocator_t allocator, + OpenXlaDevice* device) { + IREE_CHECK_OK(iree_hal_driver_create_device_by_id( + device->driver, /*device_id=*/0, + /*param_count=*/0, /*params=*/nullptr, allocator, &device->device)); + return iree_ok_status(); +} + +//===-----------------------------------------------------------------------===/ +// OpenXlaRuntimeExecutable +//===-----------------------------------------------------------------------===/ + +// TODO(ezhulenev): Crashing is absolutely not ok here, add proper error +// handling and remove CHECK and IREE_CHECK_OK. + +/*static*/ StatusOr> +OpenXlaRuntimeExecutable::Create(std::unique_ptr program, + std::string_view asm_text, + const std::vector& binary) { + CHECK_OK(BindXlaDeviceKernels(*program->module, asm_text, binary)); + auto source = llvm_ir::DumpToString(*program->module); + + // Compile IR in OpenXLA input dialect(s) into IREE VM flatbuffer. + auto compiler = CreateOpenXlaCompiler(); + CHECK(compiler->SetFlag("--iree-hal-target-backends=cuda")); + CHECK(compiler->ParseSourceBuffer(source)); + auto bytecode = compiler->CompileStandardPipeline(); + CHECK(bytecode); + + // TODO(ezhulenev): We need a better strategy for managing IREE resources: VM + // instances, contexts, devices, etc. What should be shared between all XLA + // executables, and what should be unique to each executable? + iree_allocator_t allocator = iree_allocator_system(); + + // Create the root isolated VM instance that we can create contexts within. + iree::vm::ref instance; + IREE_CHECK_OK(iree_vm_instance_create(IREE_VM_TYPE_CAPACITY_DEFAULT, + allocator, &instance)); + + // TODO(ezhulenev): CUDA devices/drivers should be created globally, and share + // CUDA context with corresponding StreamExecutor. + auto device = std::make_unique(); + IREE_CHECK_OK(CreateCudaDriver(allocator, device.get())); + IREE_CHECK_OK(CreateCudaDevice(allocator, device.get())); + + auto modules = std::make_unique>(); + + // Load HAL module. + IREE_CHECK_OK(iree_hal_module_register_all_types(instance.get())); + IREE_CHECK_OK(iree_hal_module_create(instance.get(), device->device, + IREE_HAL_MODULE_FLAG_NONE, allocator, + &modules->emplace_back())); + + // Load module compiled from XLA program to a VM flatbuffer. + IREE_CHECK_OK(iree_vm_bytecode_module_create( + instance.get(), + iree_make_const_byte_span(bytecode->data(), bytecode->lenth()), + /*archive_allocator=*/iree_allocator_null(), allocator, + &modules->emplace_back())); + + // TODO(ezhulenev): Figure out what is the correct context management strategy + // given that executable can be executed concurrently. This is almost + // certainly wrong, and will lead to data races. + iree::vm::ref context; + IREE_CHECK_OK(iree_vm_context_create_with_modules( + instance.get(), IREE_VM_CONTEXT_FLAG_NONE, modules->size(), + modules->data(), allocator, &context)); + + // Fully qualified entry point name. + auto module_name = iree_vm_module_name(modules->back()); + std::string qualified_name = std::string(module_name.data, module_name.size) + + "." + program->entry_point; + + // Look up the function by fully-qualified name (module.func). + iree_vm_function_t function; + IREE_CHECK_OK(iree_vm_context_resolve_function( + context.get(), iree_make_cstring_view(qualified_name.c_str()), + &function)); + + return std::unique_ptr(new OpenXlaRuntimeExecutable( + std::move(device), std::move(bytecode), std::move(program->buffer_sizes), + std::move(program->debug_options), context, instance, std::move(modules), + function)); +} + +OpenXlaRuntimeExecutable::OpenXlaRuntimeExecutable( + std::unique_ptr device, std::unique_ptr bytecode, + std::vector buffer_sizes, DebugOptions debug_options, + iree::vm::ref context, + iree::vm::ref instance, + std::unique_ptr> modules, + iree_vm_function_t function) + : device_(std::move(device)), + bytecode_(std::move(bytecode)), + buffer_sizes_(std::move(buffer_sizes)), + debug_options_(std::move(debug_options)), + context_(std::move(context)), + instance_(std::move(instance)), + modules_(std::move(modules)), + function_(std::move(function)) { + auto name = iree_vm_function_name(&function_); + VLOG(1) << "Created OpenXLA executable: function name = " + << std::string_view(name.data, name.size); +} + +OpenXlaRuntimeExecutable::~OpenXlaRuntimeExecutable() { + for (auto module : *modules_) iree_vm_module_release(module); +} + +Status OpenXlaRuntimeExecutable::Execute( + const ServiceExecutableRunOptions* run_options, + const BufferAllocations& buffer_allocations, + const BufferAllocation* temp_alloc) { + unsigned num_buffer_allocations = buffer_allocations.size(); + CHECK(num_buffer_allocations == buffer_sizes_.size()); // CHECK OK + + iree_allocator_t allocator = iree_allocator_system(); + + // Convert XLA buffer allocations to IREE buffer views. + iree::vm::ref inputs; + IREE_CHECK_OK(iree_vm_list_create(iree_vm_make_undefined_type_def(), + buffer_allocations.size(), allocator, + &inputs)); + + // Import argument buffers as device-local IREE buffers. + std::vector> buffers; + + for (unsigned i = 0; i < num_buffer_allocations; ++i) { + // Import XLA buffer as an IREE external buffer. + iree_hal_external_buffer_t external_buffer; + external_buffer.type = IREE_HAL_EXTERNAL_BUFFER_TYPE_DEVICE_ALLOCATION; + external_buffer.flags = IREE_HAL_EXTERNAL_BUFFER_FLAG_NONE; + external_buffer.size = buffer_sizes_[i]; + external_buffer.handle.device_allocation.ptr = reinterpret_cast( + buffer_allocations.GetDeviceAddress(i).opaque()); + + // All XLA:GPU buffer arguments are always allocated on device. + iree_hal_buffer_params_t buffer_params = { + /*usage=*/IREE_HAL_BUFFER_USAGE_DEFAULT, + /*access=*/IREE_HAL_MEMORY_ACCESS_ALL, + /*type=*/IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL, + }; + + // Import XLA buffer with no-op release callback, because the lifetime of + // arguments is managed by XLA itself and we get non-owning pointers. + IREE_CHECK_OK(iree_hal_allocator_import_buffer( + iree_hal_device_allocator(device_->device), buffer_params, + &external_buffer, {[](void*, iree_hal_buffer_t*) {}, nullptr}, + &buffers.emplace_back())); + + // In XLA all buffer arguments are vectors of i8 data type. + iree_hal_buffer_view_t* view = nullptr; + IREE_CHECK_OK(iree_hal_buffer_view_create( + buffers.back().get(), + /*shape_rank=*/1, + /*shape=*/&external_buffer.size, IREE_HAL_ELEMENT_TYPE_INT_8, + IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, allocator, &view)); + + // Move buffer view to the inputs list. + iree_vm_ref_t view_ref = iree_hal_buffer_view_move_ref(view); + IREE_CHECK_OK(iree_vm_list_push_ref_move(inputs.get(), &view_ref)); + } + + IREE_CHECK_OK(iree_vm_invoke(context_.get(), function_, + IREE_VM_INVOCATION_FLAG_NONE, + /*policy=*/nullptr, inputs.get(), + /*outputs=*/nullptr, allocator)); + + return OkStatus(); +} + +} // namespace xla::gpu diff --git a/tensorflow/compiler/xla/service/gpu/openxla/executable.h b/tensorflow/compiler/xla/service/gpu/openxla/executable.h new file mode 100644 index 00000000000000..33ad00401adfbe --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/openxla/executable.h @@ -0,0 +1,167 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_OPENXLA_EXECUTABLE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_OPENXLA_EXECUTABLE_H_ + +#include +#include +#include +#include + +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/service_executable_run_options.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/statusor.h" + +//===----------------------------------------------------------------------===// +// TODO(ezhulenev): We currently do not build with OpenXLA runtime in open +// source because we do not have bazel dependency from XLA to IREE. +#if XLA_DISABLE_OPENXLA_RUNTIME +//===----------------------------------------------------------------------===// + +namespace xla::gpu { +struct OpenXlaRuntimeProgram { + OpenXlaRuntimeProgram(std::unique_ptr, + mlir::OwningOpRef, std::string, + std::vector, DebugOptions) {} +}; + +struct OpenXlaRuntimeExecutable { + static StatusOr> Create( + std::unique_ptr, std::string_view, + const std::vector&) { + return absl::UnimplementedError( + "OpenXLA runtime is not supported in OSS build"); + } + + Status Execute(const ServiceExecutableRunOptions* run_options, + const BufferAllocations& buffer_allocations, + const BufferAllocation* temp_alloc) { + return absl::UnimplementedError( + "OpenXLA runtime is not supported in OSS build"); + } +}; + +} // namespace xla::gpu + +//===----------------------------------------------------------------------===// +#else // !XLA_DISABLE_OPENXLA_RUNTIME +//===----------------------------------------------------------------------===// + +#include +#include + +#include "third_party/iree/runtime/src/iree/vm/api.h" +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/openxla/compiler.h" +#include "tensorflow/compiler/xla/xla.pb.h" + +namespace xla { +namespace gpu { + +// Forward declare. +struct OpenXlaDevice; + +// Xla Gpu program lowered to the OpenXLA dialects (IREE input dialects). +// OpenXLA runtime executable jit-compiles this program to an executable +// artifact (via lowering to IREE VM executable). +// +// Top level module has a single HAL executable source that contains all device +// kernels for an XLA module. After lowering from LMHLO executable source is +// just a placeholder; it gets updated with a real device kernels source only +// before we pass IR to OpenXLA/IREE compiler (see OpenXlaRuntimeExecutable +// below), because in the XLA compilation pipeline backend compiler runs last. +// +// We have this program as an intermediate step between lowering from LMHLO to +// VM executable to be able to introspect the compilation process. Once we have +// this program, the Xla gpu compiler job is done, and lowering to IREE VM is +// the responsibility of OpenXLA/IREE compiler. +struct OpenXlaRuntimeProgram { + OpenXlaRuntimeProgram(std::unique_ptr ctx, + mlir::OwningOpRef module, + std::string entry_point, + std::vector buffer_sizes, + DebugOptions debug_options) + : ctx(std::move(ctx)), + module(std::move(module)), + entry_point(std::move(entry_point)), + buffer_sizes(std::move(buffer_sizes)), + debug_options(std::move(debug_options)) {} + + std::unique_ptr ctx; + mlir::OwningOpRef module; + std::string entry_point; + std::vector buffer_sizes; + DebugOptions debug_options; +}; + +// Gpu runtime executable encapsulates the Xla runtime executable compiled from +// an Xla program and owns all the state required for running it (e.g. it owns +// various caches required for performance). +class OpenXlaRuntimeExecutable { + public: + using Bytecode = OpenXlaCompiler::Bytecode; + + // Creates OpenXlaRuntimeExecutable from the OpenXLA program. + static StatusOr> Create( + std::unique_ptr program, std::string_view asm_text, + const std::vector& binary); + + ~OpenXlaRuntimeExecutable(); + + // Executes entry function with the given buffer arguments. + Status Execute(const ServiceExecutableRunOptions* run_options, + const BufferAllocations& buffer_allocations, + const BufferAllocation* temp_alloc); + + private: + OpenXlaRuntimeExecutable( + std::unique_ptr device, std::unique_ptr bytecode, + std::vector buffer_sizes, DebugOptions debug_options, + iree::vm::ref context, + iree::vm::ref instance, + std::unique_ptr> modules, + iree_vm_function_t function); + + // TODO(ezhulenev): Devices should be created lazily for each StreamExecutor + // and share underlying resources. For now we create a CUDA driver and CUDA + // HAL device for each executable. And we assume that we have just once GPU + // attached to the host, and always run on device with ordinal 0. + std::unique_ptr device_; + + std::unique_ptr bytecode_; + + std::vector buffer_sizes_; + const DebugOptions debug_options_; + + // TODO(ezhulenev): VM context and instance should be shared between multiple + // executables. Also HAL module should be loaded just once. This has to be + // fixed together with efficient device sharing, because HAL VM module + // requires HAL device for loading. + iree::vm::ref context_; + iree::vm::ref instance_; + std::unique_ptr> modules_; + iree_vm_function_t function_; +}; + +} // namespace gpu +} // namespace xla + +#endif // !XLA_DISABLE_OPENXLA_RUNTIME +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_OPENXLA_EXECUTABLE_H_ diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index 3a92c78a5059e8..5ce1f65330c66c 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -422,6 +422,15 @@ message DebugOptions { // If true, use XLA runtime for XLA:GPU backend. bool xla_gpu_enable_xla_runtime_executable = 169; + // If true, use OpenXLA runtime for XLA:GPU backend. That is, use IREE VM + // as a host executable, CUDA HAL for dispatching device kernels and + // custom modules for integration with libraries required for running + // XLA:GPU programs. + // + // Note: this mode disables thunks and the "classic" gpu runtime, which + // is defined above. + bool xla_gpu_enable_openxla_runtime = 233; + // Timeout in seconds before terminating jobs that are stuck in a NCCL // Rendezvous. Negative value disables the timeout and will not terminate. int64 xla_gpu_nccl_termination_timeout_seconds = 163; @@ -580,7 +589,7 @@ message DebugOptions { bool xla_gpu_dump_autotuned_triton_fusions = 232; - // Next id: 233 + // Next id: 234 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend. From 54c35481b516b846b50bd1a000a3150684eb2bce Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 24 Jul 2023 13:30:22 -0700 Subject: [PATCH 063/410] Implementation of FFT SPMD expander in DTensor. PiperOrigin-RevId: 550657034 --- tensorflow/dtensor/mlir/BUILD | 1 + tensorflow/dtensor/mlir/collectives.cc | 24 + tensorflow/dtensor/mlir/collectives.h | 6 + .../mlir/expansions/fft_spmd_expander.cc | 512 ++++++++ .../mlir/expansions/fft_spmd_expander.h | 44 + tensorflow/dtensor/mlir/spmd_expanders.cc | 49 +- tensorflow/dtensor/mlir/utils/BUILD | 1 + .../dtensor/mlir/utils/collective_lowering.cc | 26 +- tensorflow/dtensor/python/tests/BUILD | 2 + tensorflow/dtensor/python/tests/spmd_test.py | 1110 +++++++++++++++++ 10 files changed, 1714 insertions(+), 61 deletions(-) create mode 100644 tensorflow/dtensor/mlir/expansions/fft_spmd_expander.cc create mode 100644 tensorflow/dtensor/mlir/expansions/fft_spmd_expander.h diff --git a/tensorflow/dtensor/mlir/BUILD b/tensorflow/dtensor/mlir/BUILD index e2b6ac49bb1556..e543e419cde8a5 100644 --- a/tensorflow/dtensor/mlir/BUILD +++ b/tensorflow/dtensor/mlir/BUILD @@ -482,6 +482,7 @@ cc_library( ":spmd_expander_common", ":tf_dtensor_dialect", ":value_utils", + "//tensorflow/cc:ops", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:convert_tensor", "//tensorflow/compiler/mlir/tensorflow:tensorflow_attributes", diff --git a/tensorflow/dtensor/mlir/collectives.cc b/tensorflow/dtensor/mlir/collectives.cc index a13f47664bd186..0dfb9fd820b89d 100644 --- a/tensorflow/dtensor/mlir/collectives.cc +++ b/tensorflow/dtensor/mlir/collectives.cc @@ -384,6 +384,30 @@ StatusOr EmitRelayout( newly_created_ops); } +mlir::Operation* EmitTransposeOp(mlir::OpBuilder& builder, + const mlir::Location& loc, mlir::Value input, + std::vector& perm_arr) { + auto tr_input_type = input.getType().cast(); + auto shape = tr_input_type.getShape(); + + auto perm_type = mlir::RankedTensorType::get( + {static_cast(perm_arr.size())}, builder.getIntegerType(64)); + + auto constant_attr = builder.getI64TensorAttr(perm_arr); + auto perm_op = + builder.create(loc, perm_type, constant_attr); + + std::vector transposed_shape(shape.begin(), shape.end()); + for (int i = 0; i < shape.size(); i++) { + transposed_shape[i] = shape[perm_arr[i]]; + } + auto transposed_type = mlir::RankedTensorType::get( + transposed_shape, tr_input_type.getElementType()); + + return builder.create(loc, transposed_type, input, + perm_op); +} + StatusOr EmitBarrierWithConstValue(mlir::OpBuilder& builder, mlir::Location loc, const Mesh& mesh, diff --git a/tensorflow/dtensor/mlir/collectives.h b/tensorflow/dtensor/mlir/collectives.h index 09aee28572caea..8cea0f3feab278 100644 --- a/tensorflow/dtensor/mlir/collectives.h +++ b/tensorflow/dtensor/mlir/collectives.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_DTENSOR_MLIR_COLLECTIVES_H_ #include +#include #include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" @@ -53,6 +54,11 @@ StatusOr EmitRelayout( const dtensor::Layout& tgt_layout, llvm::SmallPtrSet* newly_created_ops = nullptr); +// Emits TransposeOp that permutes the input shape. +mlir::Operation* EmitTransposeOp(mlir::OpBuilder& builder, + const mlir::Location& loc, mlir::Value input, + std::vector& perm_arr); + // Emits collective ops to reduce `input` over `reduced_dims`. StatusOr EmitAllReduce( mlir::OpBuilder& builder, const dtensor::Layout& output_layout, diff --git a/tensorflow/dtensor/mlir/expansions/fft_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/fft_spmd_expander.cc new file mode 100644 index 00000000000000..69d07dfb06a4e8 --- /dev/null +++ b/tensorflow/dtensor/mlir/expansions/fft_spmd_expander.cc @@ -0,0 +1,512 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/dtensor/mlir/expansions/fft_spmd_expander.h" + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/TypeSwitch.h" +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/IRMapping.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/DebugStringHelper.h" // from @llvm-project +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/statusor.h" +#include "tensorflow/dtensor/cc/dstatus.h" +#include "tensorflow/dtensor/cc/tensor_layout.h" +#include "tensorflow/dtensor/mlir/collectives.h" +#include "tensorflow/dtensor/mlir/expansions/meta_spmd_expander.h" +#include "tensorflow/dtensor/mlir/ir/tf_dtensor.h" +#include "tensorflow/dtensor/mlir/layout_parsing.h" +#include "tensorflow/dtensor/mlir/op_utils.h" +#include "tensorflow/dtensor/mlir/shape_utils.h" +#include "tensorflow/dtensor/mlir/spmd_expander.h" +#include "tensorflow/dtensor/mlir/spmd_expander_common.h" +#include "tensorflow/dtensor/mlir/value_utils.h" + +namespace tensorflow { +namespace dtensor { +namespace { + +// Return the last unsharded axis. Return -1 for fully replicated dtensor +int LastUnshardedAxis(const std::vector& sharding_specs) { + for (int i = sharding_specs.size() - 1; i >= 0; --i) + if (sharding_specs[i] == Layout::kUnshardedDim) return i; + return -1; +} + +// Return false for non-distributed *FFTN. +bool IsDistributedFFTN(int num_transform_axes, const Layout& layout) { + std::vector sharding_specs = layout.sharding_spec_strs(); + for (int i = sharding_specs.size() - num_transform_axes; + i < sharding_specs.size(); ++i) + if (sharding_specs[i] != Layout::kUnshardedDim) { + return true; + } + return false; +} + +bool IsComplexFFT(mlir::Value input) { + auto data_type = + mlir::dyn_cast(input.getType()).getElementType(); + return data_type.isa(); +} + +Status IsProperFFTLength(mlir::Operation* op, + const llvm::SmallVector& fft_length_vec) { + TF_ASSIGN_OR_RETURN(auto input_layout, + ExtractRequiredLayoutFromOperand(op->getOperand(0))); + const Mesh& mesh = input_layout.mesh(); + int axes = fft_length_vec.size(); + // RFFT in DTensor requires axes except -1 to have the same shape as input. + llvm::ArrayRef input_shape = + mlir::dyn_cast(op->getOperand(0).getType()).getShape(); + std::vector input_shape_vec = input_shape.vec(); + for (int i = 0; i < axes - 1; ++i) + if (fft_length_vec[i] != input_shape_vec[input_shape_vec.size() - axes + i]) + return absl::InvalidArgumentError( + "DTensor RFFTOps are not suitable for current 'fft_length'."); + + // fft_length[-1] should be divisible by the corresponding device number. + int num_of_devices_last_dim = mesh.dim_sizes()[input_shape_vec.size() - 1]; + if (axes > 1 && fft_length_vec[axes - 1] % num_of_devices_last_dim == 1) + return absl::InvalidArgumentError( + "The values with current 'fft_length' are not shardable."); + return absl::OkStatus(); +} + +StatusOr> ExtractFFTLengthFromOp( + mlir::Operation* op) { + mlir::Value fft_length = op->getOperand(1); + llvm::SmallVector fft_length_vec; + TF_RETURN_IF_ERROR(ExtractConstVectorFromValue(fft_length, &fft_length_vec)); + TF_RETURN_IF_ERROR(IsProperFFTLength(op, fft_length_vec)); + return fft_length_vec; +} + +// Forward flow for FFT and backward flow for iFFT +// sharding_specs has at least one element +void PropagateFFTLayout(std::vector& sharding_specs, int axes) { + int last_unsharded_axis = LastUnshardedAxis(sharding_specs); + if (last_unsharded_axis == -1) + sharding_specs[sharding_specs.size() - 1] = Layout::kUnshardedDim; + else if (last_unsharded_axis != sharding_specs.size() - 1) + std::iter_swap(sharding_specs.end() - 1, + sharding_specs.begin() + last_unsharded_axis); + std::string last_sharding_spec = sharding_specs.back(); + sharding_specs.pop_back(); + sharding_specs.insert(sharding_specs.end() - axes + 1, last_sharding_spec); +} + +// Backward flow for FFT and forward flow for iFFT +// sharding_specs has at least one element +void PropagateIFFTLayout(std::vector& sharding_specs, int axes) { + int last_unsharded_axis = LastUnshardedAxis(sharding_specs); + if (last_unsharded_axis == -1) + sharding_specs[sharding_specs.size() - axes] = Layout::kUnshardedDim; + else if (last_unsharded_axis != sharding_specs.size() - axes) + std::iter_swap(sharding_specs.end() - axes, + sharding_specs.begin() + last_unsharded_axis); + std::string unsharded_axis = sharding_specs[sharding_specs.size() - axes]; + sharding_specs.erase(sharding_specs.end() - axes); + sharding_specs.push_back(unsharded_axis); +} + +StatusOr EmitTransposeRelayout(mlir::OpBuilder& builder, + mlir::Location location, + mlir::Value input, + const Layout& init_layout, + const Mesh& mesh, + std::pair& perm_axes) { + std::vector perm_for_transpose; + int input_rank = ValueRank(input); + perm_for_transpose.reserve(input_rank); + for (int ax = 0; ax < input_rank; ++ax) { + perm_for_transpose.push_back(ax); + } + std::iter_swap(perm_for_transpose.begin() + perm_axes.first, + perm_for_transpose.begin() + perm_axes.second); + mlir::Operation* transpose_op = + EmitTransposeOp(builder, location, input, perm_for_transpose); + mlir::Value transposed_input = transpose_op->getResult(0); + + std::vector sharding_specs = init_layout.sharding_spec_strs(); + std::iter_swap(sharding_specs.begin() + perm_axes.first, + sharding_specs.begin() + perm_axes.second); + TF_ASSIGN_OR_RETURN( + auto transposed_input_layout, + Layout::GetLayout(init_layout.type(), sharding_specs, mesh)); + TF_ASSIGN_OR_RETURN( + transposed_input, + EmitRelayout(transposed_input, transposed_input_layout, init_layout)); + return transposed_input; +} + +Status NormalizeAxes(std::vector& transform_axes, int input_rank) { + std::sort(transform_axes.begin(), transform_axes.end()); + for (int i = 0; i < transform_axes.size(); ++i) { + if (transform_axes[i] >= input_rank) { + return absl::InvalidArgumentError("Axes to perform FFTN on are invalid."); + } else if (transform_axes[i] < 0) { + transform_axes[i] += input_rank; + } + } + if (transform_axes.empty()) { + transform_axes.reserve(input_rank); + for (int i = 0; i < input_rank; ++i) transform_axes.push_back(i); + } + return absl::OkStatus(); +} + +// Make the last axis of the layout unsharded by swapping with another +// axis or forcing the last axis to be unsharded if fully sharded. +StatusOr UnshardLastAxis(int input_rank, const Layout& layout, + const Mesh& mesh) { + if (input_rank < 1) + return absl::InvalidArgumentError("input_rank must be >= 1"); + std::vector input_sharding_specs = layout.sharding_spec_strs(); + int last_unsharded_axis = LastUnshardedAxis(input_sharding_specs); + if (last_unsharded_axis == -1) + input_sharding_specs[input_rank - 1] = Layout::kUnshardedDim; + else if (last_unsharded_axis != input_rank - 1) + std::iter_swap(input_sharding_specs.end() - 1, + input_sharding_specs.begin() + last_unsharded_axis); + return Layout::GetLayout(layout.type(), input_sharding_specs, mesh); +} + +// Lowering FFTN/RFFTN operation for the first N-1 axes to N-1 1-d FFTOps. +StatusOr ExpandFFTNImpl( + mlir::Operation* xfft_op, mlir::Operation* fft_op, + std::vector& transform_axes, const int input_rank, + mlir::OpBuilder& builder, mlir::Location location, + const Layout& input_layout, const Mesh& mesh) { + SetSingleLayoutOnOp(xfft_op, input_layout); + mlir::Value output = xfft_op->getResult(0); + + if (transform_axes.empty()) { + fft_op->getResult(0).replaceAllUsesWith(xfft_op->getResult(0)); + fft_op->erase(); + return InferSPMDExpandedLocalShape(xfft_op); + } + + std::vector perm; + perm.reserve(input_rank); + for (int i = 0; i < input_rank - 1; ++i) { + perm.push_back(i); + } + perm.insert(perm.end() - transform_axes.size(), input_rank - 1); + + mlir::TF::FFTOp new_fft_op; + Layout intermediate_layout = input_layout; + int ax; + while (!transform_axes.empty()) { + ax = transform_axes.back(); + transform_axes.pop_back(); + + std::pair perm_axes = {ax, input_rank - 1}; + TF_ASSIGN_OR_RETURN( + mlir::Value transposed_input, + EmitTransposeRelayout(builder, location, output, intermediate_layout, + mesh, perm_axes)); + + new_fft_op = builder.create( + location, transposed_input.getType(), transposed_input); + SetSingleLayoutOnOp(new_fft_op, intermediate_layout); + output = new_fft_op.getOutput(); + } + + mlir::Operation* transpose_op = + EmitTransposeOp(builder, location, output, perm); + mlir::Value transposed_output = transpose_op->getResult(0); + + llvm::SmallPtrSet newly_created_ops; + builder.setInsertionPointAfter(new_fft_op); + TF_ASSIGN_OR_RETURN(auto final_output, + EmitRelayout(transposed_output, intermediate_layout, + intermediate_layout, &newly_created_ops)); + fft_op->getOpResult(0).replaceAllUsesExcept(final_output, newly_created_ops); + fft_op->erase(); + return InferSPMDExpandedLocalShape(final_output.getDefiningOp()); +} + +StatusOr ExpandFFTN(mlir::Operation* fft_op, + std::vector& transform_axes) { + mlir::OpBuilder builder(fft_op); + mlir::Value input = fft_op->getOperand(0); + TF_ASSIGN_OR_RETURN(auto input_layout, + ExtractRequiredLayoutFromOperand(input)); + const int input_rank = ValueRank(input); + const Mesh& mesh = input_layout.mesh(); + mlir::Location location = fft_op->getLoc(); + + TF_RETURN_IF_ERROR(NormalizeAxes(transform_axes, input_rank)); + int num_transform_axes = transform_axes.size(); + + if (!IsDistributedFFTN(num_transform_axes, input_layout)) + return InferSPMDExpandedLocalShape(fft_op); + + // FIXME(b/292286720): Since the last axis must be one of the transform_axes + // in current 1/2/3d transform ops, we don't need to find the last + // transofrm_axes and can just use -1. Need to be fixed by adding transpose op + // prior to this or unshard the last transform_axes. + TF_ASSIGN_OR_RETURN(Layout intermediate_layout, + UnshardLastAxis(input_rank, input_layout, mesh)); + TF_ASSIGN_OR_RETURN(mlir::Value intermediate, + EmitRelayout(input, input_layout, intermediate_layout)); + + if (IsComplexFFT(input)) { + // FFT for the last axis. + mlir::TF::FFTOp fft_output_op = builder.create( + location, intermediate.getType(), intermediate); + transform_axes.pop_back(); + return ExpandFFTNImpl(fft_output_op, fft_op, transform_axes, input_rank, + builder, location, intermediate_layout, mesh); + } else { + TF_ASSIGN_OR_RETURN(auto fft_length_vec, ExtractFFTLengthFromOp(fft_op)); + mlir::Value fft_length = IntConst( + builder, location, (int32)fft_length_vec[num_transform_axes - 1]); + llvm::ArrayRef rfft_shape = + mlir::dyn_cast(intermediate.getType()).getShape(); + std::vector rfft_shape_vec = rfft_shape.vec(); + int num_of_devices_last_dim = mesh.dim_sizes()[input_rank - 1]; + rfft_shape_vec[input_rank - 1] = + fft_length_vec[num_transform_axes - 1] / 2 + 1; + if (fft_length_vec.size() > 1 && + rfft_shape_vec[input_rank - 1] % num_of_devices_last_dim != 0) + return absl::InvalidArgumentError( + "No suitable algorithm in DTensor found for current 'fft_length'."); + + mlir::Type output_type = mlir::RankedTensorType::get( + rfft_shape_vec, + mlir::dyn_cast(fft_op->getResult(0).getType()) + .getElementType()); + // Real FFT for the last axis. + mlir::TF::RFFTOp rfft_output_op = builder.create( + location, output_type, intermediate, fft_length); + transform_axes.pop_back(); + return ExpandFFTNImpl(rfft_output_op, fft_op, transform_axes, input_rank, + builder, location, intermediate_layout, mesh); + } +} + +StatusOr ExpandIFFTN(mlir::Operation* ifft_op, + std::vector& transform_axes) { + mlir::OpBuilder builder(ifft_op); + mlir::Value input = ifft_op->getOperand(0); + TF_ASSIGN_OR_RETURN(auto input_layout, + ExtractRequiredLayoutFromOperand(input)); + const auto input_rank = ValueRank(input); + const Mesh& mesh = input_layout.mesh(); + mlir::Location location = ifft_op->getLoc(); + std::vector input_sharding_specs = + input_layout.sharding_spec_strs(); + + TF_RETURN_IF_ERROR(NormalizeAxes(transform_axes, input_rank)); + int num_transform_axes = transform_axes.size(); + + if (!IsDistributedFFTN(num_transform_axes, input_layout)) + return InferSPMDExpandedLocalShape(ifft_op); + + input_sharding_specs.push_back( + input_sharding_specs[input_rank - num_transform_axes]); + input_sharding_specs.erase(input_sharding_specs.begin() + input_rank - + num_transform_axes); + TF_ASSIGN_OR_RETURN( + input_layout, + Layout::GetLayout(input_layout.type(), input_sharding_specs, mesh)); + TF_ASSIGN_OR_RETURN(Layout intermediate_layout, + UnshardLastAxis(input_rank, input_layout, mesh)); + + std::vector perm; + perm.reserve(input_rank); + for (int i = 0; i < input_rank; ++i) + if (i != input_rank - num_transform_axes) { + perm.push_back(i); + } + perm.push_back(input_rank - num_transform_axes); + mlir::Operation* transpose_op = + EmitTransposeOp(builder, location, input, perm); + mlir::Value transposed_output = transpose_op->getResult(0); + TF_ASSIGN_OR_RETURN( + transposed_output, + EmitRelayout(transposed_output, input_layout, intermediate_layout)); + + mlir::TF::IFFTOp fft_new_op; + int ax; // current axis + while (transform_axes.size() > 1) { + ax = transform_axes[1] - 1; + transform_axes.erase(transform_axes.begin() + 1); + fft_new_op = builder.create( + location, transposed_output.getType(), transposed_output); + SetSingleLayoutOnOp(fft_new_op, intermediate_layout); + transposed_output = fft_new_op.getOutput(); + // Swap and relayout + std::pair perm_axes = {ax, input_rank - 1}; + TF_ASSIGN_OR_RETURN( + transposed_output, + EmitTransposeRelayout(builder, location, transposed_output, + intermediate_layout, mesh, perm_axes)); + } + + if (IsComplexFFT(ifft_op->getResult(0))) { + // IFFT for the last axis. + mlir::TF::IFFTOp ifft_output_op = builder.create( + location, transposed_output.getType(), transposed_output); + SetSingleLayoutOnOp(ifft_output_op, intermediate_layout); + builder.setInsertionPointAfter(ifft_output_op); + + ifft_op->getResult(0).replaceAllUsesWith(ifft_output_op); + ifft_op->erase(); + return InferSPMDExpandedLocalShape(ifft_output_op); + } else { + TF_ASSIGN_OR_RETURN(auto complex_fft_length_vec, + ExtractFFTLengthFromOp(ifft_op)); + mlir::Value ifft_length = + IntConst(builder, location, + (int32)complex_fft_length_vec[num_transform_axes - 1]); + // IRFFT for the last axis. + mlir::TF::IRFFTOp irfft_output_op = builder.create( + location, ifft_op->getResult(0).getType(), transposed_output, + ifft_length); + SetSingleLayoutOnOp(irfft_output_op, intermediate_layout); + builder.setInsertionPointAfter(irfft_output_op); + ifft_op->getResult(0).replaceAllUsesWith(irfft_output_op.getOutput()); + ifft_op->erase(); + return InferSPMDExpandedLocalShape(irfft_output_op); + } +} +} // namespace + +StatusOr FFTSPMDExpander::ExpandOp(mlir::Operation* op) { + std::vector last_axis{-1}; + std::vector last_2_axes{-2, -1}; + std::vector last_3_axes{-3, -2, -1}; + return llvm::TypeSwitch>(op) + // Forward prop ops. + .Case( + [&](auto op) { return ExpandFFTN(op, last_axis); }) + .Case( + [&](auto op) { return ExpandFFTN(op, last_2_axes); }) + .Case( + [&](auto op) { return ExpandFFTN(op, last_3_axes); }) + + // Backward prop ops. + .Case( + [&](auto op) { return ExpandIFFTN(op, last_axis); }) + .Case( + [&](auto op) { return ExpandIFFTN(op, last_2_axes); }) + .Case( + [&](auto op) { return ExpandIFFTN(op, last_3_axes); }) + .Default([&](auto op) { return InferSPMDExpandedLocalShape(op); }); +} + +StatusOr> FFTSPMDExpander::ComputeLayoutForward( + mlir::Operation* op, const llvm::DenseMap& input_layouts) { + if (input_layouts.find(0) == input_layouts.end()) + return llvm::DenseMap(); + + const Layout& input_layout = input_layouts.lookup(0); + std::vector sharding_specs = input_layout.sharding_spec_strs(); + if (sharding_specs.empty()) + return absl::FailedPreconditionError( + absl::StrCat(OpName(op), " has no sharding specs.")); + llvm::TypeSwitch(op) + .Case( + [&sharding_specs](auto op) { PropagateFFTLayout(sharding_specs, 1); }) + .Case( + [&sharding_specs](auto op) { PropagateFFTLayout(sharding_specs, 2); }) + .Case( + [&sharding_specs](auto op) { PropagateFFTLayout(sharding_specs, 3); }) + + .Case([&sharding_specs](auto op) { + PropagateIFFTLayout(sharding_specs, 1); + }) + .Case( + [&sharding_specs](auto op) { + PropagateIFFTLayout(sharding_specs, 2); + }) + .Case( + [&sharding_specs](auto op) { + PropagateIFFTLayout(sharding_specs, 3); + }); + + TF_ASSIGN_OR_RETURN(auto result_layout, + Layout::GetLayout(input_layout.type(), sharding_specs, + input_layout.mesh())); + if (result_layout.rank() != input_layout.rank()) + return absl::FailedPreconditionError(absl::StrCat( + OpName(op), " derived output layout rank is ", result_layout.rank(), + " not ", input_layout.rank(), " as expected.")); + + return llvm::DenseMap({{0, result_layout}}); +} + +StatusOr> FFTSPMDExpander::ComputeLayoutBackward( + mlir::Operation* op, const llvm::DenseMap& output_layouts) { + if (output_layouts.find(0) == output_layouts.end()) + return llvm::DenseMap(); + + const Layout& output_layout = output_layouts.lookup(0); + std::vector sharding_specs = output_layout.sharding_spec_strs(); + if (sharding_specs.empty()) + return absl::FailedPreconditionError( + absl::StrCat(OpName(op), " has no sharding specs.")); + + llvm::TypeSwitch(op) + .Case([&sharding_specs](auto op) { + PropagateIFFTLayout(sharding_specs, 1); + }) + .Case([&sharding_specs](auto op) { + PropagateIFFTLayout(sharding_specs, 2); + }) + .Case([&sharding_specs](auto op) { + PropagateIFFTLayout(sharding_specs, 3); + }) + + .Case( + [&sharding_specs](auto op) { PropagateFFTLayout(sharding_specs, 1); }) + .Case( + [&sharding_specs](auto op) { PropagateFFTLayout(sharding_specs, 2); }) + .Case( + [&sharding_specs](auto op) { + PropagateFFTLayout(sharding_specs, 3); + }); + + TF_ASSIGN_OR_RETURN(auto result_layout, + Layout::GetLayout(output_layout.type(), sharding_specs, + output_layout.mesh())); + if (result_layout.rank() != output_layout.rank()) + return absl::FailedPreconditionError(absl::StrCat( + OpName(op), " derived output layout rank is ", result_layout.rank(), + " not ", output_layout.rank(), " as expected.")); + + return llvm::DenseMap({{0, result_layout}}); +} + +} // namespace dtensor +} // namespace tensorflow diff --git a/tensorflow/dtensor/mlir/expansions/fft_spmd_expander.h b/tensorflow/dtensor/mlir/expansions/fft_spmd_expander.h new file mode 100644 index 00000000000000..8c0bd491d7b0d9 --- /dev/null +++ b/tensorflow/dtensor/mlir/expansions/fft_spmd_expander.h @@ -0,0 +1,44 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_DTENSOR_MLIR_EXPANSIONS_FFT_SPMD_EXPANDER_H_ +#define TENSORFLOW_DTENSOR_MLIR_EXPANSIONS_FFT_SPMD_EXPANDER_H_ + +#include + +#include "mlir/IR/Builders.h" // from @llvm-project +#include "tensorflow/dtensor/cc/dstatus.h" +#include "tensorflow/dtensor/mlir/spmd_expander.h" + +namespace tensorflow { +namespace dtensor { + +// Implement Layout propagation and SPMD expansion for FFT ops. +class FFTSPMDExpander : public SPMDExpanderBase { + public: + StatusOr ExpandOp(mlir::Operation* op) override; + + StatusOr> ComputeLayoutForward( + mlir::Operation* op, + const llvm::DenseMap& input_layouts) override; + + StatusOr> ComputeLayoutBackward( + mlir::Operation* op, + const llvm::DenseMap& output_layouts) override; +}; + +} // namespace dtensor +} // namespace tensorflow +#endif // TENSORFLOW_DTENSOR_MLIR_EXPANSIONS_FFT_SPMD_EXPANDER_H_ diff --git a/tensorflow/dtensor/mlir/spmd_expanders.cc b/tensorflow/dtensor/mlir/spmd_expanders.cc index 9dd0fdad513e14..8c772e42a77e53 100644 --- a/tensorflow/dtensor/mlir/spmd_expanders.cc +++ b/tensorflow/dtensor/mlir/spmd_expanders.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/dtensor/mlir/expansions/einsum_spmd_expander.h" #include "tensorflow/dtensor/mlir/expansions/elementwise_spmd_expander.h" #include "tensorflow/dtensor/mlir/expansions/expanddims_spmd_expander.h" +#include "tensorflow/dtensor/mlir/expansions/fft_spmd_expander.h" #include "tensorflow/dtensor/mlir/expansions/fill_spmd_expander.h" #include "tensorflow/dtensor/mlir/expansions/gather_spmd_expander.h" #include "tensorflow/dtensor/mlir/expansions/identity_n_spmd_expander.h" @@ -448,42 +449,18 @@ REGISTER_SPMD(AdjustContrastv2, TF::AdjustContrastv2Op, REGISTER_SPMD(AdjustSaturation, TF::AdjustSaturationOp, DataparallelSPMDExpander, llvm::DenseMap{{0, 3}}, llvm::DenseMap{{0, 3}}); -REGISTER_SPMD(FFT, TF::FFTOp, DataparallelSPMDExpander, - llvm::DenseMap{{0, 1}}, - llvm::DenseMap{{0, 1}}); -REGISTER_SPMD(FFT2D, TF::FFT2DOp, DataparallelSPMDExpander, - llvm::DenseMap{{0, 1}}, - llvm::DenseMap{{0, 1}}); -REGISTER_SPMD(FFT3D, TF::FFT3DOp, DataparallelSPMDExpander, - llvm::DenseMap{{0, 1}}, - llvm::DenseMap{{0, 1}}); -REGISTER_SPMD(IFFT, TF::IFFTOp, DataparallelSPMDExpander, - llvm::DenseMap{{0, 1}}, - llvm::DenseMap{{0, 1}}); -REGISTER_SPMD(IFFT2D, TF::IFFT2DOp, DataparallelSPMDExpander, - llvm::DenseMap{{0, 1}}, - llvm::DenseMap{{0, 1}}); -REGISTER_SPMD(IFFT3D, TF::IFFT3DOp, DataparallelSPMDExpander, - llvm::DenseMap{{0, 1}}, - llvm::DenseMap{{0, 1}}); -REGISTER_SPMD(IRFFT, TF::IRFFTOp, DataparallelSPMDExpander, - llvm::DenseMap{{0, 1}}, - llvm::DenseMap{{0, 1}}); -REGISTER_SPMD(IRFFT2D, TF::IRFFT2DOp, DataparallelSPMDExpander, - llvm::DenseMap{{0, 1}}, - llvm::DenseMap{{0, 1}}); -REGISTER_SPMD(IRFFT3D, TF::IRFFT3DOp, DataparallelSPMDExpander, - llvm::DenseMap{{0, 1}}, - llvm::DenseMap{{0, 1}}); -REGISTER_SPMD(RFFT, TF::RFFTOp, DataparallelSPMDExpander, - llvm::DenseMap{{0, 1}}, - llvm::DenseMap{{0, 1}}); -REGISTER_SPMD(RFFT2D, TF::RFFT2DOp, DataparallelSPMDExpander, - llvm::DenseMap{{0, 1}}, - llvm::DenseMap{{0, 1}}); -REGISTER_SPMD(RFFT3D, TF::RFFT3DOp, DataparallelSPMDExpander, - llvm::DenseMap{{0, 1}}, - llvm::DenseMap{{0, 1}}); +REGISTER_SPMD(FFT, TF::FFTOp, FFTSPMDExpander); +REGISTER_SPMD(FFT2D, TF::FFT2DOp, FFTSPMDExpander); +REGISTER_SPMD(FFT3D, TF::FFT3DOp, FFTSPMDExpander); +REGISTER_SPMD(IFFT, TF::IFFTOp, FFTSPMDExpander); +REGISTER_SPMD(IFFT2D, TF::IFFT2DOp, FFTSPMDExpander); +REGISTER_SPMD(IFFT3D, TF::IFFT3DOp, FFTSPMDExpander); +REGISTER_SPMD(IRFFT, TF::IRFFTOp, FFTSPMDExpander); +REGISTER_SPMD(IRFFT2D, TF::IRFFT2DOp, FFTSPMDExpander); +REGISTER_SPMD(IRFFT3D, TF::IRFFT3DOp, FFTSPMDExpander); +REGISTER_SPMD(RFFT, TF::RFFTOp, FFTSPMDExpander); +REGISTER_SPMD(RFFT2D, TF::RFFT2DOp, FFTSPMDExpander); +REGISTER_SPMD(RFFT3D, TF::RFFT3DOp, FFTSPMDExpander); REGISTER_SPMD(Cholesky, TF::CholeskyOp, DataparallelSPMDExpander, llvm::DenseMap{{0, 2}}, llvm::DenseMap{{0, 2}}); diff --git a/tensorflow/dtensor/mlir/utils/BUILD b/tensorflow/dtensor/mlir/utils/BUILD index 48af934cffa8f6..0d9a5a39429fe6 100644 --- a/tensorflow/dtensor/mlir/utils/BUILD +++ b/tensorflow/dtensor/mlir/utils/BUILD @@ -36,6 +36,7 @@ cc_library( "//tensorflow/dtensor/cc:dtensor_utils", "//tensorflow/dtensor/cc:layout_to_xla_sharding", "//tensorflow/dtensor/cc:tensor_layout", + "//tensorflow/dtensor/mlir:collectives", "//tensorflow/dtensor/mlir:collectives_common", "//tensorflow/dtensor/mlir:create_dtensor_mlir_passes", "//tensorflow/dtensor/mlir:device_utils", diff --git a/tensorflow/dtensor/mlir/utils/collective_lowering.cc b/tensorflow/dtensor/mlir/utils/collective_lowering.cc index 5c03db58e8e9e3..753fe04e7898c5 100644 --- a/tensorflow/dtensor/mlir/utils/collective_lowering.cc +++ b/tensorflow/dtensor/mlir/utils/collective_lowering.cc @@ -49,6 +49,7 @@ limitations under the License. #include "tensorflow/dtensor/cc/dstatus.h" #include "tensorflow/dtensor/cc/dtensor_utils.h" #include "tensorflow/dtensor/cc/tensor_layout.h" +#include "tensorflow/dtensor/mlir/collectives.h" #include "tensorflow/dtensor/mlir/collectives_common.h" #include "tensorflow/dtensor/mlir/device_utils.h" #include "tensorflow/dtensor/mlir/dtensor_dialect/ir/dialect.h" @@ -272,31 +273,6 @@ mlir::Operation* EmitCollectiveReduce( return collective_reduce; } -// Emits TransposeOp with permuting passed dim_idx with first axis. -mlir::Operation* EmitTransposeOp(mlir::OpBuilder& builder, - const mlir::Location& loc, mlir::Value input, - std::vector perm_arr) { - auto tr_input_type = input.getType().cast(); - auto shape = tr_input_type.getShape(); - - auto perm_type = mlir::RankedTensorType::get( - {static_cast(perm_arr.size())}, builder.getIntegerType(64)); - - auto constant_attr = builder.getI64TensorAttr(perm_arr); - auto perm_op = - builder.create(loc, perm_type, constant_attr); - - std::vector transposed_shape(shape.begin(), shape.end()); - for (int i = 0; i < shape.size(); i++) { - transposed_shape[i] = shape[perm_arr[i]]; - } - auto transposed_type = mlir::RankedTensorType::get( - transposed_shape, tr_input_type.getElementType()); - - return builder.create(loc, transposed_type, input, - perm_op); -} - mlir::Operation* EmitCollectiveReduceScatter( mlir::OpBuilder& builder, const mlir::Location& loc, mlir::Value input, mlir::Type output_type, const std::string& reduce_op_str, diff --git a/tensorflow/dtensor/python/tests/BUILD b/tensorflow/dtensor/python/tests/BUILD index 00b4303a37c360..d6cfb986c3043f 100644 --- a/tensorflow/dtensor/python/tests/BUILD +++ b/tensorflow/dtensor/python/tests/BUILD @@ -532,6 +532,7 @@ dtensor_test( "//tensorflow/python/ops:random_ops", "//tensorflow/python/ops:resource_variable_ops_gen", "//tensorflow/python/ops:special_math_ops", + "//tensorflow/python/ops:spectral_ops_gen", "//tensorflow/python/ops:stateless_random_ops", "//tensorflow/python/ops:stateless_random_ops_gen", "//tensorflow/python/ops:string_ops_gen", @@ -585,6 +586,7 @@ dtensor_test( "//tensorflow/python/ops:random_ops", "//tensorflow/python/ops:resource_variable_ops_gen", "//tensorflow/python/ops:special_math_ops", + "//tensorflow/python/ops:spectral_ops_gen", "//tensorflow/python/ops:stateless_random_ops", "//tensorflow/python/ops:stateless_random_ops_gen", "//tensorflow/python/ops:string_ops_gen", diff --git a/tensorflow/dtensor/python/tests/spmd_test.py b/tensorflow/dtensor/python/tests/spmd_test.py index e8267bea656e36..ece2253cc28b6e 100644 --- a/tensorflow/dtensor/python/tests/spmd_test.py +++ b/tensorflow/dtensor/python/tests/spmd_test.py @@ -44,6 +44,7 @@ from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import gen_resource_variable_ops +from tensorflow.python.ops import gen_spectral_ops from tensorflow.python.ops import gen_stateless_random_ops from tensorflow.python.ops import gen_string_ops from tensorflow.python.ops import math_ops @@ -2685,6 +2686,1115 @@ def testMaxPoolGradWithBatchShardedInputs(self, padding): dtensor_grad[0]) +class DTensorFFTSPMDTest(test_util.DTensorBaseTest): + + def setUp(self): + super().setUp() + + # Builds a 2x3x3 mesh. + self._mesh_dim_b = 'b' + self._mesh_dim_x = 'x' + self._mesh_dim_y = 'y' + self._dims = [self._mesh_dim_b, self._mesh_dim_x, self._mesh_dim_y] + + global_ids = test_util.create_device_ids_array([1, 1, 2]) + local_ids = np.ravel(global_ids).tolist() + + mesh_dict = { + device: Mesh(self._dims, global_ids, local_ids, + test_util.create_device_list([1, 1, 2], 'CPU')) + for device in ('CPU', 'GPU', 'TPU') + } + self.mesh = self.configTestMesh(mesh_dict) + + @parameterized.named_parameters( + dict( + testcase_name='_unsharded_complex64', + input_layout_specs=[ + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + ], + expected_layout_specs=[ + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + ], + input_datatype=dtypes.complex64, + ), + dict( + testcase_name='_unsharded_complex128', + input_layout_specs=[ + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + ], + expected_layout_specs=[ + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + ], + input_datatype=dtypes.complex128, + ), + dict( + testcase_name='_fullySharded_complex64', + input_layout_specs=['b', 'x', 'y'], + expected_layout_specs=['b', 'x', layout_lib.UNSHARDED], + input_datatype=dtypes.complex64, + ), + dict( + testcase_name='_fullySharded_compelx128', + input_layout_specs=['b', 'x', 'y'], + expected_layout_specs=['b', 'x', layout_lib.UNSHARDED], + input_datatype=dtypes.complex128, + ), + dict( + testcase_name='_partiallySharded1_complex64', + input_layout_specs=['b', layout_lib.UNSHARDED, 'y'], + expected_layout_specs=['b', 'y', layout_lib.UNSHARDED], + input_datatype=dtypes.complex64, + ), + dict( + testcase_name='_partiallySharded1_complex128', + input_layout_specs=['b', layout_lib.UNSHARDED, 'y'], + expected_layout_specs=['b', 'y', layout_lib.UNSHARDED], + input_datatype=dtypes.complex128, + ), + dict( + testcase_name='_partiallySharded2_complex64', + input_layout_specs=['b', 'x', layout_lib.UNSHARDED], + expected_layout_specs=['b', 'x', layout_lib.UNSHARDED], + input_datatype=dtypes.complex64, + ), + dict( + testcase_name='_partiallySharded2_complex128', + input_layout_specs=['b', 'x', layout_lib.UNSHARDED], + expected_layout_specs=['b', 'x', layout_lib.UNSHARDED], + input_datatype=dtypes.complex128, + ), + ) + def testFFT(self, input_layout_specs, expected_layout_specs, input_datatype): + a = np.random.normal(0, 10, 64).reshape([2, 4, 8]) + b = np.random.normal(0, 10, 64).reshape([2, 4, 8]) + x = constant_op.constant( + a + b * 1j, + dtype=input_datatype, + ) + expected_result = np.fft.fft(x) + if input_datatype == dtypes.complex64: + expected_result = expected_result.astype(np.complex64) + x = api.copy_to_mesh(x, Layout(input_layout_specs, self.mesh)) + result = gen_spectral_ops.fft(x) + self.assertDTensorEqual( + expected_result, Layout(expected_layout_specs, self.mesh), result + ) + + @parameterized.named_parameters( + dict( + testcase_name='_unsharded_complex64', + input_layout_specs=[ + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + ], + expected_layout_specs=[ + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + ], + input_datatype=dtypes.complex64, + ), + dict( + testcase_name='_unsharded_complex128', + input_layout_specs=[ + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + ], + expected_layout_specs=[ + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + ], + input_datatype=dtypes.complex128, + ), + dict( + testcase_name='_fullySharded_complex64', + input_layout_specs=['b', 'x', 'y'], + expected_layout_specs=['b', 'x', layout_lib.UNSHARDED], + input_datatype=dtypes.complex64, + ), + dict( + testcase_name='_fullySharded_compelx128', + input_layout_specs=['b', 'x', 'y'], + expected_layout_specs=['b', 'x', layout_lib.UNSHARDED], + input_datatype=dtypes.complex128, + ), + dict( + testcase_name='_partiallySharded1_complex64', + input_layout_specs=['b', layout_lib.UNSHARDED, 'y'], + expected_layout_specs=['b', 'y', layout_lib.UNSHARDED], + input_datatype=dtypes.complex64, + ), + dict( + testcase_name='_partiallySharded1_complex128', + input_layout_specs=['b', layout_lib.UNSHARDED, 'y'], + expected_layout_specs=['b', 'y', layout_lib.UNSHARDED], + input_datatype=dtypes.complex128, + ), + dict( + testcase_name='_partiallySharded2_complex64', + input_layout_specs=['b', 'x', layout_lib.UNSHARDED], + expected_layout_specs=['b', 'x', layout_lib.UNSHARDED], + input_datatype=dtypes.complex64, + ), + dict( + testcase_name='_partiallySharded2_complex128', + input_layout_specs=['b', 'x', layout_lib.UNSHARDED], + expected_layout_specs=['b', 'x', layout_lib.UNSHARDED], + input_datatype=dtypes.complex128, + ), + ) + def testIFFT(self, input_layout_specs, expected_layout_specs, input_datatype): + a = np.random.normal(0, 10, 64).reshape([2, 4, 8]) + b = np.random.normal(0, 10, 64).reshape([2, 4, 8]) + x = constant_op.constant( + a + b * 1j, + dtype=input_datatype, + ) + expected_result = np.fft.ifft(x) + if input_datatype == dtypes.complex64: + expected_result = expected_result.astype(np.complex64) + x = api.copy_to_mesh(x, Layout(input_layout_specs, self.mesh)) + result = gen_spectral_ops.ifft(x) + self.assertDTensorEqual( + expected_result, Layout(expected_layout_specs, self.mesh), result + ) + + @parameterized.named_parameters( + dict( + testcase_name='_unsharded_complex64', + input_layout_specs=[ + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + ], + expected_layout_specs=[ + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + ], + input_datatype=dtypes.complex64, + ), + dict( + testcase_name='_unsharded_complex128', + input_layout_specs=[ + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + ], + expected_layout_specs=[ + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + ], + input_datatype=dtypes.complex128, + ), + dict( + testcase_name='_fullySharded_complex64', + input_layout_specs=['b', 'x', 'y'], + expected_layout_specs=['b', layout_lib.UNSHARDED, 'x'], + input_datatype=dtypes.complex64, + ), + dict( + testcase_name='_fullySharded_compelx128', + input_layout_specs=['b', 'x', 'y'], + expected_layout_specs=['b', layout_lib.UNSHARDED, 'x'], + input_datatype=dtypes.complex128, + ), + dict( + testcase_name='_partiallySharded1_complex64', + input_layout_specs=['b', layout_lib.UNSHARDED, 'y'], + expected_layout_specs=['b', layout_lib.UNSHARDED, 'y'], + input_datatype=dtypes.complex64, + ), + dict( + testcase_name='_partiallySharded1_complex128', + input_layout_specs=['b', layout_lib.UNSHARDED, 'y'], + expected_layout_specs=['b', layout_lib.UNSHARDED, 'y'], + input_datatype=dtypes.complex128, + ), + dict( + testcase_name='_partiallySharded2_complex64', + input_layout_specs=['b', 'x', layout_lib.UNSHARDED], + expected_layout_specs=['b', layout_lib.UNSHARDED, 'x'], + input_datatype=dtypes.complex64, + ), + dict( + testcase_name='_partiallySharded2_complex128', + input_layout_specs=['b', 'x', layout_lib.UNSHARDED], + expected_layout_specs=['b', layout_lib.UNSHARDED, 'x'], + input_datatype=dtypes.complex128, + ), + ) + def testFFT2(self, input_layout_specs, expected_layout_specs, input_datatype): + a = np.random.normal(0, 10, 64).reshape([2, 4, 8]) + b = np.random.normal(0, 10, 64).reshape([2, 4, 8]) + x = constant_op.constant( + a + b * 1j, + dtype=input_datatype, + ) + expected_result = np.fft.fft2(x) + if input_datatype == dtypes.complex64: + expected_result = expected_result.astype(np.complex64) + x = api.copy_to_mesh(x, Layout(input_layout_specs, self.mesh)) + result = gen_spectral_ops.fft2d(x) + self.assertDTensorEqual( + expected_result, Layout(expected_layout_specs, self.mesh), result + ) + + @parameterized.named_parameters( + dict( + testcase_name='_unsharded_complex64', + input_layout_specs=[ + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + ], + expected_layout_specs=[ + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + ], + input_datatype=dtypes.complex64, + ), + dict( + testcase_name='_unsharded_complex128', + input_layout_specs=[ + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + ], + expected_layout_specs=[ + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + ], + input_datatype=dtypes.complex128, + ), + dict( + testcase_name='_fullySharded_complex64', + input_layout_specs=['b', 'x', 'y'], + expected_layout_specs=['b', 'y', layout_lib.UNSHARDED], + input_datatype=dtypes.complex64, + ), + dict( + testcase_name='_fullySharded_compelx128', + input_layout_specs=['b', 'x', 'y'], + expected_layout_specs=['b', 'y', layout_lib.UNSHARDED], + input_datatype=dtypes.complex128, + ), + dict( + testcase_name='_partiallySharded1_complex64', + input_layout_specs=['b', layout_lib.UNSHARDED, 'y'], + expected_layout_specs=['b', 'y', layout_lib.UNSHARDED], + input_datatype=dtypes.complex64, + ), + dict( + testcase_name='_partiallySharded1_complex128', + input_layout_specs=['b', layout_lib.UNSHARDED, 'y'], + expected_layout_specs=['b', 'y', layout_lib.UNSHARDED], + input_datatype=dtypes.complex128, + ), + dict( + testcase_name='_partiallySharded2_complex64', + input_layout_specs=['b', 'x', layout_lib.UNSHARDED], + expected_layout_specs=['b', 'x', layout_lib.UNSHARDED], + input_datatype=dtypes.complex64, + ), + dict( + testcase_name='_partiallySharded2_complex128', + input_layout_specs=['b', 'x', layout_lib.UNSHARDED], + expected_layout_specs=['b', 'x', layout_lib.UNSHARDED], + input_datatype=dtypes.complex128, + ), + ) + def testIFFT2( + self, input_layout_specs, expected_layout_specs, input_datatype + ): + a = np.random.normal(0, 10, 64).reshape([2, 4, 8]) + b = np.random.normal(0, 10, 64).reshape([2, 4, 8]) + x = constant_op.constant( + a + b * 1j, + dtype=input_datatype, + ) + expected_result = np.fft.ifft2(x) + if input_datatype == dtypes.complex64: + expected_result = expected_result.astype(np.complex64) + x = api.copy_to_mesh(x, Layout(input_layout_specs, self.mesh)) + result = gen_spectral_ops.ifft2d(x) + self.assertDTensorEqual( + expected_result, Layout(expected_layout_specs, self.mesh), result + ) + + @parameterized.named_parameters( + dict( + testcase_name='_unsharded_complex64', + input_layout_specs=[ + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + ], + expected_layout_specs=[ + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + ], + input_datatype=dtypes.complex64, + ), + dict( + testcase_name='_unsharded_complex128', + input_layout_specs=[ + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + ], + expected_layout_specs=[ + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + ], + input_datatype=dtypes.complex128, + ), + dict( + testcase_name='_fullySharded_complex64', + input_layout_specs=['b', 'x', 'y'], + expected_layout_specs=[layout_lib.UNSHARDED, 'b', 'x'], + input_datatype=dtypes.complex64, + ), + dict( + testcase_name='_fullySharded_compelx128', + input_layout_specs=['b', 'x', 'y'], + expected_layout_specs=[layout_lib.UNSHARDED, 'b', 'x'], + input_datatype=dtypes.complex128, + ), + dict( + testcase_name='_partiallySharded1_complex64', + input_layout_specs=['b', layout_lib.UNSHARDED, 'y'], + expected_layout_specs=[layout_lib.UNSHARDED, 'b', 'y'], + input_datatype=dtypes.complex64, + ), + dict( + testcase_name='_partiallySharded1_complex128', + input_layout_specs=['b', layout_lib.UNSHARDED, 'y'], + expected_layout_specs=[layout_lib.UNSHARDED, 'b', 'y'], + input_datatype=dtypes.complex128, + ), + dict( + testcase_name='_partiallySharded2_complex64', + input_layout_specs=['b', 'x', layout_lib.UNSHARDED], + expected_layout_specs=[layout_lib.UNSHARDED, 'b', 'x'], + input_datatype=dtypes.complex64, + ), + dict( + testcase_name='_partiallySharded2_complex128', + input_layout_specs=['b', 'x', layout_lib.UNSHARDED], + expected_layout_specs=[layout_lib.UNSHARDED, 'b', 'x'], + input_datatype=dtypes.complex128, + ), + ) + def testFFT3(self, input_layout_specs, expected_layout_specs, input_datatype): + a = np.random.normal(0, 10, 64).reshape([2, 4, 8]) + b = np.random.normal(0, 10, 64).reshape([2, 4, 8]) + x = constant_op.constant( + a + b * 1j, + dtype=input_datatype, + ) + expected_result = np.fft.fftn(x) + if input_datatype == dtypes.complex64: + expected_result = expected_result.astype(np.complex64) + x = api.copy_to_mesh(x, Layout(input_layout_specs, self.mesh)) + result = gen_spectral_ops.fft3d(x) + self.assertDTensorEqual( + expected_result, Layout(expected_layout_specs, self.mesh), result + ) + + @parameterized.named_parameters( + dict( + testcase_name='_unsharded_complex64', + input_layout_specs=[ + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + ], + expected_layout_specs=[ + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + ], + input_datatype=dtypes.complex64, + ), + dict( + testcase_name='_unsharded_complex128', + input_layout_specs=[ + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + ], + expected_layout_specs=[ + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + ], + input_datatype=dtypes.complex128, + ), + dict( + testcase_name='_fullySharded_complex64', + input_layout_specs=['b', 'x', 'y'], + expected_layout_specs=['x', 'y', layout_lib.UNSHARDED], + input_datatype=dtypes.complex64, + ), + dict( + testcase_name='_fullySharded_compelx128', + input_layout_specs=['b', 'x', 'y'], + expected_layout_specs=['x', 'y', layout_lib.UNSHARDED], + input_datatype=dtypes.complex128, + ), + dict( + testcase_name='_partiallySharded1_complex64', + input_layout_specs=['b', layout_lib.UNSHARDED, 'y'], + expected_layout_specs=['b', 'y', layout_lib.UNSHARDED], + input_datatype=dtypes.complex64, + ), + dict( + testcase_name='_partiallySharded1_complex128', + input_layout_specs=['b', layout_lib.UNSHARDED, 'y'], + expected_layout_specs=['b', 'y', layout_lib.UNSHARDED], + input_datatype=dtypes.complex128, + ), + dict( + testcase_name='_partiallySharded2_complex64', + input_layout_specs=['b', 'x', layout_lib.UNSHARDED], + expected_layout_specs=['x', 'b', layout_lib.UNSHARDED], + input_datatype=dtypes.complex64, + ), + dict( + testcase_name='_partiallySharded2_complex128', + input_layout_specs=['b', 'x', layout_lib.UNSHARDED], + expected_layout_specs=['x', 'b', layout_lib.UNSHARDED], + input_datatype=dtypes.complex128, + ), + ) + def testIFFT3( + self, input_layout_specs, expected_layout_specs, input_datatype + ): + a = np.random.normal(0, 10, 64).reshape([2, 4, 8]) + b = np.random.normal(0, 10, 64).reshape([2, 4, 8]) + x = constant_op.constant( + a + b * 1j, + dtype=input_datatype, + ) + expected_result = np.fft.ifftn(x) + if input_datatype == dtypes.complex64: + expected_result = expected_result.astype(np.complex64) + x = api.copy_to_mesh(x, Layout(input_layout_specs, self.mesh)) + result = gen_spectral_ops.ifft3d(x) + self.assertDTensorEqual( + expected_result, Layout(expected_layout_specs, self.mesh), result + ) + + @parameterized.named_parameters( + dict( + testcase_name='unsharded_float32', + input_layout_specs=[ + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + ], + expected_layout_specs=[ + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + ], + input_datatype=dtypes.float32, + complex_type=np.complex64, + ), + dict( + testcase_name='unsharded_float64', + input_layout_specs=[ + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + ], + expected_layout_specs=[ + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + ], + input_datatype=dtypes.float64, + complex_type=np.complex128, + ), + dict( + testcase_name='fullySharded_float32', + input_layout_specs=['b', 'x', 'y'], + expected_layout_specs=['b', 'x', layout_lib.UNSHARDED], + input_datatype=dtypes.float32, + complex_type=np.complex64, + ), + dict( + testcase_name='fullySharded_float64', + input_layout_specs=['b', 'x', 'y'], + expected_layout_specs=['b', 'x', layout_lib.UNSHARDED], + input_datatype=dtypes.float64, + complex_type=np.complex128, + ), + dict( + testcase_name='partiallySharded1_float32', + input_layout_specs=['b', layout_lib.UNSHARDED, 'y'], + expected_layout_specs=['b', 'y', layout_lib.UNSHARDED], + input_datatype=dtypes.float32, + complex_type=np.complex64, + ), + dict( + testcase_name='partiallySharded1_float64', + input_layout_specs=['b', layout_lib.UNSHARDED, 'y'], + expected_layout_specs=['b', 'y', layout_lib.UNSHARDED], + input_datatype=dtypes.float64, + complex_type=np.complex128, + ), + dict( + testcase_name='partiallySharded2_float32', + input_layout_specs=['b', 'x', layout_lib.UNSHARDED], + expected_layout_specs=['b', 'x', layout_lib.UNSHARDED], + input_datatype=dtypes.float32, + complex_type=np.complex64, + ), + dict( + testcase_name='partiallySharded2_float64', + input_layout_specs=['b', 'x', layout_lib.UNSHARDED], + expected_layout_specs=['b', 'x', layout_lib.UNSHARDED], + input_datatype=dtypes.float64, + complex_type=np.complex128, + ), + ) + + # Each input dimension for RFFT must have length of at least fft_length[i]. + def testRFFT( + self, + input_layout_specs, + expected_layout_specs, + input_datatype, + complex_type, + ): + x = constant_op.constant( + np.random.normal(0, 10, 80).reshape([2, 4, 10]), + dtype=input_datatype, + ) + expected_result = np.fft.rfft(x, n=10) + if input_datatype == dtypes.float32: + expected_result = expected_result.astype(np.complex64) + x = api.copy_to_mesh(x, Layout(input_layout_specs, self.mesh)) + length = constant_op.constant( + [10], + dtype=dtypes.int32, + ) + result = gen_spectral_ops.rfft(x, fft_length=length, Tcomplex=complex_type) + self.assertDTensorEqual( + expected_result, Layout(expected_layout_specs, self.mesh), result + ) + + @parameterized.named_parameters( + dict( + testcase_name='_unsharded_complex64', + input_layout_specs=[ + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + ], + expected_layout_specs=[ + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + ], + input_datatype=dtypes.float32, + complex_type=np.complex64, + ), + dict( + testcase_name='_unsharded_complex128', + input_layout_specs=[ + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + ], + expected_layout_specs=[ + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + ], + input_datatype=dtypes.float64, + complex_type=np.complex128, + ), + dict( + testcase_name='_fullySharded_complex64', + input_layout_specs=['b', 'x', 'y'], + expected_layout_specs=['b', layout_lib.UNSHARDED, 'x'], + input_datatype=dtypes.float32, + complex_type=np.complex64, + ), + dict( + testcase_name='_fullySharded_compelx128', + input_layout_specs=['b', 'x', 'y'], + expected_layout_specs=['b', layout_lib.UNSHARDED, 'x'], + input_datatype=dtypes.float64, + complex_type=np.complex128, + ), + dict( + testcase_name='_partiallySharded1_complex64', + input_layout_specs=['b', layout_lib.UNSHARDED, 'y'], + expected_layout_specs=['b', layout_lib.UNSHARDED, 'y'], + input_datatype=dtypes.float32, + complex_type=np.complex64, + ), + dict( + testcase_name='_partiallySharded1_complex128', + input_layout_specs=['b', layout_lib.UNSHARDED, 'y'], + expected_layout_specs=['b', layout_lib.UNSHARDED, 'y'], + input_datatype=dtypes.float64, + complex_type=np.complex128, + ), + dict( + testcase_name='_partiallySharded2_complex64', + input_layout_specs=['b', 'x', layout_lib.UNSHARDED], + expected_layout_specs=['b', layout_lib.UNSHARDED, 'x'], + input_datatype=dtypes.float32, + complex_type=np.complex64, + ), + dict( + testcase_name='_partiallySharded2_complex128', + input_layout_specs=['b', 'x', layout_lib.UNSHARDED], + expected_layout_specs=['b', layout_lib.UNSHARDED, 'x'], + input_datatype=dtypes.float64, + complex_type=np.complex128, + ), + ) + def testRFFT2( + self, + input_layout_specs, + expected_layout_specs, + input_datatype, + complex_type, + ): + x = constant_op.constant( + np.random.normal(0, 10, 96).reshape([2, 4, 12]), + dtype=input_datatype, + ) + expected_result = np.fft.rfft2(x, s=[4, 10]) + if input_datatype == dtypes.float32: + expected_result = expected_result.astype(np.complex64) + x = api.copy_to_mesh(x, Layout(input_layout_specs, self.mesh)) + length = constant_op.constant( + [4, 10], + dtype=dtypes.int32, + ) + result = gen_spectral_ops.rfft2d( + x, fft_length=length, Tcomplex=complex_type + ) + self.assertDTensorEqual( + expected_result, Layout(expected_layout_specs, self.mesh), result + ) + + @parameterized.named_parameters( + dict( + testcase_name='_unsharded_complex64', + input_layout_specs=[ + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + ], + expected_layout_specs=[ + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + ], + input_datatype=dtypes.float32, + complex_type=np.complex64, + ), + dict( + testcase_name='_unsharded_complex128', + input_layout_specs=[ + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + ], + expected_layout_specs=[ + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + ], + input_datatype=dtypes.float64, + complex_type=np.complex128, + ), + dict( + testcase_name='_fullySharded_complex64', + input_layout_specs=['b', 'x', 'y'], + expected_layout_specs=[layout_lib.UNSHARDED, 'b', 'x'], + input_datatype=dtypes.float32, + complex_type=np.complex64, + ), + dict( + testcase_name='_fullySharded_compelx128', + input_layout_specs=['b', 'x', 'y'], + expected_layout_specs=[layout_lib.UNSHARDED, 'b', 'x'], + input_datatype=dtypes.float64, + complex_type=np.complex128, + ), + dict( + testcase_name='_partiallySharded1_complex64', + input_layout_specs=['b', layout_lib.UNSHARDED, 'y'], + expected_layout_specs=[layout_lib.UNSHARDED, 'b', 'y'], + input_datatype=dtypes.float32, + complex_type=np.complex64, + ), + dict( + testcase_name='_partiallySharded1_complex128', + input_layout_specs=['b', layout_lib.UNSHARDED, 'y'], + expected_layout_specs=[layout_lib.UNSHARDED, 'b', 'y'], + input_datatype=dtypes.float64, + complex_type=np.complex128, + ), + dict( + testcase_name='_partiallySharded2_complex64', + input_layout_specs=['b', 'x', layout_lib.UNSHARDED], + expected_layout_specs=[layout_lib.UNSHARDED, 'b', 'x'], + input_datatype=dtypes.float32, + complex_type=np.complex64, + ), + dict( + testcase_name='_partiallySharded2_complex128', + input_layout_specs=['b', 'x', layout_lib.UNSHARDED], + expected_layout_specs=[layout_lib.UNSHARDED, 'b', 'x'], + input_datatype=dtypes.float64, + complex_type=np.complex128, + ), + ) + def testRFFT3( + self, + input_layout_specs, + expected_layout_specs, + input_datatype, + complex_type, + ): + x = constant_op.constant( + np.random.normal(0, 10, 80).reshape([2, 4, 10]), + dtype=input_datatype, + ) + expected_result = np.fft.rfftn(x, s=[2, 4, 10]) + if input_datatype == dtypes.float32: + expected_result = expected_result.astype(np.complex64) + x = api.copy_to_mesh(x, Layout(input_layout_specs, self.mesh)) + length = constant_op.constant( + [2, 4, 10], + dtype=dtypes.int32, + ) + result = gen_spectral_ops.rfft3d( + x, fft_length=length, Tcomplex=complex_type + ) + self.assertDTensorEqual( + expected_result, Layout(expected_layout_specs, self.mesh), result + ) + + @parameterized.named_parameters( + dict( + testcase_name='_unsharded_complex64', + input_layout_specs=[ + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + ], + expected_layout_specs=[ + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + ], + input_datatype=dtypes.complex64, + real_type=dtypes.float32, + ), + dict( + testcase_name='_unsharded_complex128', + input_layout_specs=[ + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + ], + expected_layout_specs=[ + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + ], + input_datatype=dtypes.complex128, + real_type=dtypes.float64, + ), + dict( + testcase_name='_fullySharded_complex64', + input_layout_specs=['b', 'x', 'y'], + expected_layout_specs=['b', 'x', layout_lib.UNSHARDED], + input_datatype=dtypes.complex64, + real_type=dtypes.float32, + ), + dict( + testcase_name='_fullySharded_compelx128', + input_layout_specs=['b', 'x', 'y'], + expected_layout_specs=['b', 'x', layout_lib.UNSHARDED], + input_datatype=dtypes.complex128, + real_type=dtypes.float64, + ), + dict( + testcase_name='_partiallySharded1_complex64', + input_layout_specs=['b', layout_lib.UNSHARDED, 'y'], + expected_layout_specs=['b', 'y', layout_lib.UNSHARDED], + input_datatype=dtypes.complex64, + real_type=dtypes.float32, + ), + dict( + testcase_name='_partiallySharded1_complex128', + input_layout_specs=['b', layout_lib.UNSHARDED, 'y'], + expected_layout_specs=['b', 'y', layout_lib.UNSHARDED], + input_datatype=dtypes.complex128, + real_type=dtypes.float64, + ), + dict( + testcase_name='_partiallySharded2_complex64', + input_layout_specs=['b', 'x', layout_lib.UNSHARDED], + expected_layout_specs=['b', 'x', layout_lib.UNSHARDED], + input_datatype=dtypes.complex64, + real_type=dtypes.float32, + ), + dict( + testcase_name='_partiallySharded2_complex128', + input_layout_specs=['b', 'x', layout_lib.UNSHARDED], + expected_layout_specs=['b', 'x', layout_lib.UNSHARDED], + input_datatype=dtypes.complex128, + real_type=dtypes.float64, + ), + ) + + # Each input dimension for IRFFT must have length of at least fft_length[i]. + # For the inner-most input dimension, it should >= fft_shape[i] / 2 + 1. + def testIRFFT( + self, input_layout_specs, expected_layout_specs, input_datatype, real_type + ): + a = np.random.normal(0, 10, 80).reshape([2, 4, 10]) + b = np.random.normal(0, 10, 80).reshape([2, 4, 10]) + x = constant_op.constant( + a + b * 1j, + dtype=input_datatype, + ) + expected_result = np.fft.irfft(x, n=8) + if input_datatype == dtypes.complex64: + expected_result = expected_result.astype(np.float32) + x = api.copy_to_mesh(x, Layout(input_layout_specs, self.mesh)) + length = constant_op.constant( + [8], + dtype=dtypes.int32, + ) + result = gen_spectral_ops.irfft(x, fft_length=length, Treal=real_type) + self.assertDTensorEqual( + expected_result, Layout(expected_layout_specs, self.mesh), result + ) + + @parameterized.named_parameters( + dict( + testcase_name='_unsharded_complex64', + input_layout_specs=[ + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + ], + expected_layout_specs=[ + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + ], + input_datatype=dtypes.complex64, + real_type=dtypes.float32, + ), + dict( + testcase_name='_unsharded_complex128', + input_layout_specs=[ + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + ], + expected_layout_specs=[ + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + ], + input_datatype=dtypes.complex128, + real_type=dtypes.float64, + ), + dict( + testcase_name='_fullySharded_complex64', + input_layout_specs=['b', 'x', 'y'], + expected_layout_specs=['b', 'y', layout_lib.UNSHARDED], + input_datatype=dtypes.complex64, + real_type=dtypes.float32, + ), + dict( + testcase_name='_fullySharded_compelx128', + input_layout_specs=['b', 'x', 'y'], + expected_layout_specs=['b', 'y', layout_lib.UNSHARDED], + input_datatype=dtypes.complex128, + real_type=dtypes.float64, + ), + dict( + testcase_name='_partiallySharded1_complex64', + input_layout_specs=['b', layout_lib.UNSHARDED, 'y'], + expected_layout_specs=['b', 'y', layout_lib.UNSHARDED], + input_datatype=dtypes.complex64, + real_type=dtypes.float32, + ), + dict( + testcase_name='_partiallySharded1_complex128', + input_layout_specs=['b', layout_lib.UNSHARDED, 'y'], + expected_layout_specs=['b', 'y', layout_lib.UNSHARDED], + input_datatype=dtypes.complex128, + real_type=dtypes.float64, + ), + dict( + testcase_name='_partiallySharded2_complex64', + input_layout_specs=['b', 'x', layout_lib.UNSHARDED], + expected_layout_specs=['b', 'x', layout_lib.UNSHARDED], + input_datatype=dtypes.complex64, + real_type=dtypes.float32, + ), + dict( + testcase_name='_partiallySharded2_complex128', + input_layout_specs=['b', 'x', layout_lib.UNSHARDED], + expected_layout_specs=['b', 'x', layout_lib.UNSHARDED], + input_datatype=dtypes.complex128, + real_type=dtypes.float64, + ), + ) + def testIRFFT2( + self, input_layout_specs, expected_layout_specs, input_datatype, real_type + ): + a = np.random.normal(0, 10, 64).reshape([2, 4, 8]) + b = np.random.normal(0, 10, 64).reshape([2, 4, 8]) + x = constant_op.constant( + a + b * 1j, + dtype=input_datatype, + ) + expected_result = np.fft.irfft2(x, s=[4, 8]) + if input_datatype == dtypes.complex64: + expected_result = expected_result.astype(np.float32) + x = api.copy_to_mesh(x, Layout(input_layout_specs, self.mesh)) + length = constant_op.constant( + [4, 8], + dtype=dtypes.int32, + ) + result = gen_spectral_ops.irfft2d(x, fft_length=length, Treal=real_type) + self.assertDTensorEqual( + expected_result, Layout(expected_layout_specs, self.mesh), result + ) + + @parameterized.named_parameters( + dict( + testcase_name='_unsharded_complex64', + input_layout_specs=[ + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + ], + expected_layout_specs=[ + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + ], + input_datatype=dtypes.complex64, + real_type=dtypes.float32, + ), + dict( + testcase_name='_unsharded_complex128', + input_layout_specs=[ + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + ], + expected_layout_specs=[ + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + layout_lib.UNSHARDED, + ], + input_datatype=dtypes.complex128, + real_type=dtypes.float64, + ), + dict( + testcase_name='_fullySharded_complex64', + input_layout_specs=['b', 'x', 'y'], + expected_layout_specs=['x', 'y', layout_lib.UNSHARDED], + input_datatype=dtypes.complex64, + real_type=dtypes.float32, + ), + dict( + testcase_name='_fullySharded_compelx128', + input_layout_specs=['b', 'x', 'y'], + expected_layout_specs=['x', 'y', layout_lib.UNSHARDED], + input_datatype=dtypes.complex128, + real_type=dtypes.float64, + ), + dict( + testcase_name='_partiallySharded1_complex64', + input_layout_specs=['b', layout_lib.UNSHARDED, 'y'], + expected_layout_specs=['b', 'y', layout_lib.UNSHARDED], + input_datatype=dtypes.complex64, + real_type=dtypes.float32, + ), + dict( + testcase_name='_partiallySharded1_complex128', + input_layout_specs=['b', layout_lib.UNSHARDED, 'y'], + expected_layout_specs=['b', 'y', layout_lib.UNSHARDED], + input_datatype=dtypes.complex128, + real_type=dtypes.float64, + ), + dict( + testcase_name='_partiallySharded2_complex64', + input_layout_specs=['b', 'x', layout_lib.UNSHARDED], + expected_layout_specs=['x', 'b', layout_lib.UNSHARDED], + input_datatype=dtypes.complex64, + real_type=dtypes.float32, + ), + dict( + testcase_name='_partiallySharded2_complex128', + input_layout_specs=['b', 'x', layout_lib.UNSHARDED], + expected_layout_specs=['x', 'b', layout_lib.UNSHARDED], + input_datatype=dtypes.complex128, + real_type=dtypes.float64, + ), + ) + def testIRFFT3( + self, input_layout_specs, expected_layout_specs, input_datatype, real_type + ): + a = np.random.normal(0, 10, 48).reshape([2, 4, 6]) + b = np.random.normal(0, 10, 48).reshape([2, 4, 6]) + x = constant_op.constant( + a + b * 1j, + dtype=input_datatype, + ) + expected_result = np.fft.irfftn(x, s=[2, 4, 8]) + if input_datatype == dtypes.complex64: + expected_result = expected_result.astype(np.float32) + x = api.copy_to_mesh(x, Layout(input_layout_specs, self.mesh)) + length = constant_op.constant( + [2, 4, 8], + dtype=dtypes.int32, + ) + result = gen_spectral_ops.irfft3d(x, fft_length=length, Treal=real_type) + self.assertDTensorEqual( + expected_result, Layout(expected_layout_specs, self.mesh), result + ) + + class DTensorLayoutPropSPMDTest(test_util.DTensorBaseTest): def setUp(self): From 5c33f00e273116ceb3ae2d05fc65235053c4780a Mon Sep 17 00:00:00 2001 From: Zhufeng Pan Date: Mon, 24 Jul 2023 13:46:01 -0700 Subject: [PATCH 064/410] Disable a flaky test PiperOrigin-RevId: 550661322 --- .../kernel_tests/nn_ops/conv2d_backprop_filter_grad_test.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tensorflow/python/kernel_tests/nn_ops/conv2d_backprop_filter_grad_test.py b/tensorflow/python/kernel_tests/nn_ops/conv2d_backprop_filter_grad_test.py index e5d96064a30e37..d8cb6703e25bc6 100644 --- a/tensorflow/python/kernel_tests/nn_ops/conv2d_backprop_filter_grad_test.py +++ b/tensorflow/python/kernel_tests/nn_ops/conv2d_backprop_filter_grad_test.py @@ -14,6 +14,7 @@ # ============================================================================== """Tests for convolution related functionality in tensorflow.ops.nn.""" +import unittest import numpy as np from tensorflow.python.framework import constant_op @@ -30,6 +31,8 @@ "Run Conv2D backprop without TF32 on GPU") class Conv2DBackpropFilterGradTest(test.TestCase): + # TODO(b/292002914): Enable this test after fixing its flakyness. + @unittest.skip("Disable the flaky test.") @test_util.run_deprecated_v1 def testGradient(self): with self.cached_session(): From 072d8352af902421201966b799b8e94a6d4d8faa Mon Sep 17 00:00:00 2001 From: Matt Callanan Date: Mon, 24 Jul 2023 13:54:09 -0700 Subject: [PATCH 065/410] #tf-data Set up `file_locality_v2` experiment. PiperOrigin-RevId: 550663606 --- tensorflow/core/data/dataset_utils.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tensorflow/core/data/dataset_utils.cc b/tensorflow/core/data/dataset_utils.cc index f7030487f5465e..5a4f58b6f06256 100644 --- a/tensorflow/core/data/dataset_utils.cc +++ b/tensorflow/core/data/dataset_utils.cc @@ -978,6 +978,8 @@ REGISTER_DATASET_EXPERIMENT("data_transfer", RandomJobSamplePercentage<50>, AllTasks); REGISTER_DATASET_EXPERIMENT("file_locality", RandomJobSamplePercentage<10>, IndependentHostTasks); +REGISTER_DATASET_EXPERIMENT("file_locality_v2", RandomJobSamplePercentage<0>, + IndependentHostTasks); } // namespace } // namespace data } // namespace tensorflow From d54fc3591d342737ec200d7993fe11efe4d223de Mon Sep 17 00:00:00 2001 From: Michael Delorimier Date: Mon, 24 Jul 2023 13:55:14 -0700 Subject: [PATCH 066/410] Parameterize TPUDeviceOrdinalPlaceholder with the TPU_REPLICATED_CORE_i. i is the logical core, which is in the range [0, num_cores_per_replica). PiperOrigin-RevId: 550663934 --- .../compiler/mlir/tensorflow/ir/tf_ops.td | 18 +++--- .../tests/extract_outside_compilation.mlir | 20 +++---- .../replicate_invariant_op_hoisting.mlir | 2 +- .../tensorflow/tests/replicate_to_island.mlir | 47 ++++++++++----- .../tests/replicate_to_island_legacy.mlir | 4 +- .../tests/side-effect-analysis-test.mlir | 2 +- .../transforms/extract_outside_compilation.cc | 23 ++------ .../transforms/replicate_to_island.cc | 59 ++++++++++--------- 8 files changed, 90 insertions(+), 85 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td index d40089d29482e9..40b2924c0b275d 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td @@ -1423,18 +1423,22 @@ summary_metadata: Serialized SummaryMetadata protocol buffer containing def TF__TPUDeviceOrdinalPlaceholderOp : TF_Op<"_TPUDeviceOrdinalPlaceholder", [Pure]> { let summary = [{ -Placeholder device ordinal that represents device ordinal of a replicated op. +Placeholder for a device ordinal that depends on its tf_device.replicate ancestor. }]; let description = [{ -This op can be used when certain rewrite passes materialize ops that require a -device ordinal of a replicated op but replication logic has been abstracted away -using tf_device.replicate op. Subsequent rewrite passes must replace this op with -a constant output that represents the correct device ordinal of the replicated -operations inside a TPU host. +This op must have a tf_device.replicate ancestor. The ancestor replica_id and +logical_core attribute correspond to a TPU core. This op maps the TPU core to a +device_ordinal, where the device ordinal is the index of the core relative to +its host. + +The replicate_to_island pass removes and flattens tf_device.replicate, so it +converts this op to the constant index of the core relative to its host. }]; - let arguments = (ins); + let arguments = (ins + I64Attr:$logical_core + ); let results = (outs TF_Int64Tensor:$device_ordinal diff --git a/tensorflow/compiler/mlir/tensorflow/tests/extract_outside_compilation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/extract_outside_compilation.mlir index 97d1631a916a7a..488c98af16f55f 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/extract_outside_compilation.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/extract_outside_compilation.mlir @@ -2175,28 +2175,24 @@ module attributes {tf.devices = {"/job:localhost/replica:0/task:0/device:CPU:0", // CHECK: tf_device.replicate // CHECK: "tf_device.parallel_execute" // CHECK: %[[PROGRAM0:.+]] = "tf._XlaCompileMlirPlaceholderProgramKey" - // CHECK: %[[DEVICE0_0:.+]] = "tf._TPUDeviceOrdinalPlaceholder" - // CHECK: %[[RECV0:.+]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM0]], %[[DEVICE0_0]]) + // CHECK: %[[DEVICE0:.+]] = "tf._TPUDeviceOrdinalPlaceholder" + // CHECK-SAME: logical_core = 0 + // CHECK: %[[RECV0:.+]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM0]], %[[DEVICE0]]) // CHECK-SAME: _xla_has_host_transfer = true // CHECK-SAME: key = "host_compute_channel_0_args" // CHECK: %[[B0:.+]] = "tf.OpB"(%[[RECV0]]) : (tensor<2x2xi64>) -> tensor<2x2xi64> - // CHECK: "tf._XlaSendFromHostV2"(%[[B0]], %[[PROGRAM0]], %[[DEVICE0_0]]) + // CHECK: "tf._XlaSendFromHostV2"(%[[B0]], %[[PROGRAM0]], %[[DEVICE0]]) // CHECK-SAME: _xla_has_host_transfer = true // CHECK-SAME: key = "host_compute_channel_0_retvals" // CHECK: }, { // CHECK: %[[PROGRAM1:.+]] = "tf._XlaCompileMlirPlaceholderProgramKey" - // CHECK: %[[DEVICE1_0:.+]] = "tf._TPUDeviceOrdinalPlaceholder" - // CHECK: %[[ONE_0:.+]] = "tf.Const" - // CHECK-SAME: value = dense<1> - // CHECK: %[[DEVICE1_1:.+]] = "tf.AddV2"(%[[DEVICE1_0]], %[[ONE_0]]) - // CHECK: %[[RECV1:.+]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM1]], %[[DEVICE1_1]]) + // CHECK: %[[DEVICE1:.+]] = "tf._TPUDeviceOrdinalPlaceholder" + // CHECK-SAME: logical_core = 1 + // CHECK: %[[RECV1:.+]] = "tf._XlaRecvAtHostV2"(%[[PROGRAM1]], %[[DEVICE1]]) // CHECK-SAME: _xla_has_host_transfer = true // CHECK-SAME: key = "host_compute_channel_0_args" // CHECK: %[[B1:.+]] = "tf.OpB"(%[[RECV1]]) : (tensor<2x2xi64>) -> tensor<2x2xi64> - // CHECK: %[[ONE_1:.+]] = "tf.Const" - // CHECK-SAME: value = dense<1> - // CHECK: %[[DEVICE1_2:.+]] = "tf.AddV2"(%[[DEVICE1_0]], %[[ONE_1]]) - // CHECK: "tf._XlaSendFromHostV2"(%[[B1]], %[[PROGRAM1]], %[[DEVICE1_2]]) + // CHECK: "tf._XlaSendFromHostV2"(%[[B1]], %[[PROGRAM1]], %[[DEVICE1]]) // CHECK-SAME: _xla_has_host_transfer = true // CHECK-SAME: key = "host_compute_channel_0_retvals" // CHECK: }, { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/replicate_invariant_op_hoisting.mlir b/tensorflow/compiler/mlir/tensorflow/tests/replicate_invariant_op_hoisting.mlir index b34a26431c0a05..024caf9297bd3e 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/replicate_invariant_op_hoisting.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/replicate_invariant_op_hoisting.mlir @@ -37,7 +37,7 @@ func.func @not_invariant_ordinal_placeholder(%arg0: tensor<*xf32>, %arg1: tensor // CHECK: tf_device.replicate // CHECK: tf._TPUDeviceOrdinalPlaceholder %0:2 = tf_device.replicate([%arg0, %arg1] as %ri: tensor<*xf32>) {n = 2: i32} { - %1 = "tf._TPUDeviceOrdinalPlaceholder"() : () -> tensor + %1 = "tf._TPUDeviceOrdinalPlaceholder"() {logical_core = 0} : () -> tensor tf_device.return %1 : tensor } func.return diff --git a/tensorflow/compiler/mlir/tensorflow/tests/replicate_to_island.mlir b/tensorflow/compiler/mlir/tensorflow/tests/replicate_to_island.mlir index 595f20895fe8b4..8e0e4558b851f0 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/replicate_to_island.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/replicate_to_island.mlir @@ -251,30 +251,27 @@ func.func @replica_id_attr_added(%arg0: tensor, %arg1: tensor tensor - tf_device.return %2 : tensor + %0:1 = tf_executor.island { + tf_device.replicate {n = 2 : i32, devices = {TPU_REPLICATED_CORE_0 = ["/job:worker/replica:0/task:0/device:TPU:1", "/job:worker/replica:0/task:0/device:TPU:2"], TPU_REPLICATED_CORE_1 = ["/job:worker/replica:0/task:0/device:TPU:3", "/job:worker/replica:0/task:0/device:TPU:4"]}} { + %1 = "tf._TPUDeviceOrdinalPlaceholder"() {logical_core = 0} : () -> tensor + %2 = "tf._TPUDeviceOrdinalPlaceholder"() {logical_core = 1} : () -> tensor + tf_device.return } - tf_executor.yield %1#0, %1#1 : tensor, tensor + tf_executor.yield } tf_executor.fetch } func.return } -// CHECK: tf_executor.island -// CHECK: [[CONST_0:%.+]] = "tf.Const" -// CHECK-SAME: _parallel_execution_ids = "r0:0", value = dense<1> : tensor -// CHECK: tf_executor.yield [[CONST_0]] -// CHECK: tf_executor.island -// CHECK: [[CONST_1:%.+]] = "tf.Const" -// CHECK-SAME: _parallel_execution_ids = "r0:1", value = dense<2> : tensor -// CHECK: tf_executor.yield [[CONST_1]] +// CHECK: tf_executor.island wraps "tf.Const"() {_parallel_execution_ids = "r0:0", value = dense<1> : tensor} +// CHECK: tf_executor.island wraps "tf.Const"() {_parallel_execution_ids = "r0:0", value = dense<3> : tensor} +// CHECK: tf_executor.island wraps "tf.Const"() {_parallel_execution_ids = "r0:1", value = dense<2> : tensor} +// CHECK: tf_executor.island wraps "tf.Const"() {_parallel_execution_ids = "r0:1", value = dense<4> : tensor} // ----- // Tests parallel_execute nested inside replicate @@ -360,7 +357,27 @@ func.func @missing_device_ordinals() { %0:3 = tf_executor.island { %1:2 = tf_device.replicate {n = 2 : i32, devices = {TPU_REPLICATED_CORE_1 = ["/job:worker/replica:0/task:0/device:TPU:1", "/job:worker/replica:0/task:0/device:TPU:2"]}} { // expected-error@below {{requires device ordinal from device TPU_REPLICATED_CORE_0 to be present in 'tf.device.replicate' op}} - %2 = "tf._TPUDeviceOrdinalPlaceholder"() : () -> tensor + %2 = "tf._TPUDeviceOrdinalPlaceholder"() {logical_core = 0} : () -> tensor + tf_device.return %2 : tensor + } + tf_executor.yield %1#0, %1#1 : tensor, tensor + } + tf_executor.fetch + } + func.return +} + +// ----- + +// Tests tf._TPUDeviceOrdinalPlaceholder cannot be updated when device ordinal +// is missing. + +func.func @missing_devices() { + tf_executor.graph { + %0:3 = tf_executor.island { + %1:2 = tf_device.replicate {n = 2 : i32} { + // expected-error@below {{devices attribute is not present}} + %2 = "tf._TPUDeviceOrdinalPlaceholder"() {logical_core = 0} : () -> tensor tf_device.return %2 : tensor } tf_executor.yield %1#0, %1#1 : tensor, tensor diff --git a/tensorflow/compiler/mlir/tensorflow/tests/replicate_to_island_legacy.mlir b/tensorflow/compiler/mlir/tensorflow/tests/replicate_to_island_legacy.mlir index 91ac4a2e76d9d2..24d498ebe88601 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/replicate_to_island_legacy.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/replicate_to_island_legacy.mlir @@ -237,7 +237,7 @@ func.func @device_ordinals() { tf_executor.graph { %0:3 = tf_executor.island { %1:2 = tf_device.replicate {n = 2 : i32, devices = {TPU_REPLICATED_CORE_0 = ["/job:worker/replica:0/task:0/device:TPU:1", "/job:worker/replica:0/task:0/device:TPU:2"]}} { - %2 = "tf._TPUDeviceOrdinalPlaceholder"() : () -> tensor + %2 = "tf._TPUDeviceOrdinalPlaceholder"() {logical_core = 0} : () -> tensor tf_device.return %2 : tensor } tf_executor.yield %1#0, %1#1 : tensor, tensor @@ -266,7 +266,7 @@ func.func @missing_device_ordinals() { %0:3 = tf_executor.island { %1:2 = tf_device.replicate {n = 2 : i32, devices = {TPU_REPLICATED_CORE_1 = ["/job:worker/replica:0/task:0/device:TPU:1", "/job:worker/replica:0/task:0/device:TPU:2"]}} { // expected-error@below {{requires device ordinal from device TPU_REPLICATED_CORE_0 to be present in 'tf.device.replicate' op}} - %2 = "tf._TPUDeviceOrdinalPlaceholder"() : () -> tensor + %2 = "tf._TPUDeviceOrdinalPlaceholder"() {logical_core = 0} : () -> tensor tf_device.return %2 : tensor } tf_executor.yield %1#0, %1#1 : tensor, tensor diff --git a/tensorflow/compiler/mlir/tensorflow/tests/side-effect-analysis-test.mlir b/tensorflow/compiler/mlir/tensorflow/tests/side-effect-analysis-test.mlir index 7ff56b0a6be9ec..575b510cff8e31 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/side-effect-analysis-test.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/side-effect-analysis-test.mlir @@ -1939,7 +1939,7 @@ func.func @device_ordinal_placeholder_side_effect_free( %island = tf_executor.island { // expected-remark@above {{ID: 3}} // expected-remark@above {{Successors: {4}}} - "tf._TPUDeviceOrdinalPlaceholder"() : () -> tensor + "tf._TPUDeviceOrdinalPlaceholder"() {logical_core = 0} : () -> tensor // expected-remark@above {{ID: 0}} "tf._UnknownSideEffectingOp_"() : () -> () // expected-remark@above {{ID: 1}} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/extract_outside_compilation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/extract_outside_compilation.cc index 16d27c0f3167c8..2906a9d059fa6f 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/extract_outside_compilation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/extract_outside_compilation.cc @@ -798,28 +798,12 @@ void CloneFirstHost(llvm::SmallVector& core_to_mapping, llvm::dyn_cast(clone)) { recv_at_host.setOperand(0, core_to_compilation_key[core]); builder.setInsertionPoint(recv_at_host); - // core_ordinal = device_ordinal + core - // where device_ordinal is the base device for the replica - Value device_ordinal = core_to_device_ordinal[core]; - Value const_core = builder.create( - recv_at_host.getLoc(), builder.getI64IntegerAttr(core)); - Value core_ordinal = builder.create( - recv_at_host.getLoc(), device_ordinal.getType(), device_ordinal, - const_core); - recv_at_host.setOperand(1, core_ordinal); + recv_at_host.setOperand(1, core_to_device_ordinal[core]); } else if (auto send_from_host = llvm::dyn_cast(clone)) { send_from_host.setOperand(1, core_to_compilation_key[core]); builder.setInsertionPoint(send_from_host); - // core_ordinal = device_ordinal + core - // where device_ordinal is the base device for the replica - Value device_ordinal = core_to_device_ordinal[core]; - Value const_core = builder.create( - send_from_host.getLoc(), builder.getI64IntegerAttr(core)); - Value core_ordinal = builder.create( - send_from_host.getLoc(), device_ordinal.getType(), device_ordinal, - const_core); - send_from_host.setOperand(2, core_ordinal); + send_from_host.setOperand(2, core_to_device_ordinal[core]); } } } @@ -1441,7 +1425,8 @@ LogicalResult CreateParallelExecuteForOutsideCompilation( if (has_tpu_device) { device_ordinal_op = builder.create( device_cluster.getLoc(), - RankedTensorType::get({}, builder.getI64Type())); + RankedTensorType::get({}, builder.getI64Type()), + builder.getI64IntegerAttr(core)); } else { device_ordinal_op = builder.create( device_cluster.getLoc(), diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc index 70ecb08ef785e7..51a7250202f834 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc @@ -51,7 +51,6 @@ namespace { constexpr char kDeviceAttr[] = "device"; constexpr char kReplicaIdAttr[] = "_xla_replica_id"; constexpr char kDeviceOrdinalAttr[] = "device_ordinal"; -constexpr char kTPUCore0[] = "TPU_REPLICATED_CORE_0"; #define GEN_PASS_DEF_REPLICATETOISLANDPASS #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.h.inc" @@ -71,25 +70,31 @@ bool RequiresReplicaIDAttribute(Operation* op) { TF::EnqueueTPUEmbeddingArbitraryTensorBatchOp>(op); } -// Collects TPU device ordinal for outside compilation communication ops. This -// currently assumes outside compilation only uses `TPU_REPLICATED_CORE_0` -// aliased device for the device computation. -std::optional GetDeviceOrdinal( - const std::optional& devices, Location loc, - unsigned replica_id) { - int64_t device_ordinal = 0; - if (devices.has_value()) { - if (auto tpu_replica_0 = devices.value().get(kTPUCore0)) { - llvm::StringRef tpu_device = tpu_replica_0.cast()[replica_id] - .cast() - .getValue(); - if (succeeded(tensorflow::GetDeviceOrdinalFromDeviceString( - loc, tpu_device, &device_ordinal))) { - return std::optional(device_ordinal); - } - } +// Returns the device ordinal (`device_ordinal`) for a replica (`replica_id`) +// and logical core (`logical_core`). +// `replica_id` is the index of the ancestor ReplicateOp in [0, num_replicas). +// `logical_core` is the index of the TPU core in [0, num_cores_per_replica). +// `device_ordinal` is the index of the TPU core relative to its host. +LogicalResult GetDeviceOrdinal(const std::optional& devices, + const unsigned replica_id, + const uint64_t logical_core, + int64_t& device_ordinal, Operation* op) { + if (!devices.has_value()) { + return op->emitOpError() + << "devices attribute is not present in 'tf.device.replicate' op"; } - return std::nullopt; + auto logical_core_name = + tensorflow::GetDeviceAliasForLogicalCore(logical_core); + auto tpu_replica = devices.value().get(logical_core_name); + if (!tpu_replica) { + return op->emitOpError() + << "requires device ordinal from device " << logical_core_name + << " to be present in 'tf.device.replicate' op"; + } + llvm::StringRef tpu_device = + tpu_replica.cast()[replica_id].cast().getValue(); + return tensorflow::GetDeviceOrdinalFromDeviceString(op->getLoc(), tpu_device, + &device_ordinal); } // Updates replica variant ops in a region based on replica `replica_id`. @@ -101,26 +106,24 @@ std::optional GetDeviceOrdinal( LogicalResult UpdateRegionReplicateVariantOps( OpBuilder& builder, Location loc, Region& region, int replica_id, const std::optional& devices) { - std::optional device_ordinal = - GetDeviceOrdinal(devices, loc, replica_id); - auto result = region.walk([&](Operation* op) -> WalkResult { if (RequiresReplicaIDAttribute(op)) { op->setAttr(kReplicaIdAttr, builder.getI64IntegerAttr(replica_id)); return WalkResult::advance(); } - if (isa(op)) { - if (!device_ordinal.has_value()) - return op->emitOpError() - << "requires device ordinal from device " << kTPUCore0 - << " to be present in 'tf.device.replicate' op"; + if (auto placeholder = dyn_cast(op)) { + int64_t device_ordinal; + if (failed(GetDeviceOrdinal(devices, replica_id, + placeholder.getLogicalCore(), device_ordinal, + op))) + return failure(); OpBuilder builder(op); auto const_op = builder.create( op->getLoc(), DenseIntElementsAttr::get( RankedTensorType::get({}, builder.getI64Type()), - {device_ordinal.value()})); + {device_ordinal})); op->replaceAllUsesWith(const_op); op->erase(); return WalkResult::advance(); From c58073a3044d9bf23bc87d49c82d7185ce1fb513 Mon Sep 17 00:00:00 2001 From: Kevin Gleason Date: Mon, 24 Jul 2023 13:55:24 -0700 Subject: [PATCH 067/410] Fix bazel build after bumping LLVM revision - Remove uses of Attributes.td PiperOrigin-RevId: 550663977 --- third_party/stablehlo/temporary.patch | 84 +++++++++++++++++++++++---- 1 file changed, 72 insertions(+), 12 deletions(-) diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch index 1b32eb1c8480f3..b697e5c80707ca 100644 --- a/third_party/stablehlo/temporary.patch +++ b/third_party/stablehlo/temporary.patch @@ -1,10 +1,35 @@ diff --ruN a/stablehlo/BUILD.bazel b/stablehlo/BUILD.bazel --- stablehlo/BUILD.bazel +++ stablehlo/BUILD.bazel -@@ -293,6 +293,24 @@ - ) - - cc_library( +@@ -227,20 +227,6 @@ + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "stablehlo/integrations/python/mlir/dialects/ChloOps.td", + deps = [ +- ":chlo_ops_py_td_files", +- ], +-) +- +-td_library( +- name = "chlo_ops_py_td_files", +- srcs = [ +- "@llvm-project//mlir:include/mlir/Bindings/Python/Attributes.td", +- ], +- includes = [ +- ".", +- "include", +- ], +- deps = [ + ":chlo_ops_td_files", + "@llvm-project//mlir:OpBaseTdFiles", + ], +@@ -289,6 +275,24 @@ + "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:TransformUtils", ++ ], ++) ++ ++cc_library( + name = "experimental_ops", + srcs = [ + "stablehlo/dialect/ExperimentalOps.cpp", @@ -19,14 +44,28 @@ diff --ruN a/stablehlo/BUILD.bazel b/stablehlo/BUILD.bazel + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", -+ ], -+) -+ -+cc_library( - name = "reference_axes", - srcs = [ - "stablehlo/reference/Axes.cpp", -@@ -702,6 +720,7 @@ + ], + ) + +@@ -635,17 +639,6 @@ + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "stablehlo/integrations/python/mlir/dialects/StablehloOps.td", + deps = [ +- ":stablehlo_ops_py_td_files", +- ], +-) +- +-td_library( +- name = "stablehlo_ops_py_td_files", +- srcs = [ +- "@llvm-project//mlir:include/mlir/Bindings/Python/Attributes.td", +- ], +- includes = ["."], +- deps = [ + ":stablehlo_ops_td_files", + "@llvm-project//mlir:OpBaseTdFiles", + ], +@@ -702,6 +695,7 @@ deps = [ ":base", ":chlo_ops", @@ -34,6 +73,27 @@ diff --ruN a/stablehlo/BUILD.bazel b/stablehlo/BUILD.bazel ":stablehlo_ops", ":stablehlo_ops_inc_gen", ":stablehlo_pass_inc_gen", +@@ -1014,20 +1008,6 @@ + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "stablehlo/integrations/python/mlir/dialects/VhloOps.td", + deps = [ +- ":vhlo_ops_py_td_files", +- ], +-) +- +-td_library( +- name = "vhlo_ops_py_td_files", +- srcs = [ +- "@llvm-project//mlir:include/mlir/Bindings/Python/Attributes.td", +- ], +- includes = [ +- ".", +- "include", +- ], +- deps = [ + ":vhlo_ops_td_files", + "@llvm-project//mlir:OpBaseTdFiles", + ], diff --ruN a/stablehlo/CMakeLists.txt b/stablehlo/CMakeLists.txt --- stablehlo/CMakeLists.txt +++ stablehlo/CMakeLists.txt From 18c517325bdd9dc6f4d2c4b6e375820173e2ed14 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 24 Jul 2023 14:01:58 -0700 Subject: [PATCH 068/410] Set default attribute values in `Convert1DConvOp` if not set PiperOrigin-RevId: 550665807 --- .../mlir/tensorflow/tests/legalize_hlo.mlir | 29 +++++++++++++++++++ .../tensorflow/transforms/legalize_hlo.cc | 18 +++++++++--- 2 files changed, 43 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir index 979445dc3a8cce..7c5a8261d5fd12 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir @@ -1719,6 +1719,35 @@ func.func @convert_conv1d(%arg0: tensor<16x32x256xbf16>, %arg1: tensor<1x256x256 func.return %0 : tensor<16x32x256xbf16> } +// CHECK-LABEL: func.func @convert_conv1d_no_lhs_dil_rhs_dil_precision_conf( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<16x32x256xbf16>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x256x256xbf16>) -> tensor<16x32x256xbf16> { +// CHECK-DAG: %[[VAL_2:.*]] = arith.constant dense<[16, 32, 256, 1]> : tensor<4xi64> +// CHECK: %[[VAL_3:.*]] = "tf.Reshape"(%[[VAL_0]], %[[VAL_2]]) : (tensor<16x32x256xbf16>, tensor<4xi64>) -> tensor<16x32x256x1xbf16> +// CHECK-DAG: %[[VAL_4:.*]] = "tf.Const"() {value = dense<[0, 1, 3, 2]> : tensor<4xi64>} : () -> tensor<4xi64> +// CHECK: %[[VAL_5:.*]] = "tf.Transpose"(%[[VAL_3]], %[[VAL_4]]) : (tensor<16x32x256x1xbf16>, tensor<4xi64>) -> tensor<16x32x1x256xbf16> +// CHECK-DAG: %[[VAL_6:.*]] = arith.constant dense<[1, 256, 256, 1]> : tensor<4xi64> +// CHECK: %[[VAL_7:.*]] = "tf.Reshape"(%[[VAL_1]], %[[VAL_6]]) : (tensor<1x256x256xbf16>, tensor<4xi64>) -> tensor<1x256x256x1xbf16> +// CHECK-DAG: %[[VAL_8:.*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64> +// CHECK: %[[VAL_9:.*]] = "tf.Transpose"(%[[VAL_7]], %[[VAL_8]]) : (tensor<1x256x256x1xbf16>, tensor<4xi64>) -> tensor<1x1x256x256xbf16> +// CHECK: %[[VAL_10:.*]] = "tf.Conv2D"(%[[VAL_5]], %[[VAL_9]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<16x32x1x256xbf16>, tensor<1x1x256x256xbf16>) -> tensor<16x32x1x256xbf16> +// CHECK: %[[VAL_11:.*]] = "tf.Const"() {value = dense<[0, 1, 3, 2]> : tensor<4xi64>} : () -> tensor<4xi64> +// CHECK: %[[VAL_12:.*]] = "tf.Transpose"(%[[VAL_10]], %[[VAL_11]]) : (tensor<16x32x1x256xbf16>, tensor<4xi64>) -> tensor<16x32x256x1xbf16> +// CHECK: %[[VAL_13:.*]] = arith.constant dense<[16, 32, 256]> : tensor<3xi64> +// CHECK: %[[VAL_14:.*]] = "tf.Reshape"(%[[VAL_12]], %[[VAL_13]]) : (tensor<16x32x256x1xbf16>, tensor<3xi64>) -> tensor<16x32x256xbf16> +// CHECK: return %[[VAL_14]] : tensor<16x32x256xbf16> +// CHECK: } +func.func @convert_conv1d_no_lhs_dil_rhs_dil_precision_conf(%arg0: tensor<16x32x256xbf16>, %arg1: tensor<1x256x256xbf16>) -> tensor<16x32x256xbf16> { + %0 = "mhlo.convolution"(%arg0, %arg1) { + batch_group_count = 1 : i64, + dimension_numbers = #mhlo.conv<[b, 0, f]x[0, i, o]->[b, 0, f]>, + feature_group_count = 1 : i64, + padding = dense<0> : tensor<1x2xi64>, + window_strides = dense<1> : tensor<1xi64> + } : (tensor<16x32x256xbf16>, tensor<1x256x256xbf16>) -> tensor<16x32x256xbf16> + func.return %0 : tensor<16x32x256xbf16> +} + // CHECK-LABEL: func.func @convert_conv1d_non_canonical_dimension_numbers( // CHECK-SAME: %[[VAL_0:.*]]: tensor<32x16x256xbf16>, // CHECK-SAME: %[[VAL_1:.*]]: tensor<256x1x256xbf16>) -> tensor<256x16x32xbf16> { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc index 9a15a022947184..9cee8917a8b60f 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc @@ -311,6 +311,12 @@ class Convert1DConvOp : public OpConversionPattern, RankedTensorType::get({2, 2}, rewriter.getI64Type()), padding_2d_array); // LHS dilation + // Set LHS dilation defaults if not set (1 for each input spatial dimension) + if (!conv_op.getLhsDilation().has_value()) { + conv_op.setLhsDilationAttr(rewriter.getI64TensorAttr( + std::vector(dnums.getInputSpatialDimensions().size(), 1))); + } + SmallVector lhs_dilation_array_2d; for (const auto v : conv_op.getLhsDilation().value().getValues()) { lhs_dilation_array_2d.emplace_back(v); @@ -321,6 +327,13 @@ class Convert1DConvOp : public OpConversionPattern, lhs_dilation_array_2d); // RHS dilation + // Set RHS dilation defaults if not set (1 for each kernel spatial + // dimension) + if (!conv_op.getRhsDilation().has_value()) { + conv_op.setRhsDilationAttr(rewriter.getI64TensorAttr( + std::vector(dnums.getKernelSpatialDimensions().size(), 1))); + } + SmallVector rhs_dilation_array_2d; for (const auto v : conv_op.getRhsDilation().value().getValues()) { rhs_dilation_array_2d.emplace_back(v); @@ -338,9 +351,6 @@ class Convert1DConvOp : public OpConversionPattern, RankedTensorType::get({2}, rewriter.getI64Type()), SmallVector({0, 0})); - // Precision config - if (!conv_op.getPrecisionConfig().has_value()) return failure(); - // Dimension numbers reflect the form of the 2d conv op NWHC * WHIO -> NWHC auto dnums_2d = mhlo::ConvDimensionNumbersAttr::get(rewriter.getContext(), @@ -380,7 +390,7 @@ class Convert1DConvOp : public OpConversionPattern, transposed_image_2d_op.getResult(), transposed_kernel_2d_op.getResult(), window_strides_2d, padding_2d, lhs_dilation_2d, rhs_dilation_2d, window_reversal_2d, dnums_2d, conv_op.getFeatureGroupCount(), - conv_op.getBatchGroupCount(), *conv_op.getPrecisionConfig()); + conv_op.getBatchGroupCount(), conv_op.getPrecisionConfigAttr()); OpResult conv2d_output = conv2d_op->getResult(0); auto conv2d_output_type = conv2d_output.getType().cast(); From 97aa8b7c245d115814893eb22474a9f19e0145ed Mon Sep 17 00:00:00 2001 From: Clive Verghese Date: Mon, 24 Jul 2023 14:26:57 -0700 Subject: [PATCH 069/410] Profiler sessions could consist of multiple xplanes. PiperOrigin-RevId: 550673445 --- .../core/profiler/utils/hlo_proto_map.cc | 54 +++++++++++-------- .../core/profiler/utils/hlo_proto_map.h | 2 +- 2 files changed, 32 insertions(+), 24 deletions(-) diff --git a/tensorflow/core/profiler/utils/hlo_proto_map.cc b/tensorflow/core/profiler/utils/hlo_proto_map.cc index a0f90aaecb0d78..edfd4af2bb483b 100644 --- a/tensorflow/core/profiler/utils/hlo_proto_map.cc +++ b/tensorflow/core/profiler/utils/hlo_proto_map.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/profiler/utils/hlo_proto_map.h" #include +#include #include #include #include @@ -47,30 +48,37 @@ int NumHeapSimulatorTraceEvents(const xla::HloProto* hlo) { } // namespace -std::vector>> +absl::flat_hash_map> ParseHloProtosFromXSpace(const XSpace& space) { - std::vector>> hlo_protos; - const XPlane* raw_plane = FindPlaneWithName(space, kMetadataPlaneName); - if (raw_plane != nullptr) { - XPlaneVisitor plane = tsl::profiler::CreateTfXPlaneVisitor(raw_plane); - - const XStatMetadata* hlo_proto_stat_metadata = - plane.GetStatMetadataByType(StatType::kHloProto); - if (hlo_proto_stat_metadata != nullptr) { - plane.ForEachEventMetadata( - [&](const XEventMetadataVisitor& event_metadata) { - auto hlo_proto_stat = event_metadata.GetStat( - StatType::kHloProto, *hlo_proto_stat_metadata); - if (!hlo_proto_stat) return; - if (hlo_proto_stat->ValueCase() != XStat::kBytesValue) return; - auto hlo_proto = std::make_unique(); - absl::string_view byte_value = hlo_proto_stat->BytesValue(); - if (hlo_proto->ParseFromArray(byte_value.data(), - byte_value.size())) { - hlo_protos.emplace_back(event_metadata.Id(), - std::move(hlo_proto)); - } - }); + absl::flat_hash_map> hlo_protos; + std::vector planes = + FindPlanesWithNames(space, {kMetadataPlaneName}); + for (const XPlane* raw_plane : planes) { + if (raw_plane != nullptr) { + XPlaneVisitor plane = tsl::profiler::CreateTfXPlaneVisitor(raw_plane); + + const XStatMetadata* hlo_proto_stat_metadata = + plane.GetStatMetadataByType(StatType::kHloProto); + if (hlo_proto_stat_metadata != nullptr) { + plane.ForEachEventMetadata( + [&](const XEventMetadataVisitor& event_metadata) { + auto hlo_proto_stat = event_metadata.GetStat( + StatType::kHloProto, *hlo_proto_stat_metadata); + if (!hlo_proto_stat) return; + if (hlo_proto_stat->ValueCase() != XStat::kBytesValue) return; + auto hlo_proto = std::make_unique(); + absl::string_view byte_value = hlo_proto_stat->BytesValue(); + if (hlo_proto->ParseFromArray(byte_value.data(), + byte_value.size())) { + if (!hlo_protos + .try_emplace(event_metadata.Id(), std::move(hlo_proto)) + .second) { + LOG(WARNING) << "Insert failed for hlo_proto with program_id" + << event_metadata.Id(); + } + } + }); + } } } return hlo_protos; diff --git a/tensorflow/core/profiler/utils/hlo_proto_map.h b/tensorflow/core/profiler/utils/hlo_proto_map.h index c26b24d63c38f1..fc6bba48d90373 100644 --- a/tensorflow/core/profiler/utils/hlo_proto_map.h +++ b/tensorflow/core/profiler/utils/hlo_proto_map.h @@ -31,7 +31,7 @@ limitations under the License. namespace tensorflow { namespace profiler { -std::vector>> +absl::flat_hash_map> ParseHloProtosFromXSpace(const XSpace& space); class HloProtoMap { From d6ee973f2f498a07c203f4e2ca021eee64a3e186 Mon Sep 17 00:00:00 2001 From: Russell Power Date: Mon, 24 Jul 2023 14:28:07 -0700 Subject: [PATCH 070/410] Add support for constructing partial XLA shapes from TF. PiperOrigin-RevId: 550673775 --- .../compiler/jit/mark_for_compilation_pass.cc | 2 + tensorflow/compiler/tf2xla/shape_util.cc | 43 +++++++++++++++++++ tensorflow/compiler/tf2xla/shape_util.h | 5 +++ 3 files changed, 50 insertions(+) diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index a6869ff741e3af..23391ac6fd9824 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -2063,6 +2063,8 @@ absl::flat_hash_set GetKnownXLAAllowlistOp() { "DepthwiseConv2dNativeBackpropInput", "Dequantize", "Diag", + "DynamicInfeedEnqueueTupleOp", + "DynamicInfeedDequeueTupleOp", "DynamicStitch", "DynamicPartition", "Einsum", diff --git a/tensorflow/compiler/tf2xla/shape_util.cc b/tensorflow/compiler/tf2xla/shape_util.cc index e40df038bbb63a..8a266c2f1eed96 100644 --- a/tensorflow/compiler/tf2xla/shape_util.cc +++ b/tensorflow/compiler/tf2xla/shape_util.cc @@ -21,6 +21,8 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" namespace tensorflow { namespace { @@ -109,6 +111,47 @@ Status TensorShapeToXLAShape(DataType dtype, return OkStatus(); } +Status TensorShapeToBoundedXLAShape(DataType dtype, + const PartialTensorShape& tensor_shape, + const TensorShape& bound, + xla::Shape* shape) { + xla::PrimitiveType type; + TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(dtype, &type)); + if (tensor_shape.unknown_rank()) { + // For unknown shape, create a rank 1 size 0 tensor. + *shape = xla::ShapeUtil::MakeShapeWithDenseLayout(type, {0}, {0}); + return OkStatus(); + } + + if (tensor_shape.dims() != bound.dims()) { + return errors::InvalidArgument( + "`tensor_shape` and `bound` have different ranks. tensor_shape=", + tensor_shape.dims(), "vs bound=", bound.dims()); + } + + int rank = tensor_shape.dims(); + std::vector dimensions(rank); + std::vector layout(rank); + for (int d = 0; d < rank; ++d) { + if (bound.dim_size(d) < 0) { + return errors::InvalidArgument("Bound dimension ", d, + " has unknown size."); + } + dimensions[d] = bound.dim_size(d); + } + // XLA uses minor-to-major; Tensorflow uses major-to-minor. + std::iota(layout.rbegin(), layout.rend(), 0); + xla::Shape result = + xla::ShapeUtil::MakeShapeWithDenseLayout(type, dimensions, layout); + for (int d = 0; d < rank; ++d) { + if (tensor_shape.dim_size(d) < 0) { + result.set_dynamic_dimension(d, true); + } + } + *shape = result; + return OkStatus(); +} + xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type, const PartialTensorShape& tensor_shape) { if (tensor_shape.unknown_rank()) { diff --git a/tensorflow/compiler/tf2xla/shape_util.h b/tensorflow/compiler/tf2xla/shape_util.h index 67ae3898a82342..57bcabcb2eb950 100644 --- a/tensorflow/compiler/tf2xla/shape_util.h +++ b/tensorflow/compiler/tf2xla/shape_util.h @@ -58,6 +58,11 @@ Status TensorShapeToXLAShape(DataType dtype, xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type, const PartialTensorShape& tensor_shape); +Status TensorShapeToBoundedXLAShape(DataType dtype, + const PartialTensorShape& tensor_shape, + const TensorShape& bound, + xla::Shape* shape); + // Given an XLA shape with layouts, builds a layout vector in the form able to // be fed to ops like InfeedEnqueue/InfeedEnqueueTuple/XRTAllocateV2/.... // THe returned vector is a linearized sequence of the minor-to-major values of From 87601ae074f69d30800dc205db8840e7d1b1f6fc Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 24 Jul 2023 14:55:03 -0700 Subject: [PATCH 071/410] Allow op costs to be periodically updated. PiperOrigin-RevId: 550681655 --- tensorflow/core/tfrt/graph_executor/BUILD | 2 + .../graph_executor/graph_execution_options.h | 17 +- .../tfrt/graph_executor/graph_executor.cc | 61 +++++-- .../core/tfrt/graph_executor/graph_executor.h | 26 ++- .../graph_executor/graph_executor_test.cc | 167 +++++++++++++++++- 5 files changed, 251 insertions(+), 22 deletions(-) diff --git a/tensorflow/core/tfrt/graph_executor/BUILD b/tensorflow/core/tfrt/graph_executor/BUILD index 14c2292a020c7d..fcbbce16ae2e81 100644 --- a/tensorflow/core/tfrt/graph_executor/BUILD +++ b/tensorflow/core/tfrt/graph_executor/BUILD @@ -33,6 +33,7 @@ cc_library( "//tensorflow/core/protobuf:for_core_protos_cc", "//tensorflow/core/tfrt/runtime:work_queue_interface", "//tensorflow/core/tfrt/utils:bridge_graph_analysis", + "@com_google_absl//absl/time", "@com_google_absl//absl/types:optional", ], ) @@ -147,6 +148,7 @@ tf_cc_test( "//tensorflow/core/tfrt/saved_model:saved_model_testutil", "//tensorflow/tsl/platform:status", "//tensorflow/tsl/platform:statusor", + "@com_google_absl//absl/time", "@com_google_googletest//:gtest_main", "@tf_runtime//:tensor", "@tf_runtime//cpp_tests:common", diff --git a/tensorflow/core/tfrt/graph_executor/graph_execution_options.h b/tensorflow/core/tfrt/graph_executor/graph_execution_options.h index 3a42bfe2447850..8dec13610505a6 100644 --- a/tensorflow/core/tfrt/graph_executor/graph_execution_options.h +++ b/tensorflow/core/tfrt/graph_executor/graph_execution_options.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/time/time.h" #include "absl/types/optional.h" #include "tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h" #include "tensorflow/core/framework/tensor.h" @@ -71,10 +72,20 @@ struct GraphExecutionOptions { // is overwritten when `enable_online_cost_analysis = true`. struct CostAnalysisOptions { enum CostAnalysisVersion { - DISABLED, - ONCE, // Cost recording and recompilation occurs on the first run only. + kDisabled, + kOnce, // Cost recording and recompilation occurs on the first run only. + kPeriodic, // This is experimental. }; - CostAnalysisVersion version = DISABLED; + CostAnalysisVersion version = kDisabled; + + // Time between resets in Op cost estimates. Upon reset, the executable + // will be recompiled. + // However, a reset always occurs after the first execution. + absl::Duration reset_interval = absl::ZeroDuration(); + + // Number of times to record costs before resetting Op cost estimates. + // However, a reset always occurs after the first execution. + int updates_per_interval = 1; }; CostAnalysisOptions cost_analysis_options; diff --git a/tensorflow/core/tfrt/graph_executor/graph_executor.cc b/tensorflow/core/tfrt/graph_executor/graph_executor.cc index 0232cc5c2d4d27..41758ad036a9f7 100644 --- a/tensorflow/core/tfrt/graph_executor/graph_executor.cc +++ b/tensorflow/core/tfrt/graph_executor/graph_executor.cc @@ -445,7 +445,7 @@ StatusOr> GraphExecutor::Create( } if (options.enable_online_cost_analysis) { // Overrides cost_analysis_options. - options.cost_analysis_options.version = Options::CostAnalysisOptions::ONCE; + options.cost_analysis_options.version = Options::CostAnalysisOptions::kOnce; } TfrtGraphExecutionState::Options graph_execution_state_options; @@ -560,7 +560,10 @@ tensorflow::Status GraphExecutor::Run( // Possibly record costs, depending on the particular setting of // `CostAnalysisOptions`. - CostRecorder* cost_recorder = loaded_client_graph.MaybeGetCostRecorder(); + auto now = absl::Now() + simulated_duration_; + bool do_recompilation; + CostRecorder* cost_recorder = + loaded_client_graph.MaybeGetCostRecorder(now, &do_recompilation); std::vector flat_outputs; TF_RETURN_IF_ERROR(GraphExecutionRunOnFunction( @@ -573,13 +576,15 @@ tensorflow::Status GraphExecutor::Run( &req_deadline_tracker_, cost_recorder, loaded_client_graph.stream_callback_id())); - if (cost_recorder != nullptr) { + if (do_recompilation) { TF_RETURN_IF_ERROR( loaded_client_graph.UpdateCost(*cost_recorder, runtime())); tensorflow::mutex_lock l(num_recompilations_mu_); num_recompilations_ += 1; } - + if (cost_recorder != nullptr) { + loaded_client_graph.UpdateCostAnalysisData(now, do_recompilation); + } // Create the outputs from the actual function results, which are sorted // according to the output tensor names. auto flat_output_iter = flat_outputs.begin(); @@ -906,13 +911,27 @@ tensorflow::Status GraphExecutor::RunWithSyncInterpreter( return execution_context.status(); } -CostRecorder* GraphExecutor::LoadedClientGraph::MaybeGetCostRecorder() { +CostRecorder* GraphExecutor::LoadedClientGraph::MaybeGetCostRecorder( + absl::Time now, bool* do_recompilation) { + *do_recompilation = false; tensorflow::mutex_lock l(cost_analysis_data_.mu); if (!cost_analysis_data_.is_available) { return nullptr; } - cost_analysis_data_.is_available = false; - return cost_analysis_data_.cost_recorder.get(); + const auto& options = graph_executor_->options().cost_analysis_options; + absl::Duration elapsed_duration = now - cost_analysis_data_.start_time; + double intended_num_updates = absl::ToDoubleSeconds(elapsed_duration) / + absl::ToDoubleSeconds(options.reset_interval) * + options.updates_per_interval; + // Compare with the actual number of cost updates to decide whether or not to + // record costs for this particular execution. + if (intended_num_updates - cost_analysis_data_.num_cost_updates >= 1) { + cost_analysis_data_.is_available = false; + *do_recompilation = 1 + cost_analysis_data_.num_cost_updates >= + options.updates_per_interval; + return cost_analysis_data_.cost_recorder.get(); + } + return nullptr; } Status GraphExecutor::LoadedClientGraph::UpdateCost( @@ -964,13 +983,33 @@ Status GraphExecutor::LoadedClientGraph::UpdateCost( // add a test kernel that examines the cost. executable_context_ = std::move(new_executable_context); } - // Free the cost analysis data if it will not be used again. - cost_analysis_data_.tfrt_mlir = nullptr; - cost_analysis_data_.tf_mlir_with_op_keys = nullptr; - cost_analysis_data_.cost_recorder = nullptr; return OkStatus(); } +void GraphExecutor::LoadedClientGraph::UpdateCostAnalysisData( + absl::Time now, bool do_recompilation) { + tensorflow::mutex_lock lock(cost_analysis_data_.mu); + if (!do_recompilation) { + cost_analysis_data_.num_cost_updates += 1; + cost_analysis_data_.is_available = true; + return; + } + if (graph_executor_->options().cost_analysis_options.version == + Options::CostAnalysisOptions::kOnce) { + // Free the cost analysis data if it will not be used again. + cost_analysis_data_.is_available = false; + cost_analysis_data_.tfrt_mlir = nullptr; + cost_analysis_data_.tf_mlir_with_op_keys = nullptr; + cost_analysis_data_.cost_recorder = nullptr; + } else { + // Update cost analysis data. + cost_analysis_data_.cost_recorder = std::make_unique(); + cost_analysis_data_.is_available = true; + cost_analysis_data_.start_time = now; + cost_analysis_data_.num_cost_updates = 0; + } +} + tensorflow::Status GraphExecutor::CompileGraph( const std::string& graph_name, absl::Span input_tensor_names, diff --git a/tensorflow/core/tfrt/graph_executor/graph_executor.h b/tensorflow/core/tfrt/graph_executor/graph_executor.h index ad26d7e8770735..fd94bd002d98c1 100644 --- a/tensorflow/core/tfrt/graph_executor/graph_executor.h +++ b/tensorflow/core/tfrt/graph_executor/graph_executor.h @@ -22,8 +22,8 @@ limitations under the License. #include #include -#include "absl/base/call_once.h" #include "absl/strings/string_view.h" +#include "absl/time/time.h" #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/OwningOpRef.h" // from @llvm-project #include "tensorflow/core/platform/statusor.h" @@ -145,7 +145,11 @@ class GraphExecutor { executable_context_(std::move(executable_context)), stream_callback_id_(std::move(stream_callback_id)) { const auto& options = graph_executor_->options().cost_analysis_options; - if (options.version != Options::CostAnalysisOptions::DISABLED) { + if (options.version != Options::CostAnalysisOptions::kDisabled) { + // Initialize in a way that ensures recompilation on the first run. + cost_analysis_data_.start_time = absl::Now() - options.reset_interval; + cost_analysis_data_.is_available = true; + cost_analysis_data_.num_cost_updates = options.updates_per_interval - 1; cost_analysis_data_.cost_recorder = std::make_unique(); if (executable_context_->IsForMlrt()) { cost_analysis_data_.tf_mlir_with_op_keys = @@ -158,12 +162,16 @@ class GraphExecutor { // Returns this instance's CostRecorder if it is time to update costs, // else returns nullptr. Only allows one non-null return value at a time - // in order to provide thread-safety. - CostRecorder* MaybeGetCostRecorder(); + // in order to provide thread-safety. If do_recompilation becomes `true`, + // then recompiles using updated costs occurs. + CostRecorder* MaybeGetCostRecorder(absl::Time now, bool* do_recompilation); // Updates the op cost values in this `LoadedClientGraph` with records from - // `cost_recorder` and initiates recompilation. + // `cost_recorder`. Status UpdateCost(const CostRecorder& cost_recorder, const Runtime& runtime); + // Updates `cost_analysis_data_` to make it accurate for the next execution. + // Assumes a cost update occurred this cycle. + void UpdateCostAnalysisData(absl::Time now, bool do_recompilation); // Getters. std::shared_ptr executable_context() const { tensorflow::mutex_lock lock(executable_context_mu_); @@ -192,12 +200,16 @@ class GraphExecutor { struct CostAnalysisData { mutable tensorflow::mutex mu; // Ensures only one GraphExecutor thread updates costs at a time. - bool is_available TF_GUARDED_BY(mu) = true; + bool is_available TF_GUARDED_BY(mu) = false; // Maintains the book-keeping of op costs. std::unique_ptr cost_recorder; // For recompilation in MLRT, TFRT respectively. mlir::OwningOpRef tf_mlir_with_op_keys; mlir::OwningOpRef tfrt_mlir; + // Start of current cost measurement cycle. + absl::Time start_time TF_GUARDED_BY(mu) = absl::Now(); + // Cost recordings within the current measurement cycle. + int num_cost_updates TF_GUARDED_BY(mu) = 0; }; CostAnalysisData cost_analysis_data_; @@ -207,7 +219,6 @@ class GraphExecutor { // Can be updated if online cost analysis is enabled. std::shared_ptr executable_context_ TF_GUARDED_BY(executable_context_mu_); - mutable absl::once_flag create_cost_recorder_once_; SyncResourceState sync_resource_state_; std::optional stream_callback_id_; @@ -345,6 +356,7 @@ class GraphExecutor { protected: // For testing basic Cost Analysis functionality. + absl::Duration simulated_duration_ = absl::ZeroDuration(); tensorflow::mutex num_recompilations_mu_; int num_recompilations_ TF_GUARDED_BY(num_recompilations_mu_) = 0; }; diff --git a/tensorflow/core/tfrt/graph_executor/graph_executor_test.cc b/tensorflow/core/tfrt/graph_executor/graph_executor_test.cc index 103a3a00e36c38..1fe33f60d19a57 100644 --- a/tensorflow/core/tfrt/graph_executor/graph_executor_test.cc +++ b/tensorflow/core/tfrt/graph_executor/graph_executor_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "learning/brain/experimental/tfrt/native_lowering/kernels/sync_fallback_kernels.h" #include #include +#include "absl/time/time.h" #include "tensorflow/cc/ops/array_ops.h" #include "tensorflow/cc/ops/const_op.h" #include "tensorflow/core/framework/graph.pb.h" @@ -54,6 +55,10 @@ class GraphExecutorForTestingCostAnalysis : public GraphExecutor { tensorflow::mutex_lock lock(num_recompilations_mu_); return num_recompilations_; } + // This method is not thread safe. + void AdvanceTime(absl::Duration duration) { + simulated_duration_ = simulated_duration_ + duration; + } }; class GraphExecutorTest : public ::testing::TestWithParam {}; @@ -121,7 +126,7 @@ TEST_P(GraphExecutorTest, OnlineCostAnalysisOptionsOverrideToOnce) { // `enable_online_cost_analysis` = true. options.enable_online_cost_analysis = true; options.cost_analysis_options.version = - GraphExecutionOptions::CostAnalysisOptions::DISABLED; + GraphExecutionOptions::CostAnalysisOptions::kDisabled; options.enable_mlrt = GetParam(); TF_ASSERT_OK_AND_ASSIGN( @@ -166,6 +171,166 @@ TEST_P(GraphExecutorTest, OnlineCostAnalysisOptionsOverrideToOnce) { EXPECT_EQ(graph_executor->num_recompilations(), 1); } +TEST_P(GraphExecutorTest, OnlineCostAnalysisEveryTime) { + GraphDef graph_def; + TF_ASSERT_OK(GetSimpleGraphDef(graph_def)); + + auto runtime = DefaultTfrtRuntime(/*num_threads=*/1); + GraphExecutor::Options options(runtime.get()); + options.cost_analysis_options.version = + GraphExecutionOptions::CostAnalysisOptions::kPeriodic; + options.cost_analysis_options.reset_interval = absl::ZeroDuration(); + options.cost_analysis_options.updates_per_interval = 1; + options.enable_mlrt = GetParam(); + + TF_ASSERT_OK_AND_ASSIGN( + auto fallback_state, + tensorflow::tfrt_stub::FallbackState::Create( + CreateDefaultSessionOptions(options), graph_def.library())); + auto resource_context = std::make_unique(); + TF_ASSERT_OK_AND_ASSIGN( + auto graph_executor_base, + GraphExecutor::Create(std::move(options), *fallback_state, + std::move(resource_context), graph_def, + GetKernelRegistry())); + auto graph_executor = std::unique_ptr( + static_cast( + graph_executor_base.release())); + + // Set input 'x' to [[1, 1, 1]] + std::vector> inputs; + inputs.push_back({"input", CreateTfTensor( + /*shape=*/{1, 3}, /*data=*/{1, 1, 1})}); + + std::vector outputs; + + for (int i = 0; i < 10; ++i) { + TF_ASSERT_OK(graph_executor->Run(/*run_options=*/{}, inputs, + /*output_tensor_names=*/{"rank"}, + /*target_tensor_names=*/{}, &outputs)); + ASSERT_EQ(outputs.size(), 1); + EXPECT_THAT(GetTfTensorData(outputs[0]), + ::testing::ElementsAreArray({2})); + EXPECT_EQ(graph_executor->num_recompilations(), i + 1); + } +} + +TEST_P(GraphExecutorTest, OnlineCostAnalysisDisabled) { + GraphDef graph_def; + TF_ASSERT_OK(GetSimpleGraphDef(graph_def)); + + auto runtime = DefaultTfrtRuntime(/*num_threads=*/1); + GraphExecutor::Options options(runtime.get()); + options.cost_analysis_options.version = + GraphExecutionOptions::CostAnalysisOptions::kDisabled; + options.cost_analysis_options.reset_interval = absl::ZeroDuration(); + options.cost_analysis_options.updates_per_interval = 1; + options.enable_mlrt = GetParam(); + + TF_ASSERT_OK_AND_ASSIGN( + auto fallback_state, + tensorflow::tfrt_stub::FallbackState::Create( + CreateDefaultSessionOptions(options), graph_def.library())); + auto resource_context = std::make_unique(); + TF_ASSERT_OK_AND_ASSIGN( + auto graph_executor_base, + GraphExecutor::Create(std::move(options), *fallback_state, + std::move(resource_context), graph_def, + GetKernelRegistry())); + auto graph_executor = std::unique_ptr( + static_cast( + graph_executor_base.release())); + + // Set input 'x' to [[1, 1, 1]] + std::vector> inputs; + inputs.push_back({"input", CreateTfTensor( + /*shape=*/{1, 3}, /*data=*/{1, 1, 1})}); + + std::vector outputs; + + TF_ASSERT_OK(graph_executor->Run(/*run_options=*/{}, inputs, + /*output_tensor_names=*/{"rank"}, + /*target_tensor_names=*/{}, &outputs)); + EXPECT_EQ(graph_executor->num_recompilations(), 0); +} + +TEST_P(GraphExecutorTest, OnlineCostAnalysisPeriodic) { + GraphDef graph_def; + TF_ASSERT_OK(GetSimpleGraphDef(graph_def)); + + auto runtime = DefaultTfrtRuntime(/*num_threads=*/1); + GraphExecutor::Options options(runtime.get()); + options.cost_analysis_options.version = + GraphExecutionOptions::CostAnalysisOptions::kPeriodic; + options.cost_analysis_options.reset_interval = absl::Minutes(10); + options.cost_analysis_options.updates_per_interval = 5; + options.enable_mlrt = GetParam(); + + TF_ASSERT_OK_AND_ASSIGN( + auto fallback_state, + tensorflow::tfrt_stub::FallbackState::Create( + CreateDefaultSessionOptions(options), graph_def.library())); + auto resource_context = std::make_unique(); + TF_ASSERT_OK_AND_ASSIGN( + auto graph_executor_base, + GraphExecutor::Create(std::move(options), *fallback_state, + std::move(resource_context), graph_def, + GetKernelRegistry())); + auto graph_executor = std::unique_ptr( + static_cast( + graph_executor_base.release())); + + // Set input 'x' to [[1, 1, 1]] + std::vector> inputs; + inputs.push_back({"input", CreateTfTensor( + /*shape=*/{1, 3}, /*data=*/{1, 1, 1})}); + + std::vector outputs; + // First run always initiates a recompilation. + TF_ASSERT_OK(graph_executor->Run(/*run_options=*/{}, inputs, + /*output_tensor_names=*/{"rank"}, + /*target_tensor_names=*/{}, &outputs)); + EXPECT_EQ(graph_executor->num_recompilations(), 1); + + // We have specified that the costs should only update every + // `reset_interval` / `updates_per_interval` = 2 + // minutes. So no cost update occurs here. + for (int i = 0; i < 10; ++i) { + TF_ASSERT_OK(graph_executor->Run(/*run_options=*/{}, inputs, + /*output_tensor_names=*/{"rank"}, + /*target_tensor_names=*/{}, &outputs)); + EXPECT_EQ(graph_executor->num_recompilations(), 1); + } + // With 2 minute breaks in-between, 4 runs = 4 cost updates. + for (int i = 0; i < 4; ++i) { + graph_executor->AdvanceTime(absl::Minutes(2)); + TF_ASSERT_OK(graph_executor->Run(/*run_options=*/{}, inputs, + /*output_tensor_names=*/{"rank"}, + /*target_tensor_names=*/{}, &outputs)); + EXPECT_EQ(graph_executor->num_recompilations(), 1); + } + // A reset occurs on the 5th run. + graph_executor->AdvanceTime(absl::Minutes(2)); + TF_ASSERT_OK(graph_executor->Run(/*run_options=*/{}, inputs, + /*output_tensor_names=*/{"rank"}, + /*target_tensor_names=*/{}, &outputs)); + EXPECT_EQ(graph_executor->num_recompilations(), 2); + + // Demonstrate one more reset. + for (int i = 0; i < 4; ++i) { + graph_executor->AdvanceTime(absl::Minutes(1000)); + TF_ASSERT_OK(graph_executor->Run(/*run_options=*/{}, inputs, + /*output_tensor_names=*/{"rank"}, + /*target_tensor_names=*/{}, &outputs)); + EXPECT_EQ(graph_executor->num_recompilations(), 2); + } + graph_executor->AdvanceTime(absl::Minutes(1000)); + TF_ASSERT_OK(graph_executor->Run(/*run_options=*/{}, inputs, + /*output_tensor_names=*/{"rank"}, + /*target_tensor_names=*/{}, &outputs)); + EXPECT_EQ(graph_executor->num_recompilations(), 3); +} + REGISTER_OP("TestCancel") .Input("x: T") .Output("z: T") From ebae459c855a6e57063751ebf23dab28dab88044 Mon Sep 17 00:00:00 2001 From: Dragan Mladjenovic Date: Mon, 24 Jul 2023 15:01:02 -0700 Subject: [PATCH 072/410] PR #3960: [ROCM] Support xla.gpu.cuda.graph.launch Imported from GitHub PR https://github.com/openxla/xla/pull/3960 Remove the cuda infix where possible Copybara import of the project: -- 098648c2cd3b12631f34170b45660c36ea83d6b2 by Dragan Mladjenovic : [ROCM] Support xla.gpu.cuda.graph.launch Remove the cuda infix where possible Merging this change closes #3960 PiperOrigin-RevId: 550683191 --- .../compiler/xla/debug_options_flags.cc | 54 ++-- .../gpu/transforms/add_concurrent_regions.cc | 4 +- .../transforms/add_hlo_trace_annotations.cc | 2 +- .../gpu/transforms/outline_cuda_graphs.cc | 36 +-- .../mlir/backends/gpu/transforms/passes.cc | 2 +- .../xla/mlir/backends/gpu/transforms/passes.h | 10 +- .../mlir/backends/gpu/transforms/passes.td | 8 +- .../tests/add_concurrent_regions.mlir | 44 +-- .../transforms/tests/outline_cuda_graphs.mlir | 84 +++--- .../xla/service/gpu/autotuner_compile_util.cc | 2 +- .../service/gpu/compile_module_to_llvm_ir.cc | 6 +- .../compiler/xla/service/gpu/runtime/BUILD | 10 +- .../xla/service/gpu/runtime/executable.cc | 33 ++- .../xla/service/gpu/runtime/executable.h | 6 +- .../xla/service/gpu/runtime/graph_launch.cc | 77 +++--- .../xla/service/gpu/runtime/graph_launch.h | 22 +- .../xla/service/gpu/runtime/kernel_launch.cc | 14 +- .../compiler/xla/stream_executor/cuda/BUILD | 11 - .../xla/stream_executor/cuda/cuda_graph.cc | 224 --------------- .../compiler/xla/stream_executor/gpu/BUILD | 19 ++ .../xla/stream_executor/gpu/gpu_graph.cc | 259 ++++++++++++++++++ .../{cuda/cuda_graph.h => gpu/gpu_graph.h} | 87 +++--- .../rocm/rocm_driver_wrapper.h | 10 + tensorflow/compiler/xla/xla.proto | 24 +- 24 files changed, 558 insertions(+), 490 deletions(-) delete mode 100644 tensorflow/compiler/xla/stream_executor/cuda/cuda_graph.cc create mode 100644 tensorflow/compiler/xla/stream_executor/gpu/gpu_graph.cc rename tensorflow/compiler/xla/stream_executor/{cuda/cuda_graph.h => gpu/gpu_graph.h} (53%) diff --git a/tensorflow/compiler/xla/debug_options_flags.cc b/tensorflow/compiler/xla/debug_options_flags.cc index 282d3e30e27e8c..576fbbbfbf2e8f 100644 --- a/tensorflow/compiler/xla/debug_options_flags.cc +++ b/tensorflow/compiler/xla/debug_options_flags.cc @@ -102,14 +102,14 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { // flag. opts.set_xla_gpu_enable_cublaslt(false); - // TODO(b/258036887): Enable cuda_graph_level=2. Currently blocked by CUDA 12 + // TODO(b/258036887): Enable gpu_graph_level=2. Currently blocked by CUDA 12 // integration. - opts.set_xla_gpu_cuda_graph_level(1); - opts.set_xla_gpu_cuda_graph_num_runs_to_instantiate(-1); + opts.set_xla_gpu_graph_level(1); + opts.set_xla_gpu_graph_num_runs_to_instantiate(-1); opts.set_xla_gpu_enable_persistent_temp_buffers(false); - opts.set_xla_gpu_cuda_graph_min_graph_size(5); - opts.set_xla_gpu_cuda_graph_enable_concurrent_region(false); - opts.set_xla_gpu_cuda_graph_eviction_timeout_seconds(60); + opts.set_xla_gpu_graph_min_graph_size(5); + opts.set_xla_gpu_graph_enable_concurrent_region(false); + opts.set_xla_gpu_graph_eviction_timeout_seconds(60); // Despite the name, fast min/max on GPUs does not seem to be any faster, and // adds very counter-intuitive "NaN-swallowing" behavior. @@ -897,36 +897,36 @@ void MakeDebugOptionsFlags(std::vector* flag_list, debug_options->xla_gpu_enable_cublaslt(), "Use cuBLASLt for GEMMs when possible.")); flag_list->push_back(tsl::Flag( - "xla_gpu_cuda_graph_level", - int32_setter_for(&DebugOptions::set_xla_gpu_cuda_graph_level), - debug_options->xla_gpu_cuda_graph_level(), - "Set CUDA graph level. 0 = off; 1 = capture fusions and memcpys; 2 = " + "xla_gpu_graph_level", + int32_setter_for(&DebugOptions::set_xla_gpu_graph_level), + debug_options->xla_gpu_graph_level(), + "Set GPU graph level. 0 = off; 1 = capture fusions and memcpys; 2 = " "capture convolutions and gemms; 3 = capture collectives.")); flag_list->push_back(tsl::Flag( - "xla_gpu_cuda_graph_num_runs_to_instantiate", + "xla_gpu_graph_num_runs_to_instantiate", int32_setter_for( - &DebugOptions::set_xla_gpu_cuda_graph_num_runs_to_instantiate), - debug_options->xla_gpu_cuda_graph_num_runs_to_instantiate(), - "Instantiate a cuda graph after the time a captured function is executed " + &DebugOptions::set_xla_gpu_graph_num_runs_to_instantiate), + debug_options->xla_gpu_graph_num_runs_to_instantiate(), + "Instantiate a gpu graph after the time a captured function is executed " "reaches the threshold.")); flag_list->push_back(tsl::Flag( - "xla_gpu_cuda_graph_min_graph_size", - int32_setter_for(&DebugOptions::set_xla_gpu_cuda_graph_min_graph_size), - debug_options->xla_gpu_cuda_graph_min_graph_size(), + "xla_gpu_graph_min_graph_size", + int32_setter_for(&DebugOptions::set_xla_gpu_graph_min_graph_size), + debug_options->xla_gpu_graph_min_graph_size(), "Capture a region as a function to be launched as cuda graph if the " "number of moved instructions reaches this threshold.")); + flag_list->push_back( + tsl::Flag("xla_gpu_graph_enable_concurrent_region", + bool_setter_for( + &DebugOptions::set_xla_gpu_graph_enable_concurrent_region), + debug_options->xla_gpu_graph_enable_concurrent_region(), + "Identify concurrent regions in gpu graphs and execute them " + "concurrently.")); flag_list->push_back(tsl::Flag( - "xla_gpu_cuda_graph_enable_concurrent_region", - bool_setter_for( - &DebugOptions::set_xla_gpu_cuda_graph_enable_concurrent_region), - debug_options->xla_gpu_cuda_graph_enable_concurrent_region(), - "Identify concurrent regions in cuda graphs and execute them " - "concurrently.")); - flag_list->push_back(tsl::Flag( - "xla_gpu_cuda_graph_eviction_timeout_seconds", + "xla_gpu_graph_eviction_timeout_seconds", int32_setter_for( - &DebugOptions::set_xla_gpu_cuda_graph_eviction_timeout_seconds), - debug_options->xla_gpu_cuda_graph_eviction_timeout_seconds(), + &DebugOptions::set_xla_gpu_graph_eviction_timeout_seconds), + debug_options->xla_gpu_graph_eviction_timeout_seconds(), "Timeout in seconds to evict instantiated Gpu graphs from device. When " "XLA instantiates new Gpu graphs, it evicts graphs that were not " "recently executed to free space on device.")); diff --git a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/add_concurrent_regions.cc b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/add_concurrent_regions.cc index 2a8457cf1630df..6c2b840181795e 100644 --- a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/add_concurrent_regions.cc +++ b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/add_concurrent_regions.cc @@ -209,9 +209,9 @@ void AddConcurrentRegionsPass::runOnOperation() { auto func_ops = llvm::to_vector(module.getOps()); for (auto func_op : func_ops) { - // Find the cuda graph capture function. + // Find the gpu graph capture function. if (absl::StrContains(func_op.getSymNameAttr().str(), - "xla.gpu.cuda.graph.capture")) { + "xla.gpu.graph.capture")) { InsertConcurrentRegions(func_op, custom_calls, getAnalysis()); } diff --git a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/add_hlo_trace_annotations.cc b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/add_hlo_trace_annotations.cc index 3d41826265d9e9..76c2b9941bd580 100644 --- a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/add_hlo_trace_annotations.cc +++ b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/add_hlo_trace_annotations.cc @@ -64,7 +64,7 @@ void AddHloTraceAnnotationsPass::runOnOperation() { // TODO(b/275240695): Report the graph content once the Xprof team provides // an API. if (absl::StrContains(call.getCalleeAttr().getValue(), - "xla.gpu.cuda.graph.launch")) { + "xla.gpu.graph.launch")) { auto capture = call->getAttr("capture").cast(); std::string op_name = "cuda_graph/" + capture.getValue().str(); auto annotation = HloTraceAttr::get(ctx, std::move(op_name)); diff --git a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/outline_cuda_graphs.cc b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/outline_cuda_graphs.cc index 3a921bb4af7b5d..bbca585ef50bc3 100644 --- a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/outline_cuda_graphs.cc +++ b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/outline_cuda_graphs.cc @@ -46,19 +46,19 @@ limitations under the License. namespace xla { namespace gpu { -#define GEN_PASS_DEF_OUTLINECUDAGRAPHSPASS +#define GEN_PASS_DEF_OUTLINEGPUGRAPHSPASS #include "tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.h.inc" using namespace mlir; // NOLINT using mlir::gpu::LaunchFuncOp; -class OutlineCudaGraphsPass - : public impl::OutlineCudaGraphsPassBase { +class OutlineGpuGraphsPass + : public impl::OutlineGpuGraphsPassBase { public: - OutlineCudaGraphsPass() = default; - explicit OutlineCudaGraphsPass(int cuda_graph_level, int min_graph_size) - : cuda_graph_level_(cuda_graph_level) { + OutlineGpuGraphsPass() = default; + explicit OutlineGpuGraphsPass(int gpu_graph_level, int min_graph_size) + : gpu_graph_level_(gpu_graph_level) { this->min_graph_size_ = min_graph_size; } @@ -69,7 +69,7 @@ class OutlineCudaGraphsPass } private: - int cuda_graph_level_ = 3; + int gpu_graph_level_ = 3; }; //===----------------------------------------------------------------------===// @@ -350,7 +350,7 @@ static LogicalResult Outline(unsigned ordinal, // Create a function in the compiled module. auto func = b.create( - "xla.gpu.cuda.graph.capture", + "xla.gpu.graph.capture", FunctionType::get(ctx, TypeRange(ValueRange(args)), TypeRange())); Operation* first_op = seq.front().first; @@ -401,7 +401,7 @@ static LogicalResult Outline(unsigned ordinal, // Create a custom call declaration corresponding to the outlined graph // capture function. func::FuncOp graph_launch = custom_calls.GetOrCreate( - b, "xla.gpu.cuda.graph.launch", TypeRange(ValueRange(args)), TypeRange()); + b, "xla.gpu.graph.launch", TypeRange(ValueRange(args)), TypeRange()); // Call the cuda graph launch custom call right before the first moved op. auto insertion_point = llvm::find_if(seq, [](auto capture) { @@ -451,13 +451,13 @@ static LogicalResult Outline(unsigned ordinal, //===----------------------------------------------------------------------===// -void OutlineCudaGraphsPass::runOnOperation() { +void OutlineGpuGraphsPass::runOnOperation() { SymbolTable sym_table(getOperation()); CustomCallDeclarations custom_calls(std::move(sym_table)); OpCapturePatternSet patterns; - if (cuda_graph_level_ >= 1) { + if (gpu_graph_level_ >= 1) { // Enable capturing fusions and memcpies. patterns.emplace_back(new LaunchFuncOpCapture()); patterns.emplace_back(new ConstantOpCapture()); @@ -466,7 +466,7 @@ void OutlineCudaGraphsPass::runOnOperation() { patterns.emplace_back(new ReinterpretCastOpCapture()); } - if (cuda_graph_level_ >= 2) { + if (gpu_graph_level_ >= 2) { // Enable capturing conv/gemms. patterns.emplace_back(new ConvForwardOpCapture()); patterns.emplace_back(new ConvBackwardInputOpCapture()); @@ -484,14 +484,14 @@ void OutlineCudaGraphsPass::runOnOperation() { } } -std::unique_ptr> createOutlineCudaGraphsPass() { - return std::make_unique(); +std::unique_ptr> createOutlineGpuGraphsPass() { + return std::make_unique(); } -std::unique_ptr> createOutlineCudaGraphsPass( - int cuda_graph_level, int min_graph_size) { - return std::make_unique(cuda_graph_level, - min_graph_size); +std::unique_ptr> createOutlineGpuGraphsPass( + int gpu_graph_level, int min_graph_size) { + return std::make_unique(gpu_graph_level, + min_graph_size); } } // namespace gpu diff --git a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.cc b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.cc index 54f91186b8be8f..c943c33a550c78 100644 --- a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.cc +++ b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.cc @@ -38,7 +38,7 @@ void populateXlaGpuRuntimePasses(mlir::OpPassManager& pm, // Outline CUDA-Graph-compatible operations into graph capture functions. pm.addPass( - createOutlineCudaGraphsPass(opts.cuda_graph_level, opts.min_graph_size)); + createOutlineGpuGraphsPass(opts.gpu_graph_level, opts.min_graph_size)); if (opts.enable_concurrent_region) { pm.addPass(createAddConcurrentRegionsPass()); } diff --git a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.h b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.h index 7d3ea262443770..2c13ee239b8899 100644 --- a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.h +++ b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.h @@ -32,7 +32,7 @@ namespace gpu { #define GEN_PASS_DECL_CONVERTLMHLOTOGPULAUNCHPASS #define GEN_PASS_DECL_CONVERTLMHLOTOGPURUNTIMEPASS #define GEN_PASS_DECL_CONVERTMEMREFGETGLOBALTOARGPASS -#define GEN_PASS_DECL_OUTLINECUDAGRAPHSPASS +#define GEN_PASS_DECL_OUTLINEGPUGRAPHSPASS #define GEN_PASS_DECL_ADDCONCURRENTREGIONSPASS #define GEN_PASS_DECL_STREAMASSIGNMENTPASS #include "tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.h.inc" @@ -43,7 +43,7 @@ struct GpuPipelineOpts { // Enable experimental pass that outlines parts of the XLA computation into // CUDA Graphs, which allows us to amortize the cost of launching multiple // device kernels. - int32_t cuda_graph_level = 0; + int32_t gpu_graph_level = 0; int32_t min_graph_size = 0; bool enable_concurrent_region = false; }; @@ -101,10 +101,10 @@ createAddHloTraceAnnotationsPass(); //===----------------------------------------------------------------------===// std::unique_ptr> -createOutlineCudaGraphsPass(); +createOutlineGpuGraphsPass(); -std::unique_ptr> -createOutlineCudaGraphsPass(int32_t cuda_graph_level, int32_t min_graph_size); +std::unique_ptr> createOutlineGpuGraphsPass( + int32_t gpu_graph_level, int32_t min_graph_size); //===----------------------------------------------------------------------===// // Passes for marking concurrent region in CUDA graph capture function. diff --git a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.td b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.td index 29cf6991e26431..79b87695eeb4b3 100644 --- a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.td +++ b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.td @@ -152,8 +152,8 @@ def AddHloTraceAnnotationsPass : // Xla Gpu <-> Cuda Graphs integration. //===----------------------------------------------------------------------===// -def OutlineCudaGraphsPass : - Pass<"xla-gpu-outline-cuda-graphs", "mlir::ModuleOp"> { +def OutlineGpuGraphsPass : + Pass<"xla-gpu-outline-gpu-graphs", "mlir::ModuleOp"> { let summary = "Outline sequences of Xla Gpu operations into CUDA Graphs"; let description = [{ @@ -178,13 +178,13 @@ def OutlineCudaGraphsPass : } // Replace a sequence of graph launch operations with a call to runtime API. - call @xla.gpu.cuda.graph.launch(%arg0: memref, + call @xla.gpu.graph.launch(%arg0: memref, %arg1: memref) attributes { capture = @capture } ``` }]; - let constructor = "createOutlineCudaGraphsPass()"; + let constructor = "createOutlineGpuGraphsPass()"; let options = [ Option<"min_graph_size_", "min_graph_size", "int64_t", /*default=*/"2", diff --git a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/add_concurrent_regions.mlir b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/add_concurrent_regions.mlir index 97fd88c47fb3c6..e9f21a15b65e72 100644 --- a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/add_concurrent_regions.mlir +++ b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/add_concurrent_regions.mlir @@ -14,8 +14,8 @@ module attributes {gpu.container_module} { } - // CHECK: func @xla.gpu.cuda.graph.capture - func.func @xla.gpu.cuda.graph.capture(%arg0: memref<72xi8>, %arg1: memref<72xi8>, %arg2: memref<328xi8>, %arg3: memref<72xi8>, %arg4: memref<72xi8>, %arg5: memref<72xi8>, %arg6: memref<72xi8>, %arg7: memref<72xi8>, %arg8: memref<72xi8>, %arg9: memref<72xi8>) { + // CHECK: func @xla.gpu.graph.capture + func.func @xla.gpu.graph.capture(%arg0: memref<72xi8>, %arg1: memref<72xi8>, %arg2: memref<328xi8>, %arg3: memref<72xi8>, %arg4: memref<72xi8>, %arg5: memref<72xi8>, %arg6: memref<72xi8>, %arg7: memref<72xi8>, %arg8: memref<72xi8>, %arg9: memref<72xi8>) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %view = memref.view %arg0[%c0][] : memref<72xi8> to memref<3x3xi64> @@ -46,8 +46,8 @@ module attributes {gpu.container_module} { } - // CHECK: func @xla.gpu.cuda.graph.capture - func.func @xla.gpu.cuda.graph.capture(%arg0: memref<72xi8>, %arg1: memref<72xi8>, %arg2: memref<328xi8>, %arg3: memref<72xi8>, %arg4: memref<72xi8>, %arg5: memref<72xi8>, %arg6: memref<72xi8>, %arg7: memref<72xi8>, %arg8: memref<72xi8>, %arg9: memref<72xi8>) { + // CHECK: func @xla.gpu.graph.capture + func.func @xla.gpu.graph.capture(%arg0: memref<72xi8>, %arg1: memref<72xi8>, %arg2: memref<328xi8>, %arg3: memref<72xi8>, %arg4: memref<72xi8>, %arg5: memref<72xi8>, %arg6: memref<72xi8>, %arg7: memref<72xi8>, %arg8: memref<72xi8>, %arg9: memref<72xi8>) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %view = memref.view %arg0[%c0][] : memref<72xi8> to memref<3x3xi64> @@ -76,8 +76,8 @@ module attributes {gpu.container_module} { } - // CHECK: func @xla.gpu.cuda.graph.capture - func.func @xla.gpu.cuda.graph.capture(%arg0: memref<72xi8>, %arg1: memref<72xi8>, %arg2: memref<328xi8>, %arg3: memref<72xi8>, %arg4: memref<72xi8>, %arg5: memref<72xi8>, %arg6: memref<72xi8>, %arg7: memref<72xi8>, %arg8: memref<72xi8>, %arg9: memref<72xi8>) { + // CHECK: func @xla.gpu.graph.capture + func.func @xla.gpu.graph.capture(%arg0: memref<72xi8>, %arg1: memref<72xi8>, %arg2: memref<328xi8>, %arg3: memref<72xi8>, %arg4: memref<72xi8>, %arg5: memref<72xi8>, %arg6: memref<72xi8>, %arg7: memref<72xi8>, %arg8: memref<72xi8>, %arg9: memref<72xi8>) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %view = memref.view %arg0[%c0][] : memref<72xi8> to memref<3x3xi64> @@ -106,8 +106,8 @@ module attributes {gpu.container_module} { } - // CHECK: func @xla.gpu.cuda.graph.capture - func.func @xla.gpu.cuda.graph.capture(%arg0: memref<72xi8>, %arg1: memref<72xi8>, %arg2: memref<328xi8>, %arg3: memref<72xi8>, %arg4: memref<72xi8>, %arg5: memref<72xi8>, %arg6: memref<72xi8>, %arg7: memref<72xi8>, %arg8: memref<72xi8>, %arg9: memref<72xi8>) { + // CHECK: func @xla.gpu.graph.capture + func.func @xla.gpu.graph.capture(%arg0: memref<72xi8>, %arg1: memref<72xi8>, %arg2: memref<328xi8>, %arg3: memref<72xi8>, %arg4: memref<72xi8>, %arg5: memref<72xi8>, %arg6: memref<72xi8>, %arg7: memref<72xi8>, %arg8: memref<72xi8>, %arg9: memref<72xi8>) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %view = memref.view %arg0[%c0][] : memref<72xi8> to memref<3x3xi1> @@ -136,8 +136,8 @@ module attributes {gpu.container_module} { } - // CHECK: func @xla.gpu.cuda.graph.capture - func.func @xla.gpu.cuda.graph.capture(%arg0: memref<144xi8>) { + // CHECK: func @xla.gpu.graph.capture + func.func @xla.gpu.graph.capture(%arg0: memref<144xi8>) { %c0 = arith.constant 0 : index %c72 = arith.constant 72 : index %c1 = arith.constant 1 : index @@ -168,8 +168,8 @@ module attributes {gpu.container_module} { } - // CHECK: func @xla.gpu.cuda.graph.capture - func.func @xla.gpu.cuda.graph.capture(%arg0: memref<144xi8>) { + // CHECK: func @xla.gpu.graph.capture + func.func @xla.gpu.graph.capture(%arg0: memref<144xi8>) { %c0 = arith.constant 0 : index %c36 = arith.constant 36 : index %c1 = arith.constant 1 : index @@ -198,8 +198,8 @@ module attributes {gpu.container_module} { } - // CHECK: func @xla.gpu.cuda.graph.capture - func.func @xla.gpu.cuda.graph.capture(%arg0: memref<144xi8> {lmhlo.constant_name = "cst0"}) { + // CHECK: func @xla.gpu.graph.capture + func.func @xla.gpu.graph.capture(%arg0: memref<144xi8> {lmhlo.constant_name = "cst0"}) { %c0 = arith.constant 0 : index %c36 = arith.constant 36 : index %c1 = arith.constant 1 : index @@ -225,8 +225,8 @@ module attributes {gpu.container_module} { module attributes {gpu.container_module} { - // CHECK: func @xla.gpu.cuda.graph.capture - func.func @xla.gpu.cuda.graph.capture(%arg0: memref<16xi8>, + // CHECK: func @xla.gpu.graph.capture + func.func @xla.gpu.graph.capture(%arg0: memref<16xi8>, %arg1: memref<16xi8>, %arg2: memref<16xi8>, %arg3: memref<16xi8>) { @@ -261,8 +261,8 @@ module attributes {gpu.container_module} { gpu.func @fn0(%arg0: memref<16xi8> {lmhlo.written} ) kernel { gpu.return } } - // CHECK: func @xla.gpu.cuda.graph.capture - func.func @xla.gpu.cuda.graph.capture(%arg0: memref<16xi8>, + // CHECK: func @xla.gpu.graph.capture + func.func @xla.gpu.graph.capture(%arg0: memref<16xi8>, %arg1: memref<16xi8>, %arg2: memref<16xi8>) { %c0 = arith.constant 0 : index @@ -290,8 +290,8 @@ module attributes {gpu.container_module} { module attributes {gpu.container_module} { - // CHECK: func @xla.gpu.cuda.graph.capture - func.func @xla.gpu.cuda.graph.capture(%arg0: memref<16xi8>, + // CHECK: func @xla.gpu.graph.capture + func.func @xla.gpu.graph.capture(%arg0: memref<16xi8>, %arg1: memref<16xi8>, %arg2: memref<16xi8>) { %c0 = arith.constant 0 : index @@ -325,8 +325,8 @@ module attributes {gpu.container_module} { } - // CHECK: func @xla.gpu.cuda.graph.capture - func.func @xla.gpu.cuda.graph.capture(%arg0: memref<72xi8>, %arg1: memref<72xi8>, %arg2: memref<328xi8>, %arg3: memref<72xi8>, %arg4: memref<72xi8>, %arg5: memref<72xi8>, %arg6: memref<72xi8>, %arg7: memref<72xi8>, %arg8: memref<72xi8>, %arg9: memref<72xi8>) { + // CHECK: func @xla.gpu.graph.capture + func.func @xla.gpu.graph.capture(%arg0: memref<72xi8>, %arg1: memref<72xi8>, %arg2: memref<328xi8>, %arg3: memref<72xi8>, %arg4: memref<72xi8>, %arg5: memref<72xi8>, %arg6: memref<72xi8>, %arg7: memref<72xi8>, %arg8: memref<72xi8>, %arg9: memref<72xi8>) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %view = memref.view %arg0[%c0][] : memref<72xi8> to memref<3x3xi64> diff --git a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/outline_cuda_graphs.mlir b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/outline_cuda_graphs.mlir index ce2a7227b97a89..5c151ffbd8b36f 100644 --- a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/outline_cuda_graphs.mlir +++ b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/outline_cuda_graphs.mlir @@ -1,4 +1,4 @@ -// RUN: xla-gpu-opt %s --split-input-file -xla-gpu-outline-cuda-graphs \ +// RUN: xla-gpu-opt %s --split-input-file -xla-gpu-outline-gpu-graphs \ // RUN: | FileCheck %s module attributes {gpu.container_module} { @@ -24,8 +24,8 @@ func.func @func(%arg0: memref, %arg1: memref) { %c5 = arith.constant 5 : index %c6 = arith.constant 6 : index - // CHECK: call @xla.gpu.cuda.graph.launch(%[[ARG0]], %[[ARG1]]) - // CHECK-SAME: {capture = @xla.gpu.cuda.graph.capture} + // CHECK: call @xla.gpu.graph.launch(%[[ARG0]], %[[ARG1]]) + // CHECK-SAME: {capture = @xla.gpu.graph.capture} // CHECK-NEXT: return gpu.launch_func @gpu_module::@fn0 @@ -41,7 +41,7 @@ func.func @func(%arg0: memref, %arg1: memref) { func.return } -// CHECK: func @xla.gpu.cuda.graph.capture +// CHECK: func @xla.gpu.graph.capture // CHECK-NEXT: %[[C1:.*]] = arith.constant 1 // CHECK-NEXT: %[[C2:.*]] = arith.constant 2 // CHECK-NEXT: %[[C3:.*]] = arith.constant 3 @@ -56,8 +56,8 @@ func.func @func(%arg0: memref, %arg1: memref) { // CHECK-SAME: threads in (%[[C6]], %[[C5]], %[[C4]]) // CHECK-NEXT: return -// CHECK: func private @xla.gpu.cuda.graph.launch(memref, memref) -// CHECK-SAME: attributes {rt.custom_call = "xla.gpu.cuda.graph.launch"} +// CHECK: func private @xla.gpu.graph.launch(memref, memref) +// CHECK-SAME: attributes {rt.custom_call = "xla.gpu.graph.launch"} } // ----- @@ -76,7 +76,7 @@ func.func @func(%arg0: memref) { %c1 = arith.constant 1 : index // CHECK: gpu.launch_func {{.*}} args(%[[ARG0]] : memref) - // CHECK-NOT: call @xla.gpu.cuda.graph.launch + // CHECK-NOT: call @xla.gpu.graph.launch gpu.launch_func @gpu_module::@fn0 blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) @@ -107,7 +107,7 @@ func.func @func(%arg0: memref) { // CHECK: %[[C1:.*]] = arith.constant 1 %c1 = arith.constant 1 : index - // CHECK: call @xla.gpu.cuda.graph.launch(%[[ARG0]]) + // CHECK: call @xla.gpu.graph.launch(%[[ARG0]]) // CHECK-SAME: {capture = @[[CAPTURE:.*]]} gpu.launch_func @gpu_module::@fn0 @@ -127,7 +127,7 @@ func.func @func(%arg0: memref) { // CHECK: call @external call @external(): () -> () - // CHECK: call @xla.gpu.cuda.graph.launch(%[[ARG0]]) + // CHECK: call @xla.gpu.graph.launch(%[[ARG0]]) // CHECK-SAME: {capture = @[[CAPTURE_0:.*]]} gpu.launch_func @gpu_module::@fn1 @@ -181,8 +181,8 @@ gpu.module @gpu_module attributes {binary = "kernel binary"} { func.func @func(%arg0: memref, %arg1: memref) { cf.br ^bb2 ^bb1: - // CHECK: call @xla.gpu.cuda.graph.launch(%[[ARG0]], %[[ARG1]]) - // CHECK-SAME: {capture = @xla.gpu.cuda.graph.capture} + // CHECK: call @xla.gpu.graph.launch(%[[ARG0]], %[[ARG1]]) + // CHECK-SAME: {capture = @xla.gpu.graph.capture} // CHECK-NEXT: return gpu.launch_func @gpu_module::@fn0 @@ -203,7 +203,7 @@ func.func @func(%arg0: memref, %arg1: memref) { } } -// CHECK: func @xla.gpu.cuda.graph.capture +// CHECK: func @xla.gpu.graph.capture // CHECK-NEXT: arith.constant 1 // CHECK-NEXT: gpu.launch_func @gpu_module::@fn0 // CHECK-NEXT: gpu.launch_func @gpu_module::@fn1 @@ -227,8 +227,8 @@ func.func @func(%arg0: memref<16xi8>) { call @external() : () -> () - // CHECK: call @xla.gpu.cuda.graph.launch(%[[ARG0]]) - // CHECK-SAME: {capture = @xla.gpu.cuda.graph.capture} + // CHECK: call @xla.gpu.graph.launch(%[[ARG0]]) + // CHECK-SAME: {capture = @xla.gpu.graph.capture} // CHECK-NEXT: return gpu.launch_func @gpu_module::@fn0 blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%view : memref<4xf32>) @@ -241,7 +241,7 @@ func.func @func(%arg0: memref<16xi8>) { func.func private @external() } -// CHECK: func @xla.gpu.cuda.graph.capture +// CHECK: func @xla.gpu.graph.capture // CHECK-NEXT: arith.constant 0 // CHECK-NEXT: arith.constant 1 // CHECK-NEXT: memref.view @@ -267,8 +267,8 @@ func.func @func(%arg0: memref<16xi8>) { call @external() : () -> () - // CHECK: call @xla.gpu.cuda.graph.launch(%[[ARG0]]) - // CHECK-SAME: {capture = @xla.gpu.cuda.graph.capture} + // CHECK: call @xla.gpu.graph.launch(%[[ARG0]]) + // CHECK-SAME: {capture = @xla.gpu.graph.capture} // CHECK-NEXT: memref.view // CHECK-NEXT: return gpu.launch_func @gpu_module::@fn0 blocks in (%c1, %c1, %c1) @@ -283,7 +283,7 @@ func.func @func(%arg0: memref<16xi8>) { func.func private @external() } -// CHECK: func @xla.gpu.cuda.graph.capture +// CHECK: func @xla.gpu.graph.capture // CHECK-NEXT: arith.constant 1 // CHECK-NEXT: gpu.launch_func @gpu_module::@fn0 // CHECK-NEXT: gpu.launch_func @gpu_module::@fn1 @@ -313,8 +313,8 @@ module attributes {gpu.container_module} { %c2 = arith.constant 0 : index %arg2 = memref.view %raw_arg2[%c2][] : memref<16xi8> to memref<2x2xf32> - // CHECK: call @xla.gpu.cuda.graph.launch(%[[ARG0]], %[[ARG1]], %[[ARG2]]) - // CHECK-SAME: {capture = @xla.gpu.cuda.graph.capture} + // CHECK: call @xla.gpu.graph.launch(%[[ARG0]], %[[ARG1]], %[[ARG2]]) + // CHECK-SAME: {capture = @xla.gpu.graph.capture} "lmhlo_gpu.gemm"(%arg0, %arg1, %arg2) {alpha_imag = 0.000000e+00 : f64, alpha_real = 1.000000e+00 : f64, beta = 0.000000e+00 : f64, batch_size = 1 : i64, lhs_stride = 4 : i64, rhs_stride = 4 : i64, dot_dimension_numbers = #mhlo.dot} : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () gpu.launch_func @gpu_module::@fn0 blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%raw_arg0 : memref<16xi8>) @@ -324,7 +324,7 @@ module attributes {gpu.container_module} { func.func private @external() } -// CHECK: func @xla.gpu.cuda.graph.capture +// CHECK: func @xla.gpu.graph.capture // CHECK-NEXT: arith.constant 0 // CHECK-NEXT: memref.view // CHECK-NEXT: arith.constant 0 @@ -361,7 +361,7 @@ module attributes {gpu.container_module} { %arg2 = memref.view %raw_arg2[%c2][] : memref<16xi8> to memref<2x2xf32> - // CHECK-NOT: call @xla.gpu.cuda.graph.launch + // CHECK-NOT: call @xla.gpu.graph.launch // CHECK: "lmhlo_gpu.gemm" "lmhlo_gpu.gemm"(%arg0, %arg1, %arg2) {algorithm = -5, alpha_imag = 0.000000e+00 : f64, alpha_real = 1.000000e+00 : f64, beta = 0.000000e+00 : f64, batch_size = 1 : i64, lhs_stride = 4 : i64, rhs_stride = 4 : i64, dot_dimension_numbers = #mhlo.dot} : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () gpu.launch_func @gpu_module::@fn0 blocks in (%c1, %c1, %c1) @@ -468,8 +468,8 @@ module attributes {gpu.container_module} { ) { %c0 = arith.constant 0 : index - // CHECK: call @xla.gpu.cuda.graph.launch(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]]) - // CHECK-SAME: {capture = @xla.gpu.cuda.graph.capture} + // CHECK: call @xla.gpu.graph.launch(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]]) + // CHECK-SAME: {capture = @xla.gpu.graph.capture} lmhlo_gpu.conv_forward(%input, %filter, %output, %scratch) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = { stride = [1, 1], @@ -504,7 +504,7 @@ module attributes {gpu.container_module} { func.func private @external() } -// CHECK: func @xla.gpu.cuda.graph.capture +// CHECK: func @xla.gpu.graph.capture // CHECK-NEXT: arith.constant 0 // CHECK-NEXT: lmhlo_gpu.conv_forward // CHECK-NEXT: gpu.launch_func @gpu_module::@fn0 @@ -521,8 +521,8 @@ module attributes {gpu.container_module} { %dst = memref.view %arg0[%c0][] : memref<100xi8> to memref<10xf32> %src = memref.view %arg0[%c0][] : memref<100xi8> to memref<10xf32> - // CHECK: call @xla.gpu.cuda.graph.launch(%[[ARG0]]) - // CHECK-SAME: {capture = @xla.gpu.cuda.graph.capture} + // CHECK: call @xla.gpu.graph.launch(%[[ARG0]]) + // CHECK-SAME: {capture = @xla.gpu.graph.capture} gpu.memcpy %dst, %src : memref<10xf32>, memref<10xf32> gpu.memcpy %dst, %src : memref<10xf32>, memref<10xf32> @@ -532,7 +532,7 @@ module attributes {gpu.container_module} { func.func private @external() } -// CHECK: func @xla.gpu.cuda.graph.capture +// CHECK: func @xla.gpu.graph.capture // CHECK: gpu.memcpy // CHECK: gpu.memcpy // CHECK-NEXT: return @@ -555,8 +555,8 @@ func.func @func(%arg0: memref<16xi8>) { call @external() : () -> () - // CHECK: call @xla.gpu.cuda.graph.launch(%[[ARG0]]) - // CHECK-SAME: {capture = @xla.gpu.cuda.graph.capture} + // CHECK: call @xla.gpu.graph.launch(%[[ARG0]]) + // CHECK-SAME: {capture = @xla.gpu.graph.capture} // CHECK-NEXT: return gpu.launch_func @gpu_module::@fn0 blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%view : memref<16xi8, strided<[1], offset: 0>>) @@ -569,7 +569,7 @@ func.func @func(%arg0: memref<16xi8>) { func.func private @external() } -// CHECK: func @xla.gpu.cuda.graph.capture +// CHECK: func @xla.gpu.graph.capture // CHECK-NEXT: arith.constant 1 // CHECK-NEXT: memref.reinterpret_cast // CHECK-NEXT: gpu.launch_func @gpu_module::@fn0 @@ -593,15 +593,15 @@ func.func @func(%arg0: memref<16xi8>, %cond: memref) { call @external() : () -> () "lmhlo.while"(%cond) ({ - // CHECK: func.call @xla.gpu.cuda.graph.launch(%[[ARG0]]) - // CHECK-SAME: {capture = @xla.gpu.cuda.graph.capture} + // CHECK: func.call @xla.gpu.graph.launch(%[[ARG0]]) + // CHECK-SAME: {capture = @xla.gpu.graph.capture} gpu.launch_func @gpu_module::@fn0 blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%arg0: memref<16xi8>) gpu.launch_func @gpu_module::@fn1 blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%arg0: memref<16xi8>) "lmhlo.terminator"() : () -> () }, { - // CHECK: func.call @xla.gpu.cuda.graph.launch(%[[ARG0]]) - // CHECK-SAME: {capture = @xla.gpu.cuda.graph.capture_0} + // CHECK: func.call @xla.gpu.graph.launch(%[[ARG0]]) + // CHECK-SAME: {capture = @xla.gpu.graph.capture_0} gpu.launch_func @gpu_module::@fn0 blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%arg0: memref<16xi8>) gpu.launch_func @gpu_module::@fn1 blocks in (%c1, %c1, %c1) @@ -614,13 +614,13 @@ func.func @func(%arg0: memref<16xi8>, %cond: memref) { func.func private @external() } -// CHECK: func @xla.gpu.cuda.graph.capture +// CHECK: func @xla.gpu.graph.capture // CHECK-NEXT: arith.constant 1 // CHECK-NEXT: gpu.launch_func @gpu_module::@fn0 // CHECK-NEXT: gpu.launch_func @gpu_module::@fn1 // CHECK-NEXT: return -// CHECK: func @xla.gpu.cuda.graph.capture_0 +// CHECK: func @xla.gpu.graph.capture_0 // CHECK-NEXT: arith.constant 1 // CHECK-NEXT: gpu.launch_func @gpu_module::@fn0 // CHECK-NEXT: gpu.launch_func @gpu_module::@fn1 @@ -647,8 +647,8 @@ func.func @func(%arg0: memref {lmhlo.constant_name = "cst0"}, %arg1: memref {lmhlo.constant_name = "cst1"}) { %c1 = arith.constant 1 : index - // CHECK: call @xla.gpu.cuda.graph.launch(%[[ARG0]], %[[ARG1]]) - // CHECK-SAME: {capture = @xla.gpu.cuda.graph.capture} + // CHECK: call @xla.gpu.graph.launch(%[[ARG0]], %[[ARG1]]) + // CHECK-SAME: {capture = @xla.gpu.graph.capture} // CHECK-NEXT: return gpu.launch_func @gpu_module::@fn0 @@ -664,7 +664,7 @@ func.func @func(%arg0: memref {lmhlo.constant_name = "cst0"}, func.return } -// CHECK: func @xla.gpu.cuda.graph.capture( +// CHECK: func @xla.gpu.graph.capture( // CHECK-SAME: %[[ARG0]]: memref {lmhlo.constant_name = "cst0"}, // CHECK-SAME: %[[ARG1]]: memref {lmhlo.constant_name = "cst1"}) // CHECK-NEXT: %[[C1:.*]] = arith.constant 1 @@ -676,6 +676,6 @@ func.func @func(%arg0: memref {lmhlo.constant_name = "cst0"}, // CHECK-SAME: threads in (%[[C1]], %[[C1]], %[[C1]]) // CHECK-NEXT: return -// CHECK: func private @xla.gpu.cuda.graph.launch(memref, memref) -// CHECK-SAME: attributes {rt.custom_call = "xla.gpu.cuda.graph.launch"} +// CHECK: func private @xla.gpu.graph.launch(memref, memref) +// CHECK-SAME: attributes {rt.custom_call = "xla.gpu.graph.launch"} } diff --git a/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.cc b/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.cc index 46288839d6567a..6bd4e7a6c64782 100644 --- a/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.cc +++ b/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.cc @@ -120,7 +120,7 @@ AutotunerCompileUtil::AutotunerCompileUtil(const AutotuneConfig& config, // Avoid using another thread pool. opts_.set_xla_gpu_force_compilation_parallelism(1); // Avoid using GPU graphs as we don't want to measure graph construction time. - opts_.set_xla_gpu_cuda_graph_level(0); + opts_.set_xla_gpu_graph_level(0); // Disable experimental OpenXLA runtime. opts_.set_xla_gpu_enable_openxla_runtime(false); } diff --git a/tensorflow/compiler/xla/service/gpu/compile_module_to_llvm_ir.cc b/tensorflow/compiler/xla/service/gpu/compile_module_to_llvm_ir.cc index fd9081b3b716c1..07347af3025796 100644 --- a/tensorflow/compiler/xla/service/gpu/compile_module_to_llvm_ir.cc +++ b/tensorflow/compiler/xla/service/gpu/compile_module_to_llvm_ir.cc @@ -117,10 +117,10 @@ static Status LowerToXlaGpuRuntime(mlir::ModuleOp module, mlir::PassManager pm(module->getName(), mlir::PassManager::Nesting::Implicit); GpuPipelineOpts opts; - opts.cuda_graph_level = debug_options.xla_gpu_cuda_graph_level(); - opts.min_graph_size = debug_options.xla_gpu_cuda_graph_min_graph_size(); + opts.gpu_graph_level = debug_options.xla_gpu_graph_level(); + opts.min_graph_size = debug_options.xla_gpu_graph_min_graph_size(); opts.enable_concurrent_region = - debug_options.xla_gpu_cuda_graph_enable_concurrent_region(); + debug_options.xla_gpu_graph_enable_concurrent_region(); populateXlaGpuRuntimePasses(pm, thunk_sequence, opts); if (pm.run(module).failed()) { diff --git a/tensorflow/compiler/xla/service/gpu/runtime/BUILD b/tensorflow/compiler/xla/service/gpu/runtime/BUILD index 98cb3764cba885..9076401a67147d 100644 --- a/tensorflow/compiler/xla/service/gpu/runtime/BUILD +++ b/tensorflow/compiler/xla/service/gpu/runtime/BUILD @@ -350,15 +350,14 @@ cc_library( "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service/gpu:non_atomically_upgradeable_rw_lock", "//tensorflow/compiler/xla/stream_executor", + "//tensorflow/compiler/xla/stream_executor/gpu:gpu_graph", "//tensorflow/tsl/profiler/lib:scoped_annotation_stack", "//tensorflow/tsl/profiler/lib:traceme", "//tensorflow/tsl/profiler/lib:traceme_encode", "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", - ] + if_cuda_is_configured([ - "//tensorflow/compiler/xla/stream_executor/cuda:cuda_graph", - ]), + ], ) cc_library( @@ -407,11 +406,10 @@ cc_library( "//tensorflow/compiler/xla/service/gpu:launch_dimensions", "//tensorflow/compiler/xla/service/gpu:stream_executor_util", "//tensorflow/compiler/xla/stream_executor", + "//tensorflow/compiler/xla/stream_executor/gpu:gpu_graph", "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/synchronization", - ] + if_cuda_is_configured([ - "//tensorflow/compiler/xla/stream_executor/cuda:cuda_graph", - ]), + ], ) cc_library( diff --git a/tensorflow/compiler/xla/service/gpu/runtime/executable.cc b/tensorflow/compiler/xla/service/gpu/runtime/executable.cc index da1a9d78a19196..d70f834c0ed94c 100644 --- a/tensorflow/compiler/xla/service/gpu/runtime/executable.cc +++ b/tensorflow/compiler/xla/service/gpu/runtime/executable.cc @@ -97,13 +97,13 @@ void RegisterXlaGpuRuntimeCustomCalls(DirectCustomCallRegistry& registry) { RegisterTopkCustomCall(registry); #if GOOGLE_CUDA + RegisterMatmulCustomCalls(registry); +#endif // GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM // Graph launch kernels depend on Cuda Graph API. RegisterGraphLaunchCustomCalls(registry); RegisterConcurrentRegionCustomCalls(registry); - RegisterMatmulCustomCalls(registry); -#endif // GOOGLE_CUDA -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM RegisterXlaClassicCustomCalls(registry); #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } @@ -151,9 +151,9 @@ GpuRuntimeExecutable::GpuRuntimeExecutable( buffer_sizes_(std::move(buffer_sizes)), executable_(std::move(jit_executable)), debug_options_(std::move(debug_options)), -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM graph_instances_(module_name_, GetNumGraphs(executable())), -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM modules_state_(std::move(modules_state)), ffi_modules_state_(std::move(ffi_modules_state)) { ExportModules(dynamic_custom_calls_); // export runtime modules @@ -168,9 +168,9 @@ GpuRuntimeExecutable::GpuRuntimeExecutable( buffer_sizes_(std::move(buffer_sizes)), executable_(std::move(aot_executable)), debug_options_(std::move(debug_options)), -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM graph_instances_(module_name_, GetNumGraphs(executable())), -#endif // GOOGL_CUDA +#endif // GOOGL_CUDA || TENSORFLOW_USE_ROCM modules_state_(std::move(modules_state)), ffi_modules_state_(std::move(ffi_modules_state)) { ExportModules(dynamic_custom_calls_); // export runtime modules @@ -385,7 +385,7 @@ Status GpuRuntimeExecutable::Execute( StreamExecutorConvRunners::Snapshot conv_runners = conv_runners_(executor)->snapshot(); -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM std::shared_ptr executor_graphs = graph_instances_(executor); @@ -393,7 +393,7 @@ Status GpuRuntimeExecutable::Execute( executor_graphs->snapshot(); CapturedFunctionExecutionCount::Snapshot execution_count = captured_function_counts_(executor)->snapshot(); -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM // Kernels in concurrent regions should be launched on borrowed stream, so // that the cuda graph won't record dependencies between kernels. @@ -420,8 +420,11 @@ Status GpuRuntimeExecutable::Execute( &collectives_, &fft_plans, &send_recv_events, &gpu_lock, #if GOOGLE_CUDA // Auxiliary data that is available only if compiled with CUDA support. - &matmul_plans, &graph_instances, &execution_count, + &matmul_plans, #endif // GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + &graph_instances, &execution_count, +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM &concurrent_region_status, // Null pointer will be interpreted as an absence of async collectives // support and custom calls will safely return an error. @@ -433,9 +436,9 @@ Status GpuRuntimeExecutable::Execute( return InternalError("Failed to initialize runtime modules state: %s", state_ref.status().message()); -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM // Instantiate all CUDA graphs before executing the main function. - if (debug_options_.xla_gpu_cuda_graph_num_runs_to_instantiate() < 0 && + if (debug_options_.xla_gpu_graph_num_runs_to_instantiate() < 0 && !graph_instances_.InstantiatedAllGraphs(run_options, executable)) { // To instantiate all Gpu graphs we have to pass a valid device pointer // because some device operations in XLA (e.g. memcpy) query device @@ -455,13 +458,13 @@ Status GpuRuntimeExecutable::Execute( if (auto instantiated = graph_instances_.InstantiateAllGraphs( run_options, executable, user_data, device_ptr, - debug_options_.xla_gpu_cuda_graph_eviction_timeout_seconds()); + debug_options_.xla_gpu_graph_eviction_timeout_seconds()); !instantiated.ok()) { - return InternalError("Failed to instantiate CUDA graphs: %s", + return InternalError("Failed to instantiate GPU graphs: %s", instantiated.message()); } } -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM // Collect all emitted diagnostic messages. std::string diagnostic; diff --git a/tensorflow/compiler/xla/service/gpu/runtime/executable.h b/tensorflow/compiler/xla/service/gpu/runtime/executable.h index 0405c86db315b6..84189c0ead3cc8 100644 --- a/tensorflow/compiler/xla/service/gpu/runtime/executable.h +++ b/tensorflow/compiler/xla/service/gpu/runtime/executable.h @@ -164,11 +164,13 @@ class GpuRuntimeExecutable { #if GOOGLE_CUDA // Keep matmul execution plans (only if cuBLASLt is available). MatmulPlans cublas_lt_matmul_plans_; +#endif // GOOGLE_CUDA - // Keep captured and instantiated CUDA graphs instances. +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + // Keep captured and instantiated GPU graphs instances. GraphInstances graph_instances_; CapturedFunctionExecutionCounts captured_function_counts_; -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM // Keep an executable state for all registered runtime modules. ModulesState modules_state_; diff --git a/tensorflow/compiler/xla/service/gpu/runtime/graph_launch.cc b/tensorflow/compiler/xla/service/gpu/runtime/graph_launch.cc index da78f21227822e..de914f695d726a 100644 --- a/tensorflow/compiler/xla/service/gpu/runtime/graph_launch.cc +++ b/tensorflow/compiler/xla/service/gpu/runtime/graph_launch.cc @@ -40,9 +40,9 @@ limitations under the License. #include "tensorflow/tsl/profiler/lib/traceme.h" #include "tensorflow/tsl/profiler/lib/traceme_encode.h" -#if GOOGLE_CUDA -#include "tensorflow/compiler/xla/stream_executor/cuda/cuda_graph.h" -#endif // #if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#include "tensorflow/compiler/xla/stream_executor/gpu/gpu_graph.h" +#endif // #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM namespace xla { namespace gpu { @@ -60,18 +60,18 @@ using xla::runtime::MemrefDesc; using xla::runtime::MemrefType; using xla::runtime::StridedMemrefView; -#if GOOGLE_CUDA -using se::gpu::OwnedCudaGraph; +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +using se::gpu::OwnedGpuGraph; // Captures Gpu graph by running given function in capture mode. -static absl::StatusOr CaptureGraph( +static absl::StatusOr CaptureGraph( const ServiceExecutableRunOptions* run_options, runtime::FunctionRef function_ref, Arguments& args, CustomCall::UserData user_data); -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM //===----------------------------------------------------------------------===// -// CUDA graphs caching. +// GPU graphs caching. //===----------------------------------------------------------------------===// struct GraphInstances::Impl { @@ -273,7 +273,7 @@ Status GraphInstances::InstantiateAllGraphs( // All Gpu graphs are already instantiated for a given executor. if (state.instantiated) return OkStatus(); - TraceMe trace("cuda.graph.instantiate_all"); + TraceMe trace("gpu.graph.instantiate_all"); // Evict all timeout graphs before trying to instantiate new ones. EvictAllGraphs(executor, eviction_timeout_seconds); @@ -290,7 +290,7 @@ Status GraphInstances::InstantiateAllGraphs( // with correct pointers. for (unsigned ordinal = 1; ordinal < executable.num_functions(); ++ordinal) { if (!absl::StartsWith(executable.function_name(ordinal), - "xla.gpu.cuda.graph.capture")) + "xla.gpu.graph.capture")) continue; VLOG(3) << "Instantiate Gpu graph defined by capture function @" @@ -298,7 +298,7 @@ Status GraphInstances::InstantiateAllGraphs( << ")"; TraceMe trace_instantiation([&] { - return TraceMeEncode("cuda.graph.instantiate", {{"ordinal", ordinal}}); + return TraceMeEncode("gpu.graph.instantiate", {{"ordinal", ordinal}}); }); FunctionRef function_ref = executable.function_ref(ordinal); @@ -327,12 +327,12 @@ Status GraphInstances::InstantiateAllGraphs( /*offset=*/0, sizes, strides); } -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM // Instantiate a Gpu graph with fake arguments. auto instantiate = [&]() -> absl::StatusOr { TF_ASSIGN_OR_RETURN( auto g, CaptureGraph(run_options, function_ref, args, user_data)); - TF_ASSIGN_OR_RETURN(auto e, se::gpu::InstantiateCudaGraph(std::move(g))); + TF_ASSIGN_OR_RETURN(auto e, se::gpu::InstantiateGpuGraph(std::move(g))); return GraphInstance(0, std::move(e)); }; @@ -349,7 +349,7 @@ Status GraphInstances::InstantiateAllGraphs( // Otherwise return an error to the caller. if (!instance.ok()) return instance.status(); -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } state.instantiated = true; @@ -384,10 +384,10 @@ H AbslHashValue(H h, const RemainingArgsPtrs& m) { } //----------------------------------------------------------------------------// -// Runs capture function exported by the executable to constuct a CUDA graph. +// Runs capture function exported by the executable to construct a gpu graph. //----------------------------------------------------------------------------// -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM static bool InDebugMode() { #ifdef NDEBUG @@ -413,7 +413,7 @@ static absl::Status ForwardArguments(CustomCall::RemainingArgs fwd_args, return OkStatus(); } -static absl::StatusOr CaptureGraph( +static absl::StatusOr CaptureGraph( const ServiceExecutableRunOptions* run_options, runtime::FunctionRef function_ref, Arguments& args, CustomCall::UserData user_data) { @@ -423,7 +423,7 @@ static absl::StatusOr CaptureGraph( se::StreamExecutor* executor = run_options->stream()->parent(); // Initialize (with memoization) BlasSupport here because cublasCreate fails - // during cuda graph capturing. + // during gpu graph capturing. if (function_ref.RequiresBlas()) { if (!executor->AsBlas()) { return absl::InternalError("Failed to initialize BLAS support"); @@ -439,7 +439,7 @@ static absl::StatusOr CaptureGraph( capture_stream.status().message())); TraceMe trace([&] { - return TraceMeEncode("cuda.graph.capture", + return TraceMeEncode("gpu.graph.capture", {{"ordinal", function_ref.ordinal()}}); }); @@ -468,14 +468,14 @@ static absl::StatusOr CaptureGraph( opts.async_task_runner = reinterpret_cast(0XDEADBEEF); // Create a graph from running the graph capture function. - auto captured = se::gpu::CaptureCudaGraph(capture_stream->get(), [&]() { + auto captured = se::gpu::CaptureGpuGraph(capture_stream->get(), [&]() { return function_ref(args, runtime::NoResultConverter{}, opts, /*verify_arguments=*/InDebugMode()) .status(); }); if (!captured.ok()) { - return InternalError("CaptureCudaGraph failed (%s): %s", + return InternalError("CaptureGpuGraph failed (%s): %s", diagnostic.empty() ? "" : diagnostic, captured.status().ToString()); } @@ -491,7 +491,7 @@ static absl::Status RunGraphWithoutCapture( opts.custom_call_data = &user_data; TraceMe trace([&] { - return TraceMeEncode("cuda.graph.run_no_capture", + return TraceMeEncode("gpu.graph.run_no_capture", {{"ordinal", function_ref.ordinal()}}); }); @@ -518,10 +518,10 @@ static absl::Status RunGraphWithoutCapture( return absl::OkStatus(); } -#endif // #if GOOGLE_CUDA +#endif // #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM //===----------------------------------------------------------------------===// -// Define the cuda graph launch custom call. +// Define the gpu graph launch custom call. //===----------------------------------------------------------------------===// static absl::Status LaunchGraph( @@ -536,10 +536,10 @@ static absl::Status LaunchGraph( NonAtomicallyUpgradeableRWLock* gpu_lock, ConcurrentRegionStatus* region_status, CustomCall::RemainingArgs fwd_args, CustomCall::FunctionOrdinal capture) { -#if GOOGLE_CUDA - VLOG(1) << "Launch Cuda Graph: ordinal = " << capture.ordinal; +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + VLOG(1) << "Launch GPU Graph: ordinal = " << capture.ordinal; - // Get a reference to exported function that captures the cuda graph. + // Get a reference to exported function that captures the gpu graph. runtime::FunctionRef function_ref = executable->function_ref(capture.ordinal); // Compute the hash of the buffer arguments. @@ -559,11 +559,11 @@ static absl::Status LaunchGraph( int64_t count = (*get_count)->fetch_add(1); int64_t num_runs_to_instantiate = - debug_options->xla_gpu_cuda_graph_num_runs_to_instantiate(); + debug_options->xla_gpu_graph_num_runs_to_instantiate(); // TODO(ezhulenev): Cupti tracing leads to deadlocks in CUDA 11. Always fall // back on regular execution if we detect tracing activity. -#if CUDA_VERSION >= 12000 +#if defined(CUDA_VERSION) && CUDA_VERSION >= 12000 bool is_profiling = false; #else bool is_profiling = tsl::profiler::ScopedAnnotationStack::IsEnabled(); @@ -583,7 +583,7 @@ static absl::Status LaunchGraph( TF_ASSIGN_OR_RETURN( auto g, CaptureGraph(run_options, function_ref, args, user_data())); - TF_ASSIGN_OR_RETURN(auto e, se::gpu::InstantiateCudaGraph(std::move(g))); + TF_ASSIGN_OR_RETURN(auto e, se::gpu::InstantiateGpuGraph(std::move(g))); return GraphInstance(ptrs_hash, std::move(e)); }; @@ -599,7 +599,7 @@ static absl::Status LaunchGraph( // If pointers did not change we can run captured graph. if (ptrs_hash == instance->ptr_hash) { TraceMe trace([&] { - return TraceMeEncode("cuda.graph.launch_cached", + return TraceMeEncode("gpu.graph.launch_cached", {{"ordinal", capture.ordinal}}); }); @@ -614,7 +614,7 @@ static absl::Status LaunchGraph( Arguments args(fwd_args.size()); TF_RETURN_IF_ERROR(ForwardArguments(fwd_args, args)); - // Capture CUDA graph by running capture function. + // Capture GPU graph by running capture function. TF_ASSIGN_OR_RETURN( auto g, CaptureGraph(run_options, function_ref, args, user_data())); @@ -629,24 +629,23 @@ static absl::Status LaunchGraph( instance->ptr_hash = ptrs_hash; TraceMe trace([&] { - return TraceMeEncode("cuda.graph.launch_updated", + return TraceMeEncode("gpu.graph.launch_updated", {{"ordinal", capture.ordinal}}); }); return instance->exec.Launch(run_options->stream()); +#else // #if !GOOGLE_CUDA && !TENSORFLOW_USE_ROCM -#else // #if !GOOGLE_CUDA + return absl::InternalError("GPU graphs are not supported"); - return absl::InternalError("Cuda graphs are not supported"); - -#endif // #if GOOGLE_CUDA +#endif // #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM } //===----------------------------------------------------------------------===// XLA_RUNTIME_DEFINE_CUSTOM_CALL( Launch, FunctionWrapper(), checks, - CustomCall::Bind("xla.gpu.cuda.graph.launch") + CustomCall::Bind("xla.gpu.graph.launch") .UserData() .UserData() .UserData() @@ -665,7 +664,7 @@ XLA_RUNTIME_DEFINE_CUSTOM_CALL( void RegisterGraphLaunchCustomCalls( runtime::DirectCustomCallRegistry& registry) { - registry.Register("xla.gpu.cuda.graph.launch", Launch); + registry.Register("xla.gpu.graph.launch", Launch); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/runtime/graph_launch.h b/tensorflow/compiler/xla/service/gpu/runtime/graph_launch.h index c94e9b63f7a9b3..1cf075314ae83d 100644 --- a/tensorflow/compiler/xla/service/gpu/runtime/graph_launch.h +++ b/tensorflow/compiler/xla/service/gpu/runtime/graph_launch.h @@ -29,9 +29,9 @@ limitations under the License. #include "tensorflow/compiler/xla/stream_executor/stream_executor.h" #include "tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.h" -#if GOOGLE_CUDA -#include "tensorflow/compiler/xla/stream_executor/cuda/cuda_graph.h" -#endif // #if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#include "tensorflow/compiler/xla/stream_executor/gpu/gpu_graph.h" +#endif // #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM namespace xla { namespace gpu { @@ -48,37 +48,37 @@ class StreamExecutorGraphInstances; // Forward declare class CapturedFunctionExecutionCount : public runtime::StateVector>> {}; -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -// A state vector that owns all instantiated CUDA graphs. Graph capture function +// A state vector that owns all instantiated GPU graphs. Graph capture function // ordinal is the key in this container. class StreamExecutorGraphInstances : public runtime::StateVector {}; -// Instantiated CUDA graph instance guarded with a mutex for exclusive access. +// Instantiated GPU graph instance guarded with a mutex for exclusive access. struct GraphInstance { - GraphInstance(size_t ptr_hash, se::gpu::OwnedCudaGraphExec exec) + GraphInstance(size_t ptr_hash, se::gpu::OwnedGpuGraphExec exec) : ptr_hash(ptr_hash), exec(std::move(exec)), mutex(new absl::Mutex) {} // Graph instance is fully identified by the hash of its pointer arguments // because currently it's guaranteed that all shapes and launch dimensions // will be constant from run to run. size_t ptr_hash ABSL_GUARDED_BY(*mutex); - se::gpu::OwnedCudaGraphExec exec ABSL_GUARDED_BY(*mutex); + se::gpu::OwnedGpuGraphExec exec ABSL_GUARDED_BY(*mutex); // Access to a graph instance must be synchronized, because we potentially can // run concurrent graph instance updates. std::unique_ptr mutex; }; -#else // #if !GOOGLE_CUDA +#else // #if !GOOGLE_CUDA && !TENSORFLOW_USE_ROCM -// Define empty struct and empty state when CUDA is not enabled. +// Define empty struct and empty state when GPU is not enabled. struct GraphInstance {}; class StreamExecutorGraphInstances : public runtime::StateVector {}; -#endif // #if GOOGLE_CUDA +#endif // #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM // Xla executable keeps a mapping from stream executors to graph instances. // diff --git a/tensorflow/compiler/xla/service/gpu/runtime/kernel_launch.cc b/tensorflow/compiler/xla/service/gpu/runtime/kernel_launch.cc index 3fa3da0a336b93..26102bf378d7d8 100644 --- a/tensorflow/compiler/xla/service/gpu/runtime/kernel_launch.cc +++ b/tensorflow/compiler/xla/service/gpu/runtime/kernel_launch.cc @@ -31,9 +31,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/service_executable_run_options.h" #include "tensorflow/compiler/xla/stream_executor/kernel.h" -#if GOOGLE_CUDA -#include "tensorflow/compiler/xla/stream_executor/cuda/cuda_graph.h" -#endif // #if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#include "tensorflow/compiler/xla/stream_executor/gpu/gpu_graph.h" +#endif // #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM namespace xla { namespace gpu { @@ -79,7 +79,7 @@ static absl::Status LaunchImpl( })); assert((*kernel)->name() == name && "unexpected loaded kernel"); -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM TF_ASSIGN_OR_RETURN(bool is_capturing, se::gpu::IsStreamCapturing(stream)); #else bool is_capturing = false; @@ -88,10 +88,10 @@ static absl::Status LaunchImpl( if (is_capturing) { if (region_status->IsInConcurrentRegion()) { VLOG(3) << "Launching " << (*kernel)->name() - << "in a concurrent region during CUDA graph capture"; + << "in a concurrent region during GPU graph capture"; } else { VLOG(3) << "Launching " << (*kernel)->name() - << "during CUDA graph capture"; + << "during GPU graph capture"; } } else { VLOG(3) << "Launching " << (*kernel)->name(); @@ -117,7 +117,7 @@ static absl::Status LaunchImpl( // Always add temporary buffer as the last kernel argument. buffer_args.back() = *temp_buffer; - // If we are capturing a concurrent region in a CUDA graph, then use the + // If we are capturing a concurrent region in a GPU graph, then use the // stream provided by ConcurrentRegionStatus to execute the kernel. se::Stream* execution_stream = stream; if (region_status->IsInConcurrentRegion()) { diff --git a/tensorflow/compiler/xla/stream_executor/cuda/BUILD b/tensorflow/compiler/xla/stream_executor/cuda/BUILD index f3ad4282a83848..b9b2d5da004aae 100644 --- a/tensorflow/compiler/xla/stream_executor/cuda/BUILD +++ b/tensorflow/compiler/xla/stream_executor/cuda/BUILD @@ -489,17 +489,6 @@ cc_library( ]), ) -cc_library( - name = "cuda_graph", - srcs = if_cuda_is_configured(["cuda_graph.cc"]), - hdrs = if_cuda_is_configured(["cuda_graph.h"]), - deps = if_cuda_is_configured([ - "@com_google_absl//absl/strings:str_format", - "@local_config_cuda//cuda:cuda_headers", - "//tensorflow/compiler/xla/stream_executor/gpu:gpu_stream", - ]), -) - cc_library( name = "cuda_stream", srcs = [], diff --git a/tensorflow/compiler/xla/stream_executor/cuda/cuda_graph.cc b/tensorflow/compiler/xla/stream_executor/cuda/cuda_graph.cc deleted file mode 100644 index 5363a570bef5b4..00000000000000 --- a/tensorflow/compiler/xla/stream_executor/cuda/cuda_graph.cc +++ /dev/null @@ -1,224 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/stream_executor/cuda/cuda_graph.h" - -#include - -#include "absl/strings/str_format.h" -#include "third_party/gpus/cuda/include/cuda_runtime_api.h" -#include "tensorflow/compiler/xla/stream_executor/gpu/gpu_stream.h" -#include "tensorflow/tsl/platform/env.h" -#include "tensorflow/tsl/platform/path.h" - -namespace stream_executor { -namespace gpu { - -//===----------------------------------------------------------------------===// -// RAII helpers for CUDA graph types. -//===----------------------------------------------------------------------===// - -std::atomic CudaGraphSupport::allocated_cuda_graph_execs_; -std::atomic CudaGraphSupport::alive_cuda_graph_execs_; - -/*static*/ size_t CudaGraphSupport::NotifyGraphExecCreated() { - alive_cuda_graph_execs_.fetch_add(1, std::memory_order_relaxed); - return allocated_cuda_graph_execs_.fetch_add(1, std::memory_order_relaxed); -} - -/*static*/ size_t CudaGraphSupport::NotifyGraphExecDestroyed() { - return alive_cuda_graph_execs_.fetch_sub(1, std::memory_order_relaxed) - 1; -} - -/*static*/ size_t CudaGraphSupport::allocated_cuda_graph_execs() { - return allocated_cuda_graph_execs_.load(std::memory_order_relaxed); -} - -/*static*/ size_t CudaGraphSupport::alive_cuda_graph_execs() { - return alive_cuda_graph_execs_.load(std::memory_order_relaxed); -} - -void CudaGraphSupport::DestroyGraph::operator()(cudaGraph_t graph) { - cudaError_t err = cudaGraphDestroy(graph); - CHECK(err == cudaSuccess) - << "Failed to destroy CUDA graph: " << cudaGetErrorString(err); -} - -void CudaGraphSupport::DestroyGraphExec::operator()(cudaGraphExec_t instance) { - cudaError_t err = cudaGraphExecDestroy(instance); - CHECK(err == cudaSuccess) - << "Failed to destroy CUDA graph instance: " << cudaGetErrorString(err); -} - -tsl::Status OwnedCudaGraphExec::Update(OwnedCudaGraph graph) { - VLOG(3) << "Update CUDA graph exec with a new graph after " << num_launches_ - << " launches since last update" - << " #" << num_updates_++; - - num_launches_ = 0; - -#if CUDA_VERSION >= 12000 - cudaGraphExecUpdateResultInfo updated; - - auto err = cudaGraphExecUpdate(get(), graph.get(), &updated); - if (err != cudaSuccess || updated.result != cudaGraphExecUpdateSuccess) - return absl::InternalError(absl::StrFormat( - "failed to update cuda graph: %s", cudaGetErrorString(err))); - -#else - cudaGraphExecUpdateResult updated; - cudaGraphNode_t error_node; - - auto err = cudaGraphExecUpdate(get(), graph.get(), &error_node, &updated); - if (err != cudaSuccess || updated != cudaGraphExecUpdateSuccess) - return absl::InternalError(absl::StrFormat("Failed to update cuda graph %s", - cudaGetErrorString(err))); -#endif - - return tsl::OkStatus(); -} - -tsl::Status OwnedCudaGraphExec::Launch(stream_executor::Stream* stream) { - VLOG(3) << "Launch CUDA graph " << get() - << " on a stream: " << stream->DebugStreamPointers() << " #" - << ++num_launches_; - - if (auto err = cudaGraphLaunch(get(), AsGpuStreamValue(stream)); - err != cudaSuccess) - return absl::InternalError(absl::StrFormat("failed to run cuda graph: %s", - cudaGetErrorString(err))); - - return tsl::OkStatus(); -} - -OwnedCudaGraphExec::~OwnedCudaGraphExec() { - if (*this) // do not log for moved-from instances - VLOG(5) << "Destroy CUDA graph exec #" << id_ - << " (remaining alive instances: " - << CudaGraphSupport::NotifyGraphExecDestroyed() << ")"; -} - -//===----------------------------------------------------------------------===// -// CUDA Graph Helpers. -//===----------------------------------------------------------------------===// - -tsl::StatusOr CaptureCudaGraph( - stream_executor::Stream* stream, absl::AnyInvocable capture, - cudaStreamCaptureMode mode) { - VLOG(3) << "Capture CUDA graph on a stream: " - << stream->DebugStreamPointers(); - - cudaGraph_t graph; - - // Get the underlying CUDA stream for passing to CUDA APIs. - auto gpu_stream = AsGpuStreamValue(stream); - - // Capture graph constructed by the exported graph capture function. - if (auto err = cudaStreamBeginCapture(gpu_stream, mode); err != cudaSuccess) - return absl::InternalError(absl::StrFormat( - "stream begin capture failed: %s", cudaGetErrorString(err))); - - // Call into graph capture function. - auto captured = capture(); - - // Always stop capturing the stream before checking `captured` result. - if (auto err = cudaStreamEndCapture(gpu_stream, &graph); err != cudaSuccess) - return absl::InternalError(absl::StrFormat("stream end capture failed: %s", - cudaGetErrorString(err))); - - if (!captured.ok()) - return absl::InternalError(absl::StrFormat( - "failed to capture CUDA graph: %s", captured.message())); - - VLOG(5) << "Captured CUDA graph " << graph; - -#if CUDA_VERSION >= 12000 - // If verbose logging is enabled print captured CUDA graph debug information. - if (VLOG_IS_ON(100)) { - if (const char* path = getenv("XLA_CUDA_GRAPH_DEBUG_DIRECTORY"); path) { - std::string file = tsl::io::JoinPath(std::string(path), "/cuda_graph-"); - - if (tsl::Env::Default()->CreateUniqueFileName(&file, ".dot")) { - VLOG(100) << "Print CUDA graph " << graph - << " debug dot file to: " << file; - - int flags = cudaGraphDebugDotFlagsVerbose; - if (auto err = cudaGraphDebugDotPrint(graph, file.c_str(), flags); - err != cudaSuccess) { - LOG(WARNING) << "failed to print CUDA graph debug file: " - << cudaGetErrorString(err); - - } else if (VLOG_IS_ON(200)) { - std::string data; - if (tsl::ReadFileToString(tsl::Env::Default(), file, &data).ok()) { - VLOG(200) << "CUDA graph " << graph << " debug file:\n" << data; - } else { - LOG(WARNING) << "failed to read CUDA graph debug file"; - } - } - - } else { - LOG(WARNING) << "cannot create unique filename, won't enable CUDA " - "graph debugging"; - } - } - } -#endif // CUDA_VERSION >= 12000 - - return OwnedCudaGraph(graph); -} - -tsl::StatusOr InstantiateCudaGraph(OwnedCudaGraph graph) { - cudaGraphExec_t exec; - -#if CUDA_VERSION >= 12000 - if (auto err = cudaGraphInstantiate(&exec, &*graph); -#else - if (auto err = cudaGraphInstantiate(&exec, &*graph, nullptr, nullptr, 0); -#endif - err != cudaSuccess) { - if (err == cudaErrorMemoryAllocation) { - // OOM is a recoverable error, we evict all instantiated cuda graphs to - // free up some space (see graph launch.cc). Clear error status. - return absl::ResourceExhaustedError( - absl::StrFormat("graph instantiation failed: %s", - cudaGetErrorString(cudaGetLastError()))); - } else { - return absl::InternalError(absl::StrFormat( - "graph instantiation failed: %s", cudaGetErrorString(err))); - } - } - - size_t id = CudaGraphSupport::NotifyGraphExecCreated(); - VLOG(5) << "Instantiated CUDA graph exec instance #" << id - << " (alive instances: " << CudaGraphSupport::alive_cuda_graph_execs() - << ")"; - return OwnedCudaGraphExec(id, exec); -} - -tsl::StatusOr IsStreamCapturing(stream_executor::Stream* stream) { - cudaStreamCaptureStatus capture_status; - cudaError_t err = cudaStreamIsCapturing( - stream_executor::gpu::AsGpuStreamValue(stream), &capture_status); - if (err != cudaSuccess) { - return absl::InternalError(absl::StrFormat( - "Failed to get stream's capture status: %s", cudaGetErrorString(err))); - } - - return capture_status == cudaStreamCaptureStatusActive; -} - -} // namespace gpu -} // namespace stream_executor diff --git a/tensorflow/compiler/xla/stream_executor/gpu/BUILD b/tensorflow/compiler/xla/stream_executor/gpu/BUILD index eefd3bd0a5e00f..6e1e2bd9c8ec7a 100644 --- a/tensorflow/compiler/xla/stream_executor/gpu/BUILD +++ b/tensorflow/compiler/xla/stream_executor/gpu/BUILD @@ -424,3 +424,22 @@ tsl_gpu_library( "@com_google_absl//absl/strings", ], ) + +cc_library( + name = "gpu_graph", + srcs = if_gpu_is_configured(["gpu_graph.cc"]), + hdrs = if_gpu_is_configured(["gpu_graph.h"]), + deps = if_gpu_is_configured([ + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/functional:any_invocable", + "//tensorflow/compiler/xla/stream_executor/gpu:gpu_stream", + "//tensorflow/compiler/xla/stream_executor", + "//tensorflow/tsl/platform:env", + "//tensorflow/tsl/platform:path", + "//tensorflow/tsl/platform:statusor", + ]) + if_cuda_is_configured([ + "@local_config_cuda//cuda:cuda_headers", + ]) + if_rocm_is_configured([ + "//tensorflow/compiler/xla/stream_executor/rocm:rocm_driver", + ]), +) diff --git a/tensorflow/compiler/xla/stream_executor/gpu/gpu_graph.cc b/tensorflow/compiler/xla/stream_executor/gpu/gpu_graph.cc new file mode 100644 index 00000000000000..795e8474ea8e36 --- /dev/null +++ b/tensorflow/compiler/xla/stream_executor/gpu/gpu_graph.cc @@ -0,0 +1,259 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/stream_executor/gpu/gpu_graph.h" + +#include +#include + +#include "absl/strings/str_format.h" +#include "tensorflow/compiler/xla/stream_executor/gpu/gpu_stream.h" +#include "tensorflow/tsl/platform/env.h" +#include "tensorflow/tsl/platform/path.h" + +#if TENSORFLOW_USE_ROCM +using namespace stream_executor::wrap; // NOLINT[build/namespaces] +#define GPU_PREFIX hip +#else +#include "third_party/gpus/cuda/include/cuda_runtime_api.h" +#define GPU_PREFIX cuda +#endif + +#define GPU_CAT_NX(A, B) A##B +#define GPU_CAT(A, B) GPU_CAT_NX(A, B) +#define GPU(A) GPU_CAT(GPU_PREFIX, A) + +#define GpuGetErrorString GPU(GetErrorString) +#define GpuGraphDebugDotFlagsVerbose GPU(GraphDebugDotFlagsVerbose) +#define GpuGraphDebugDotPrint GPU(GraphDebugDotPrint) +#define GpuGraphDestroy GPU(GraphDestroy) +#define GpuErrorMemoryAllocation GPU(ErrorMemoryAllocation) +#define GpuGraphExecDestroy GPU(GraphExecDestroy) +#define GpuGraphExecUpdate GPU(GraphExecUpdate) +#define GpuGraphExecUpdateResult GPU(GraphExecUpdateResult) +#define GpuGraphExecUpdateSuccess GPU(GraphExecUpdateSuccess) +#define GpuGraphInstantiate GPU(GraphInstantiate) +#define GpuGraphLaunch GPU(GraphLaunch) +#define GpuGraphNode GPU(GraphNode_t) +#define GpuStreamBeginCapture GPU(StreamBeginCapture) +#define GpuStreamCaptureModeThreadLocal GPU(StreamCaptureModeThreadLocal) +#define GpuStreamCaptureStatus GPU(StreamCaptureStatus) +#define GpuStreamCaptureStatusActive GPU(StreamCaptureStatusActive) +#define GpuStreamEndCapture GPU(StreamEndCapture) +#define GpuStreamIsCapturing GPU(StreamIsCapturing) +#define GpuSuccess GPU(Success) + +#define RETURN_IF_GPU_GRAPH_ERROR(expr, ...) \ + do { \ + auto _res = (expr); \ + if (TF_PREDICT_FALSE(_res != GpuSuccess)) { \ + return tsl::errors::Internal(__VA_ARGS__, ": ", \ + GpuGetErrorString(_res)); \ + } \ + } while (0) + +namespace stream_executor { +namespace gpu { + +//===----------------------------------------------------------------------===// +// RAII helpers for gpu graph types. +//===----------------------------------------------------------------------===// + +std::atomic GpuGraphSupport::allocated_gpu_graph_execs_; +std::atomic GpuGraphSupport::alive_gpu_graph_execs_; + +/*static*/ size_t GpuGraphSupport::NotifyGraphExecCreated() { + alive_gpu_graph_execs_.fetch_add(1, std::memory_order_relaxed); + return allocated_gpu_graph_execs_.fetch_add(1, std::memory_order_relaxed); +} + +/*static*/ size_t GpuGraphSupport::NotifyGraphExecDestroyed() { + return alive_gpu_graph_execs_.fetch_sub(1, std::memory_order_relaxed) - 1; +} + +/*static*/ size_t GpuGraphSupport::allocated_gpu_graph_execs() { + return allocated_gpu_graph_execs_.load(std::memory_order_relaxed); +} + +/*static*/ size_t GpuGraphSupport::alive_gpu_graph_execs() { + return alive_gpu_graph_execs_.load(std::memory_order_relaxed); +} + +void GpuGraphSupport::DestroyGraph::operator()(GpuGraphHandle graph) { + auto err = GpuGraphDestroy(graph); + CHECK(err == GpuSuccess) << "Failed to destroy gpu graph: " + << GpuGetErrorString(err); +} + +void GpuGraphSupport::DestroyGraphExec::operator()( + GpuGraphExecHandle instance) { + auto err = GpuGraphExecDestroy(instance); + CHECK(err == GpuSuccess) << "Failed to destroy gpu graph instance: " + << GpuGetErrorString(err); +} + +tsl::Status OwnedGpuGraphExec::Update(OwnedGpuGraph graph) { + VLOG(3) << "Update gpu graph exec with a new graph after " << num_launches_ + << " launches since last update" + << " #" << num_updates_++; + + num_launches_ = 0; + +#if defined(CUDA_VERSION) && CUDA_VERSION >= 12000 + cudaGraphExecUpdateResultInfo updated; + + auto err = cudaGraphExecUpdate(get(), graph.get(), &updated); + if (err != cudaSuccess || updated.result != cudaGraphExecUpdateSuccess) + return tsl::errors::Internal("Failed to update gpu graph: ", + GpuGetErrorString(err)); + +#else + GpuGraphExecUpdateResult updated; + GpuGraphNode error_node; + + auto err = GpuGraphExecUpdate(get(), graph.get(), &error_node, &updated); + if (err != GpuSuccess || updated != GpuGraphExecUpdateSuccess) + return tsl::errors::Internal("Failed to update gpu graph: ", + GpuGetErrorString(err)); +#endif + + return tsl::OkStatus(); +} + +tsl::Status OwnedGpuGraphExec::Launch(stream_executor::Stream* stream) { + VLOG(3) << "Launch gpu graph " << get() + << " on a stream: " << stream->DebugStreamPointers() << " #" + << ++num_launches_; + + RETURN_IF_GPU_GRAPH_ERROR(GpuGraphLaunch(get(), AsGpuStreamValue(stream)), + "failed to run gpu graph"); + + return tsl::OkStatus(); +} + +OwnedGpuGraphExec::~OwnedGpuGraphExec() { + if (*this) // do not log for moved-from instances + VLOG(5) << "Destroy GPU graph exec #" << id_ + << " (remaining alive instances: " + << GpuGraphSupport::NotifyGraphExecDestroyed() << ")"; +} + +//===----------------------------------------------------------------------===// +// GPU Graph Helpers. +//===----------------------------------------------------------------------===// + +tsl::StatusOr CaptureGpuGraph( + stream_executor::Stream* stream, + absl::AnyInvocable capture) { + VLOG(3) << "Capture gpu graph on a stream: " << stream->DebugStreamPointers(); + + GpuGraphHandle graph; + + // Get the underlying stream for passing to GPU runtime APIs. + auto gpu_stream = AsGpuStreamValue(stream); + + // Capture graph constructed by the exported graph capture function. + RETURN_IF_GPU_GRAPH_ERROR( + GpuStreamBeginCapture(gpu_stream, GpuStreamCaptureModeThreadLocal), + "stream begin capture failed"); + + // Call into graph capture function. + auto captured = capture(); + + // Always stop capturing the stream before checking `captured` result. + RETURN_IF_GPU_GRAPH_ERROR(GpuStreamEndCapture(gpu_stream, &graph), + "stream end capture failed"); + + if (!captured.ok()) + return tsl::errors::Internal("failed to capture gpu graph: ", + captured.message()); + + VLOG(5) << "Captured gpu graph " << graph; + +#if TENSORFLOW_USE_ROCM || CUDA_VERSION >= 12000 + // If verbose logging is enabled print captured gpu graph debug information. + if (VLOG_IS_ON(100)) { + if (const char* path = getenv("XLA_GPU_GRAPH_DEBUG_DIRECTORY"); path) { + std::string file = tsl::io::JoinPath(std::string(path), "/gpu_graph-"); + + if (tsl::Env::Default()->CreateUniqueFileName(&file, ".dot")) { + VLOG(100) << "Print gpu graph " << graph + << " debug dot file to: " << file; + + int flags = GpuGraphDebugDotFlagsVerbose; + if (auto err = GpuGraphDebugDotPrint(graph, file.c_str(), flags); + err != GpuSuccess) { + LOG(WARNING) << "failed to print gpu graph debug file: " + << GpuGetErrorString(err); + + } else if (VLOG_IS_ON(200)) { + std::string data; + if (tsl::ReadFileToString(tsl::Env::Default(), file, &data).ok()) { + VLOG(200) << "gpu graph " << graph << " debug file:\n" << data; + } else { + LOG(WARNING) << "failed to read gpu graph debug file"; + } + } + + } else { + LOG(WARNING) << "cannot create unique filename, won't enable gpu " + "graph debugging"; + } + } + } +#endif // TENSORFLOW_USE_ROCM || CUDA_VERSION >= 12000 + + return OwnedGpuGraph(graph); +} + +tsl::StatusOr InstantiateGpuGraph(OwnedGpuGraph graph) { + GpuGraphExecHandle exec; + +#if defined(CUDA_VERSION) && CUDA_VERSION >= 12000 + if (auto err = cudaGraphInstantiate(&exec, &*graph); +#else + if (auto err = GpuGraphInstantiate(&exec, &*graph, nullptr, nullptr, 0); +#endif + err != GpuSuccess) { + if (err == GpuErrorMemoryAllocation) { + // OOM is a recoverable error, we evict all instantiated cuda graphs to + // free up some space (see graph launch.cc). Clear error status. + return absl::ResourceExhaustedError(absl::StrFormat( + "graph instantiation failed: %s", GpuGetErrorString(err))); + } else { + return absl::InternalError(absl::StrFormat( + "graph instantiation failed: %s", GpuGetErrorString(err))); + } + } + + size_t id = GpuGraphSupport::NotifyGraphExecCreated(); + VLOG(5) << "Instantiated gpu graph exec instance #" << id + << " (alive instances: " << GpuGraphSupport::alive_gpu_graph_execs() + << ")"; + return OwnedGpuGraphExec(id, exec); +} + +tsl::StatusOr IsStreamCapturing(stream_executor::Stream* stream) { + GpuStreamCaptureStatus capture_status; + RETURN_IF_GPU_GRAPH_ERROR( + GpuStreamIsCapturing(stream_executor::gpu::AsGpuStreamValue(stream), + &capture_status), + "Failed to get stream's capture status"); + + return capture_status == GpuStreamCaptureStatusActive; +} + +} // namespace gpu +} // namespace stream_executor diff --git a/tensorflow/compiler/xla/stream_executor/cuda/cuda_graph.h b/tensorflow/compiler/xla/stream_executor/gpu/gpu_graph.h similarity index 53% rename from tensorflow/compiler/xla/stream_executor/cuda/cuda_graph.h rename to tensorflow/compiler/xla/stream_executor/gpu/gpu_graph.h index ad56554c0ad300..69cd6632a9f4af 100644 --- a/tensorflow/compiler/xla/stream_executor/cuda/cuda_graph.h +++ b/tensorflow/compiler/xla/stream_executor/gpu/gpu_graph.h @@ -13,73 +13,87 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_CUDA_CUDA_GRAPH_H_ -#define TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_CUDA_CUDA_GRAPH_H_ +#ifndef TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_GPU_GPU_GRAPH_H_ +#define TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_GPU_GPU_GRAPH_H_ #include #include #include #include "absl/functional/any_invocable.h" -#include "third_party/gpus/cuda/include/driver_types.h" #include "tensorflow/compiler/xla/stream_executor/stream.h" #include "tensorflow/tsl/platform/statusor.h" +#if TENSORFLOW_USE_ROCM +#include "tensorflow/compiler/xla/stream_executor/rocm/rocm_driver_wrapper.h" +#else +#include "third_party/gpus/cuda/include/driver_types.h" +#endif + +#if TENSORFLOW_USE_ROCM +using GpuGraphHandle = hipGraph_t; +using GpuGraphExecHandle = hipGraphExec_t; +#else +using GpuGraphHandle = cudaGraph_t; +using GpuGraphExecHandle = cudaGraphExec_t; +#endif + namespace stream_executor { namespace gpu { -class CudaGraphSupport { +class GpuGraphSupport { public: - // Deleters for CUDA graph and graph exec instance that check the returned - // status and terminate if it's not `cudaSuccess`. + // Deleters for gpu graph and graph exec instance that check the returned + // status and terminate on error. struct DestroyGraph { - void operator()(cudaGraph_t); + void operator()(GpuGraphHandle); }; struct DestroyGraphExec { - void operator()(cudaGraphExec_t); + void operator()(GpuGraphExecHandle); }; static size_t NotifyGraphExecCreated(); static size_t NotifyGraphExecDestroyed(); - static size_t allocated_cuda_graph_execs(); - static size_t alive_cuda_graph_execs(); + static size_t allocated_gpu_graph_execs(); + static size_t alive_gpu_graph_execs(); private: - // Global counters for the total number of allocated and alive CUDA graph + // Global counters for the total number of allocated and alive gpu graph // execs to track the resource usage at run time. - static std::atomic allocated_cuda_graph_execs_; - static std::atomic alive_cuda_graph_execs_; + static std::atomic allocated_gpu_graph_execs_; + static std::atomic alive_gpu_graph_execs_; }; //===----------------------------------------------------------------------===// -// RAII helpers for CUDA graph types. +// RAII helpers for gpu graph types. //===----------------------------------------------------------------------===// -class OwnedCudaGraph - : public std::unique_ptr, - CudaGraphSupport::DestroyGraph> { +class OwnedGpuGraph + : public std::unique_ptr, + GpuGraphSupport::DestroyGraph> { // Bring std::unique_ptr constructors in scope. - using std::unique_ptr, - CudaGraphSupport::DestroyGraph>::unique_ptr; + using std::unique_ptr, + GpuGraphSupport::DestroyGraph>::unique_ptr; }; -class OwnedCudaGraphExec - : public std::unique_ptr, - CudaGraphSupport::DestroyGraphExec> { - using Base = std::unique_ptr, - CudaGraphSupport::DestroyGraphExec>; +class OwnedGpuGraphExec + : public std::unique_ptr, + GpuGraphSupport::DestroyGraphExec> { + using Base = std::unique_ptr, + GpuGraphSupport::DestroyGraphExec>; public: - OwnedCudaGraphExec(uint64_t id, cudaGraphExec_t exec) : Base(exec), id_(id) {} - ~OwnedCudaGraphExec(); + OwnedGpuGraphExec(uint64_t id, GpuGraphExecHandle exec) + : Base(exec), id_(id) {} + ~OwnedGpuGraphExec(); - OwnedCudaGraphExec(OwnedCudaGraphExec&&) = default; - OwnedCudaGraphExec& operator=(OwnedCudaGraphExec&&) = default; + OwnedGpuGraphExec(OwnedGpuGraphExec&&) = default; + OwnedGpuGraphExec& operator=(OwnedGpuGraphExec&&) = default; // Updates executable graph instance with a newly captured graph. Returns an // error if the new graph is not compatible (see `cudaGraphExecUpdate`). - tsl::Status Update(OwnedCudaGraph graph); + tsl::Status Update(OwnedGpuGraph graph); // Launches captured graph on a given stream. tsl::Status Launch(stream_executor::Stream* stream); @@ -93,17 +107,16 @@ class OwnedCudaGraphExec }; //===----------------------------------------------------------------------===// -// CUDA Graph Helpers. +// Gpu Graph Helpers. //===----------------------------------------------------------------------===// // Captures all operations added to a `stream` by the `capture` function into -// the cuda graph instance. -tsl::StatusOr CaptureCudaGraph( - stream_executor::Stream* stream, absl::AnyInvocable capture, - cudaStreamCaptureMode mode = cudaStreamCaptureModeThreadLocal); +// the gpu graph instance. +tsl::StatusOr CaptureGpuGraph( + stream_executor::Stream* stream, absl::AnyInvocable capture); -// Instantiates a captured cuda graph instance into a cuda graph executable. -tsl::StatusOr InstantiateCudaGraph(OwnedCudaGraph graph); +// Instantiates a captured gpu graph instance into a gpu graph executable. +tsl::StatusOr InstantiateGpuGraph(OwnedGpuGraph graph); // Returns true if the stream is in graph capture mode tsl::StatusOr IsStreamCapturing(stream_executor ::Stream* stream); @@ -111,4 +124,4 @@ tsl::StatusOr IsStreamCapturing(stream_executor ::Stream* stream); } // namespace gpu } // namespace stream_executor -#endif // TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_CUDA_CUDA_GRAPH_H_ +#endif // TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_GPU_GPU_GRAPH_H_ diff --git a/tensorflow/compiler/xla/stream_executor/rocm/rocm_driver_wrapper.h b/tensorflow/compiler/xla/stream_executor/rocm/rocm_driver_wrapper.h index 3ef762f8cd2717..0089ffd367f5a7 100644 --- a/tensorflow/compiler/xla/stream_executor/rocm/rocm_driver_wrapper.h +++ b/tensorflow/compiler/xla/stream_executor/rocm/rocm_driver_wrapper.h @@ -91,6 +91,13 @@ namespace wrap { __macro(hipGetDevice) \ __macro(hipGetDeviceCount) \ __macro(hipGetDeviceProperties) \ + __macro(hipGetErrorString) \ + __macro(hipGraphDebugDotPrint) \ + __macro(hipGraphDestroy) \ + __macro(hipGraphExecDestroy) \ + __macro(hipGraphExecUpdate) \ + __macro(hipGraphInstantiate) \ + __macro(hipGraphLaunch) \ __macro(hipHostFree) \ __macro(hipHostMalloc) \ __macro(hipHostRegister) \ @@ -123,9 +130,12 @@ namespace wrap { __macro(hipSetDevice) \ __macro(hipDeviceGetStreamPriorityRange) \ __macro(hipStreamAddCallback) \ + __macro(hipStreamBeginCapture) \ __macro(hipStreamCreateWithFlags) \ __macro(hipStreamCreateWithPriority) \ __macro(hipStreamDestroy) \ + __macro(hipStreamEndCapture) \ + __macro(hipStreamIsCapturing) \ __macro(hipStreamQuery) \ __macro(hipStreamSynchronize) \ __macro(hipStreamWaitEvent) // clang-format on diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index 5ce1f65330c66c..9431fe3a93d3af 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -442,31 +442,31 @@ message DebugOptions { // Whether to use cuBLASLt for GEMMs on GPUs. bool xla_gpu_enable_cublaslt = 166; - // 0: Disable CUDA graph capture. - // 1: Enable cuda graphs for fusions and memcpy (safest ones). - // 2: Enable cuda graphs for gemms and convs. - // 3+ Enable cuda graphs for collectives. + // 0: Disable GPU graph capture. + // 1: Enable GPU graphs for fusions and memcpy (safest ones). + // 2: Enable GPU graphs for gemms and convs. + // 3+ Enable GPU graphs for collectives. // // Default: 0. - int32 xla_gpu_cuda_graph_level = 194; + int32 xla_gpu_graph_level = 194; - // Only instantiates a CUDA graph after the captured function execution count + // Only instantiates a GPU graph after the captured function execution count // reaches the threshold. This constant is a heuristic to avoid creating a // large number of CUDA graph instances in memory. - int32 xla_gpu_cuda_graph_num_runs_to_instantiate = 202; + int32 xla_gpu_graph_num_runs_to_instantiate = 202; // This number determines how many moved instructions like fusion kernels are - // required for a region to be captured as a function to be launched as a cuda + // required for a region to be captured as a function to be launched as a GPU // graph. - int32 xla_gpu_cuda_graph_min_graph_size = 208; + int32 xla_gpu_graph_min_graph_size = 208; - // Identify concurrent regions in cuda graphs and execute them concurrently. - bool xla_gpu_cuda_graph_enable_concurrent_region = 215; + // Identify concurrent regions in GPU graphs and execute them concurrently. + bool xla_gpu_graph_enable_concurrent_region = 215; // Timeout in seconds to evict instantiated Gpu graphs from device. When XLA // instantiates new Gpu graphs, it evicts graphs that were not recently // executed to free space on device. - int32 xla_gpu_cuda_graph_eviction_timeout_seconds = 230; + int32 xla_gpu_graph_eviction_timeout_seconds = 230; // Allocate temp buffers once during the first execution of an executable. // Reuse the allocated buffers in subsequent executions. Executables cannot From 042f9ebb67bcbeb244e001d7f543f8c7515662fc Mon Sep 17 00:00:00 2001 From: Pat Notz Date: Mon, 24 Jul 2023 15:04:14 -0700 Subject: [PATCH 073/410] Use embedding sequencing when summary ops are detected PiperOrigin-RevId: 550684087 --- .../tests/embedding_pipelining.mlir | 39 ++++++++++++++++ .../tests/embedding_sequencing.mlir | 44 +++++++++++++++++++ .../mlir/tensorflow/transforms/bridge.cc | 13 +++--- .../transforms/embedding_pipelining.cc | 38 +++++++++++++--- .../transforms/embedding_sequencing.cc | 4 ++ 5 files changed, 127 insertions(+), 11 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/embedding_pipelining.mlir b/tensorflow/compiler/mlir/tensorflow/tests/embedding_pipelining.mlir index 408342b0ebd202..9e75324b1ebdf2 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/embedding_pipelining.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/embedding_pipelining.mlir @@ -481,3 +481,42 @@ module { // CHECK: {{.*StatefulPartitionedCall\"\(%arg0\).*f = @func1.*}} // CHECK: {{.*StatefulPartitionedCall\"\(%arg0\).*f = @func2.*}} } + +// ----- +// This test verifies that the pipelining pass has no effect when tf.WriteSummaryOp is present. +module { + func.func @main(%arg0: tensor<*x!tf_type.resource>) { + %cst_main = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %0:2 = "tf.While"(%cst_main, %arg0) {body = @while_body, cond = @while_cond, is_stateless = false} : (tensor, tensor<*x!tf_type.resource>) -> (tensor, tensor<*x!tf_type.resource>) + return + } + func.func private @while_body(%arg0: tensor, %arg1: tensor<*x!tf_type.resource>) -> (tensor, tensor<*x!tf_type.resource>) { + // CHECK-NOT: {{.*tf.While.* body = @new_while_body.* cond = @new_while_cond.*}} + // CHECK: return + // metadata ops + "tf.TPUReplicateMetadata"() {_has_manual_control_dependencies = true, _replication_info = "repl_info", num_replicas = 2 : i64} : () -> () + %comp_res = "tf.TPUCompilationResult"() {_tpu_compilation_status = "repl_info"} : () -> tensor + + // forward_ops + %res_f = "tf.Const"() {_embedding_pipelining = "forward", _replication_info = "repl_info", value = dense<2> : tensor} : () -> tensor + + // core_tpu ops: + %res_t = "tf.Identity"(%res_f) {_replication_info = "repl_info"} : (tensor) -> tensor + + // backward_ops + %res_b = "tf.Identity"(%res_t) {_embedding_pipelining = "backward", _replication_info = "repl_info"} : (tensor) -> tensor + + // non_tpu_ops + %res_n = "tf.Identity"(%arg0) : (tensor) -> tensor + + %tensor_int64 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %tensor_string = "tf.Const"() {value = dense<""> : tensor} : () -> tensor + "tf.WriteSummary"(%arg1, %tensor_int64, %tensor_int64, %tensor_string, %tensor_string): (tensor<*x!tf_type.resource>, tensor, tensor, tensor, tensor) -> () + + return %res_n, %arg1 : tensor, tensor<*x!tf_type.resource> + } + func.func private @while_cond(%arg0: tensor, %arg1: tensor<*x!tf_type.resource>) -> tensor { + %0 = "tf.Less"(%arg0, %arg0) : (tensor, tensor) -> tensor + return %0 : tensor + } +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/embedding_sequencing.mlir b/tensorflow/compiler/mlir/tensorflow/tests/embedding_sequencing.mlir index 0a8a30698618a1..d1eb8507e0d43e 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/embedding_sequencing.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/embedding_sequencing.mlir @@ -317,3 +317,47 @@ module { return %0 : tensor } } + +// ----- +// This test verifies that the sequencing pass WAI even when tf.WriteSummaryOp is present. This is a complement to the logic/test for embedding_pipelining.cc/mlir +module { + func.func @main(%arg0: tensor<*x!tf_type.resource>) { + %cst_main = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %0:2 = "tf.While"(%cst_main, %arg0) {body = @while_body, cond = @while_cond, is_stateless = false} : (tensor, tensor<*x!tf_type.resource>) -> (tensor, tensor<*x!tf_type.resource>) + return + } + func.func private @while_body(%arg0: tensor, %arg1: tensor<*x!tf_type.resource>) -> (tensor, tensor<*x!tf_type.resource>) { + // Verify that everything is extracted into one of the four functions. + // The order of these functions is also significant. + // CHECK: {{.*StatefulPartitionedCall.* f = @_func_non_tpu.*}} + // CHECK-NEXT: {{.*StatefulPartitionedCall.* f = @_func_sc_forward.*}} + // CHECK-NEXT: {{.*StatefulPartitionedCall.* f = @_func_core_tpu.*}} + // CHECK-NEXT: {{.*StatefulPartitionedCall.* f = @_func_sc_backward.*}} + // CHECK-NEXT: return + // metadata ops + "tf.TPUReplicateMetadata"() {_has_manual_control_dependencies = true, _replication_info = "repl_info", num_replicas = 2 : i64} : () -> () + %comp_res = "tf.TPUCompilationResult"() {_tpu_compilation_status = "repl_info"} : () -> tensor + + // forward_ops + %res_f = "tf.Const"() {_embedding_pipelining = "forward", _replication_info = "repl_info", value = dense<2> : tensor} : () -> tensor + + // core_tpu ops: + %res_t = "tf.Identity"(%res_f) {_replication_info = "repl_info"} : (tensor) -> tensor + + // backward_ops + %res_b = "tf.Identity"(%res_t) {_embedding_pipelining = "backward", _replication_info = "repl_info"} : (tensor) -> tensor + + // non_tpu_ops + %res_n = "tf.Identity"(%arg0) : (tensor) -> tensor + + %tensor_int64 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %tensor_string = "tf.Const"() {value = dense<""> : tensor} : () -> tensor + "tf.WriteSummary"(%arg1, %tensor_int64, %tensor_int64, %tensor_string, %tensor_string): (tensor<*x!tf_type.resource>, tensor, tensor, tensor, tensor) -> () + + return %res_n, %arg1 : tensor, tensor<*x!tf_type.resource> + } + func.func private @while_cond(%arg0: tensor, %arg1: tensor<*x!tf_type.resource>) -> tensor { + %0 = "tf.Less"(%arg0, %arg0) : (tensor, tensor) -> tensor + return %0 : tensor + } +} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc index 92ff89da0a4caa..2babd26307b3b8 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc @@ -149,12 +149,13 @@ void CreateTPUBridgePipelineImpl( pm.addNestedPass( CreateTPUReorderReplicateAndPartitionedInputsPass()); pm.addNestedPass(TF::CreateDecomposeReduceDatasetPass()); - if (tensorflow::GetBuildXlaOpsPassFlags() - ->tf_xla_disable_full_embedding_pipelining) { - pm.addPass(TFDevice::CreateEmbeddingSequencingPass()); - } else { - pm.addPass(TFDevice::CreateEmbeddingPipeliningPass()); - } + // Only one of EmbeddingSequencing and EmbeddingPipelining will actually + // run and the logic is in EmbeddingPipeliningPass. If the pipelining pass + // runs, embedding attributes are stripped and the sequencing pass will have + // no effect. If the pipelining pass doesn't run, embedding attributes are + // preserved and the sequencing rewrite will trigger. + pm.addPass(TFDevice::CreateEmbeddingPipeliningPass()); + pm.addPass(TFDevice::CreateEmbeddingSequencingPass()); pm.addPass(CreateTPUClusterFormationPass()); // CreateEmbeddingPipeliningPass may have created more functions, but // TPUClusterCleanup and OutsideCompiledToHostLaunch need every function to be diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/embedding_pipelining.cc b/tensorflow/compiler/mlir/tensorflow/transforms/embedding_pipelining.cc index 84b161a0fd7862..ffd4aed3bb2ed7 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/embedding_pipelining.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/embedding_pipelining.cc @@ -124,13 +124,11 @@ return selected_results // #include "smartass/brain/ops/flogs_ops.h" #include "absl/log/log.h" #include "absl/strings/str_cat.h" -#include "absl/strings/str_replace.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/Support/Casting.h" -#include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project @@ -152,6 +150,7 @@ return selected_results #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/InliningUtils.h" // from @llvm-project #include "mlir/Transforms/RegionUtils.h" // from @llvm-project +#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" @@ -192,6 +191,30 @@ struct EmbeddingPipeliningPass void runOnOperation() override; }; +bool UseEmbeddingPipelining(ModuleOp& module) { + // Enable automated pipelining pass unless: + // 1. The user disables it via flog, or + // 2. The graph contains TF.Summary ops. Graphs like this typically only run + // for a single step which doesn't work in pipelining. + + if (tensorflow::GetBuildXlaOpsPassFlags() + ->tf_xla_disable_full_embedding_pipelining) + return false; + + // Detect summaries by looking for key Ops in the graph. It would be better to + // do this via operator attributes rather than looking for a specific op. + WalkResult walk_result = module.walk([&](Operation* op) -> WalkResult { + if (llvm::isa(op)) return WalkResult::interrupt(); + return WalkResult::advance(); + }); + if (walk_result.wasInterrupted()) { + VLOG(1) << "TF summaries detected - disabling embedding pipelining."; + return false; + } + VLOG(1) << "Embedding pipelining rewrite enabled."; + return true; +} + StringAttr GetReplicationAttr(mlir::Operation* op) { return op->getAttrOfType(TF::kReplicationInfoAttr); } @@ -1218,7 +1241,7 @@ LogicalResult StartStep0(OpBuilder& builder, Location& loc, const std::string name = "start_step_0"; AddAssertion(builder, loc, cond_value, - "Auto-pipelining requires at least two steps."); + "[StartStep0] Auto-pipelining requires at least two steps."); auto insertion_point = builder.saveInsertionPoint(); func::FuncOp orig_parent_func = @@ -1302,7 +1325,7 @@ LogicalResult StartStep1(OpBuilder& builder, Location& loc, const std::string name = "start_step_1"; AddAssertion(builder, loc, cond_value, - "Auto-pipelining requires at least two steps."); + "[StartStep1] Auto-pipelining requires at least two steps."); auto insertion_point = builder.saveInsertionPoint(); func::FuncOp orig_parent_func = @@ -1362,7 +1385,7 @@ LogicalResult FinishStepNm2(OpBuilder& builder, Location& loc, const std::string name = "finish_step_nm2"; AddAssertion(builder, loc, cond_value, - "Auto-pipelining requires at least two steps."); + "[FinishStepNm2] Auto-pipelining requires at least two steps."); auto insertion_point = builder.saveInsertionPoint(); func::FuncOp orig_parent_func = @@ -1622,6 +1645,11 @@ Operation* LiftNonTpuFuncCaller(mlir::OpBuilder& builder, void EmbeddingPipeliningPass::runOnOperation() { VLOG(3) << "EmbeddingPipeliningPass::runOnOperation()"; ModuleOp module = getOperation(); + + // We only use one of the EmbeddingPipelining and EmbeddingSequencing passes. + if (!UseEmbeddingPipelining(module)) return; + VLOG(1) << "Embedding pipelining rewrite enabled."; + SymbolTable symbol_table(module); llvm::SetVector forward_pass_ops; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/embedding_sequencing.cc b/tensorflow/compiler/mlir/tensorflow/transforms/embedding_sequencing.cc index a83f6ac54a8af4..6fdac01fe7a001 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/embedding_sequencing.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/embedding_sequencing.cc @@ -765,6 +765,7 @@ LogicalResult ExtractOpsAsFunc( } void EmbeddingSequencingPass::runOnOperation() { + VLOG(3) << "EmbeddingSequencingPass::runOnOperation()"; ModuleOp module = getOperation(); llvm::SetVector forward_pass_ops; @@ -802,6 +803,7 @@ void EmbeddingSequencingPass::runOnOperation() { return signalPassFailure(); } } + VLOG(1) << "Embedding sequencing rewrite enabled."; // Ensure that all ops are in the same region, and have the same replication // info. @@ -912,6 +914,8 @@ void EmbeddingSequencingPass::runOnOperation() { metadata_op->erase(); compilation_op->erase(); + + VLOG(3) << "EmbeddingSequencingPass::runOnOperation done."; } } // namespace From 3995176d57fa0e8c8555521b3d694dd48cd2c538 Mon Sep 17 00:00:00 2001 From: David Silverstone Date: Mon, 24 Jul 2023 15:05:23 -0700 Subject: [PATCH 074/410] Replace deprecated, opaque `TPUHostTransferProto` with serialized representation PiperOrigin-RevId: 550684392 --- .../tpu/kernels/tpu_executable_info.proto | 32 ++----------------- 1 file changed, 3 insertions(+), 29 deletions(-) diff --git a/tensorflow/core/tpu/kernels/tpu_executable_info.proto b/tensorflow/core/tpu/kernels/tpu_executable_info.proto index 669d3e31be679d..cadcd8439b1189 100644 --- a/tensorflow/core/tpu/kernels/tpu_executable_info.proto +++ b/tensorflow/core/tpu/kernels/tpu_executable_info.proto @@ -63,34 +63,8 @@ message TPUExecutableInfoProto { xla.DeviceAssignmentProto device_assignment = 6; } -// Metadata for a data transfer between device and host. -message TPUHostTransferProto { - enum TransferDirection { - NONE = 0; - DEVICE_TO_HOST = 1; - HOST_TO_DEVICE = 2; - } - // Channel identifier assigned by compiler and used in host commands. - int64 channel = 1; - // Direction of the transfer operation. - TransferDirection direction = 2; - // Channel identifier prodided by XLA client. - string key = 3; - reserved 4; // was nested_while_level - // Shape of the data to be transferred (including layout). - xla.ShapeProto shape = 5; - // Address of the device buffer in HBM (byte offset). - int64 buffer_offset = 6; - // Original data type for this host transfer before X64 rewrite. - xla.PrimitiveType original_type = 7; - // If this host transfer is a splitted X64 transfer, specifies whether this - // transfer is for lower bits. - bool is_lower_bits = 8; - // The name of host side command handler. - string host_handler_name = 9; - reserved 10; -} - message TPUHostTransferInfoProto { - repeated TPUHostTransferProto host_transfers = 1; + reserved 1; + // Serialized metadata for a data transfer between device and host. + repeated bytes serialized_transfers = 2 [ctype = CORD]; } From 6586360a3e898de43f3238ffa7120b702cc55248 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 24 Jul 2023 15:16:48 -0700 Subject: [PATCH 075/410] tf.signal: Fix nan in kaiser_window on xla_gpu. Round-off error was creating a slightly negative value in the argument to math_ops.sqrt(). TESTED: - window_ops_test_xla_gpu fails before this change, and passes after PiperOrigin-RevId: 550687551 --- tensorflow/python/ops/signal/BUILD | 1 + tensorflow/python/ops/signal/window_ops.py | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/tensorflow/python/ops/signal/BUILD b/tensorflow/python/ops/signal/BUILD index a87dd8bf525b32..94fe785c21b908 100644 --- a/tensorflow/python/ops/signal/BUILD +++ b/tensorflow/python/ops/signal/BUILD @@ -164,6 +164,7 @@ py_strict_library( "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:cond", "//tensorflow/python/ops:math_ops", + "//tensorflow/python/ops:nn_ops", "//tensorflow/python/ops:special_math_ops", "//tensorflow/python/util:dispatch", "//tensorflow/python/util:tf_export", diff --git a/tensorflow/python/ops/signal/window_ops.py b/tensorflow/python/ops/signal/window_ops.py index 0dafd1e0a45827..e0c62d0ebef43b 100644 --- a/tensorflow/python/ops/signal/window_ops.py +++ b/tensorflow/python/ops/signal/window_ops.py @@ -23,6 +23,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import cond from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops from tensorflow.python.ops import special_math_ops from tensorflow.python.util import dispatch from tensorflow.python.util.tf_export import tf_export @@ -80,10 +81,9 @@ def kaiser_window(window_length, beta=12., dtype=dtypes.float32, name=None): arg = math_ops.cast(arg, dtype=dtype) beta = math_ops.cast(beta, dtype=dtype) one = math_ops.cast(1.0, dtype=dtype) - two = math_ops.cast(2.0, dtype=dtype) halflen_float = math_ops.cast(halflen_float, dtype=dtype) - num = beta * math_ops.sqrt( - one - math_ops.pow(arg, two) / math_ops.pow(halflen_float, two)) + num = beta * math_ops.sqrt(nn_ops.relu( + one - math_ops.square(arg / halflen_float))) window = math_ops.exp(num - beta) * ( special_math_ops.bessel_i0e(num) / special_math_ops.bessel_i0e(beta)) return window From 6d1d4f301d8885a578299562ec41c492106422e9 Mon Sep 17 00:00:00 2001 From: Matt Callanan Date: Mon, 24 Jul 2023 15:17:44 -0700 Subject: [PATCH 076/410] #tf-data Ramp up `"file_locality_v2"` experiment to 10%. PiperOrigin-RevId: 550687785 --- tensorflow/core/data/dataset_utils.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/data/dataset_utils.cc b/tensorflow/core/data/dataset_utils.cc index 5a4f58b6f06256..733ccb42554444 100644 --- a/tensorflow/core/data/dataset_utils.cc +++ b/tensorflow/core/data/dataset_utils.cc @@ -978,7 +978,7 @@ REGISTER_DATASET_EXPERIMENT("data_transfer", RandomJobSamplePercentage<50>, AllTasks); REGISTER_DATASET_EXPERIMENT("file_locality", RandomJobSamplePercentage<10>, IndependentHostTasks); -REGISTER_DATASET_EXPERIMENT("file_locality_v2", RandomJobSamplePercentage<0>, +REGISTER_DATASET_EXPERIMENT("file_locality_v2", RandomJobSamplePercentage<10>, IndependentHostTasks); } // namespace } // namespace data From 99bfee183e9768a61e503123e5358ac518a2fb65 Mon Sep 17 00:00:00 2001 From: Tianrun Li Date: Mon, 24 Jul 2023 16:01:24 -0700 Subject: [PATCH 077/410] Add support for parted layout type. (1) Parted layout is generated by where op expander. (2) Propagate this layout to all the descendant nodes. (3) During SPMD expansion, if DTensor finds a parted layout, it will skip the expansion and only do local compute. PiperOrigin-RevId: 550699427 --- tensorflow/dtensor/cc/tensor_layout.cc | 6 +- tensorflow/dtensor/cc/tensor_layout.h | 5 +- tensorflow/dtensor/mlir/collectives.cc | 4 +- .../mlir/expansions/where_spmd_expander.cc | 76 +++---------------- .../dtensor/mlir/layout_propagation_v2.cc | 5 +- tensorflow/dtensor/mlir/spmd_expander.cc | 73 +++++++++++++++++- tensorflow/dtensor/python/tests/BUILD | 2 + .../python/tests/layout_propagation_test.py | 34 ++++----- .../dtensor/python/tests/layout_test.py | 3 +- tensorflow/dtensor/python/tests/spmd_test.py | 14 +++- 10 files changed, 125 insertions(+), 97 deletions(-) diff --git a/tensorflow/dtensor/cc/tensor_layout.cc b/tensorflow/dtensor/cc/tensor_layout.cc index 0b7d6c76595f6d..26a70937006f53 100644 --- a/tensorflow/dtensor/cc/tensor_layout.cc +++ b/tensorflow/dtensor/cc/tensor_layout.cc @@ -1105,9 +1105,13 @@ StatusOr Layout::ToProto() const { } bool Layout::IsEquivalent(const Layout& b) const { + if (this->type() != b.type()) return false; + return IsEquivalentIgnoringType(b); +} + +bool Layout::IsEquivalentIgnoringType(const Layout& b) const { if (this->rank() != b.rank()) return false; if (this->mesh() != b.mesh()) return false; - if (this->type() != b.type()) return false; for (int i = 0; i < this->rank(); ++i) { if (this->sharding_specs_[i] != b.sharding_specs_[i]) { if ((this->num_shards_for_dim(i) != 1) || (b.num_shards_for_dim(i) != 1)) diff --git a/tensorflow/dtensor/cc/tensor_layout.h b/tensorflow/dtensor/cc/tensor_layout.h index db1aa24551e6f7..444769ccd81649 100644 --- a/tensorflow/dtensor/cc/tensor_layout.h +++ b/tensorflow/dtensor/cc/tensor_layout.h @@ -430,10 +430,13 @@ class Layout { const std::string& sharding_spec(int idx) const; + // Similar to IsEquivalentIgnoringType, but also verifies the layout type are + // equal. + bool IsEquivalent(const Layout& b) const; // Two layouts are equivalent if they would result in the same sharding for // the tensor. E.g. if one is unsharded and the other is sharded on a mesh // dimension of size 1. - bool IsEquivalent(const Layout& b) const; + bool IsEquivalentIgnoringType(const Layout& b) const; // Uses proto to compare the equality. If any conversion to proto fails, // returns false. bool operator==(const Layout& b) const; diff --git a/tensorflow/dtensor/mlir/collectives.cc b/tensorflow/dtensor/mlir/collectives.cc index 0dfb9fd820b89d..696fd52923b36e 100644 --- a/tensorflow/dtensor/mlir/collectives.cc +++ b/tensorflow/dtensor/mlir/collectives.cc @@ -299,7 +299,9 @@ StatusOr EmitRelayout( mlir::OpBuilder builder(input.getContext()); TF_RETURN_IF_ERROR(SetBuilderInsertionAfterValue(input, builder)); - if (src_layout.IsEquivalent(tgt_layout)) { + // If two layouts are the same, or the only difference is layout type, then + // there is no need to actually relayout data. + if (src_layout.IsEquivalentIgnoringType(tgt_layout)) { mlir::TF::IdentityOp op = builder.create( input.getLoc(), input.getType(), input); if (newly_created_ops != nullptr) newly_created_ops->insert(op); diff --git a/tensorflow/dtensor/mlir/expansions/where_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/where_spmd_expander.cc index c93ecf9e4ca679..c0641aa0756186 100644 --- a/tensorflow/dtensor/mlir/expansions/where_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/where_spmd_expander.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/status/status.h" #include "llvm/ADT/DenseMap.h" #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project @@ -36,73 +37,10 @@ limitations under the License. namespace tensorflow { namespace dtensor { -namespace { - -// Convert the local index (associated with local tensor per device) to the -// global index (associated with the global tensor). The local index must be the -// first result of `op`. -StatusOr LocalIndexToGlobalIndex(mlir::Operation* op) { - TF_ASSIGN_OR_RETURN(auto input_layout, - ExtractLayoutFromOperand(op->getOperand(0))); - - mlir::OpBuilder builder(op); - builder.setInsertionPointAfter(op); - - // Calculate the index offset using DeviceId, for now, DTensor only supports - // index conversion when sharding is on the first dimension. - mlir::Value num_devices_per_dim_0 = - IntConst(builder, op->getLoc(), - input_layout->mesh().num_local_devices() / - input_layout->num_shards_for_dim(0)); - TF_ASSIGN_OR_RETURN(mlir::Value device_id, DeviceId(op)); - mlir::Value device_id_offset = builder.create( - op->getLoc(), device_id, num_devices_per_dim_0); - - TF_ASSIGN_OR_RETURN(const auto& shape, - ExtractGlobalInputShape(op->getOpOperand(0))); - mlir::Value size_per_shard_dim_0 = IntConst( - builder, op->getLoc(), shape[0] / input_layout->num_shards_for_dim(0)); - mlir::Value index_offset = builder.create( - op->getLoc(), size_per_shard_dim_0, device_id_offset); - - // Add index offset to the local index to get the global index. - mlir::Value index_offset_i64 = builder.create( - op->getLoc(), - mlir::RankedTensorType::get( - index_offset.getType().cast().getShape(), - builder.getIntegerType(64)), - index_offset); - mlir::Value global_index = builder.create( - op->getLoc(), op->getResultTypes(), index_offset_i64, op->getOpResult(0)); - - op->getOpResult(0).replaceAllUsesExcept(global_index, - global_index.getDefiningOp()); - - return global_index.getDefiningOp(); -} - -} // namespace - StatusOr WhereOpSPMDExpander::ExpandOp(mlir::Operation* op) { - TF_ASSIGN_OR_RETURN(auto input_layout, - ExtractLayoutFromOperand(op->getOperand(0))); - assert(input_layout); - - // If input is fully replicated, there is no need to manipulate the index - // calculated by the Where Op, just return directly. - if (input_layout->IsFullyReplicated()) { - return op; - } - - // Only supports sharding on the first dimension. - if (!input_layout->IsBatchParallel()) { - return absl::InvalidArgumentError( - "Where op only supports batch sharding for now."); - } - - // Where Op returns the indices of the non-zero elements in the input tensor. - // Convert the local index to global index as the final output. - return LocalIndexToGlobalIndex(op); + // Where op does not do anything during SPMD expansion. + // Because the output layout is parted layout and it follows local semantic. + return op; } StatusOr> WhereOpSPMDExpander::ComputeLayoutForward( @@ -121,8 +59,12 @@ StatusOr> WhereOpSPMDExpander::ComputeLayoutForward( std::vector layout_specs; layout_specs.push_back(layout.sharding_spec(0)); layout_specs.push_back(Layout::kUnshardedDim); + // The output of Where Op contains dynamic shape and has parted layout. + // This is the source of the parted layout and it is propagated to descendent + // ops. TF_ASSIGN_OR_RETURN(Layout new_layout, - Layout::GetLayout(layout_specs, layout.mesh())); + Layout::GetLayout(Layout::LayoutType::kParted, + layout_specs, layout.mesh())); return llvm::DenseMap({{0, new_layout}}); } diff --git a/tensorflow/dtensor/mlir/layout_propagation_v2.cc b/tensorflow/dtensor/mlir/layout_propagation_v2.cc index ecc601bcafd129..6cc87e7e3c564c 100644 --- a/tensorflow/dtensor/mlir/layout_propagation_v2.cc +++ b/tensorflow/dtensor/mlir/layout_propagation_v2.cc @@ -299,7 +299,10 @@ StatusOr MergeLayouts( } } FilterkAnySpecs(proposed_specs); - return Layout::GetLayout(proposed_specs, mesh); + // Parted layout is propagated from producer side to consumer side, so when + // merging the layout with producer layout, take the producer layout type into + // account. + return Layout::GetLayout(producer->type(), proposed_specs, mesh); } mlir::LogicalResult InsertLayoutsForDTensorLayout( diff --git a/tensorflow/dtensor/mlir/spmd_expander.cc b/tensorflow/dtensor/mlir/spmd_expander.cc index 01719759074a99..ce6b34c7a004b5 100644 --- a/tensorflow/dtensor/mlir/spmd_expander.cc +++ b/tensorflow/dtensor/mlir/spmd_expander.cc @@ -25,7 +25,9 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/strings/str_cat.h" #include "absl/types/optional.h" +#include "llvm/ADT/DenseMap.h" #include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/OperationSupport.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -46,6 +48,58 @@ limitations under the License. namespace tensorflow { namespace dtensor { +namespace { + +// Adjust the layout to parted if the input has parted layout. +// This is only used by the forward layout propagation, not the backward. The +// parted layout can only be generated by the Where Op and then affect the +// descendent nodes. +// User should not explicitly set a output parted layout and expect it to affect +// the layout of ancestor nodes. +Status AdjustPartedLayout(const llvm::DenseMap& input_layouts, + llvm::DenseMap* computed_layouts) { + // If any input has parted layout, propagate the parted layout to the layout + // of all the computed values. + bool input_has_parted_layout = false; + for (const auto& input_layout : input_layouts) { + if (input_layout.second.type() == Layout::LayoutType::kParted) { + input_has_parted_layout = true; + break; + } + } + if (input_has_parted_layout) { + for (auto& computed_layout : *computed_layouts) { + TF_ASSIGN_OR_RETURN(Layout parted, computed_layout.second.ToParted()); + computed_layout.getSecond() = parted; + } + } + return OkStatus(); +} + +// Returns whether DTensor should skip SPMD expansion because `op` uses parted +// layout. +bool SkipExpansionForPartedLayout(mlir::Operation* op) { + // If op is a terminating return op, don't skip the SPMD expansion. + if (llvm::isa(op)) { + return false; + } + // Check if any input operand uses parted layout. + auto status_or_input_layouts = ExtractRequiredLayoutFromOperands(op); + if (!status_or_input_layouts.ok()) { + return false; + } + bool operand_uses_parted_layout = false; + for (const auto& layout : status_or_input_layouts.value()) { + if (layout.type() == Layout::LayoutType::kParted) { + operand_uses_parted_layout = true; + break; + } + } + return operand_uses_parted_layout; +} + +} // namespace + // static SPMDExpanderRegistry* SPMDExpanderRegistry::Global() { static SPMDExpanderRegistry* registry = new SPMDExpanderRegistry(); @@ -99,10 +153,19 @@ Status SPMDExpanderBase::ExpandOpAndSetLayout(mlir::Operation* op, .c_str()); } - // If op is on an XLA SPMD mesh, then set layout and skip expansion. + // If op is on an XLA SPMD mesh, then set layout and skip expansion. There is + // no need to infer local shape because XLA SPMD expects global shape. + // If op skips SPMD expansion because of parted layout, infer the local shape + // and return. TF_ASSIGN_OR_RETURN(const Mesh& mesh, ExtractDeviceMeshEnclosingCluster(op)); - if (mesh.IsSingleDevice() || mesh.use_xla_spmd()) { - *output = op; + bool skip_expansion_for_parted_layout = SkipExpansionForPartedLayout(op); + if (mesh.IsSingleDevice() || mesh.use_xla_spmd() || + skip_expansion_for_parted_layout) { + if (skip_expansion_for_parted_layout) { + *output = InferSPMDExpandedLocalShape(op); + } else { + *output = op; + } SetLayoutOnOp(*output, absl::Span>( computed_layout.data(), computed_layout.size())); return OkStatus(); @@ -199,7 +262,9 @@ StatusOr> SPMDExpanderBase::ComputeLayoutForward( } return layouts; } - return ComputeLayoutForward(op, input_layouts); + TF_ASSIGN_OR_RETURN(auto layouts, ComputeLayoutForward(op, input_layouts)); + TF_RETURN_IF_ERROR(AdjustPartedLayout(input_layouts, &layouts)); + return layouts; } StatusOr> SPMDExpanderBase::ComputeLayoutBackward( diff --git a/tensorflow/dtensor/python/tests/BUILD b/tensorflow/dtensor/python/tests/BUILD index d6cfb986c3043f..5b64ad1d4b549b 100644 --- a/tensorflow/dtensor/python/tests/BUILD +++ b/tensorflow/dtensor/python/tests/BUILD @@ -277,6 +277,8 @@ dtensor_test( "//tensorflow/python/framework:dtypes", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:math_ops", + "//tensorflow/python/ops:nn_ops", + "//tensorflow/python/ops:nn_ops_gen", "//tensorflow/python/ops:stateless_random_ops", "//tensorflow/python/ops:variables", "//tensorflow/python/platform:client_testlib", diff --git a/tensorflow/dtensor/python/tests/layout_propagation_test.py b/tensorflow/dtensor/python/tests/layout_propagation_test.py index 7cd3f990c9a450..7b3efc49cc6fae 100644 --- a/tensorflow/dtensor/python/tests/layout_propagation_test.py +++ b/tensorflow/dtensor/python/tests/layout_propagation_test.py @@ -26,7 +26,9 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops from tensorflow.python.ops import stateless_random_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -34,25 +36,8 @@ UNSHARDED = layout.UNSHARDED # Convenient constants to use for tests. -_MESH_DIM_BATCH = 'batch' _MESH_DIM_X = 'x' _MESH_DIM_Y = 'y' -_MESH_2D_STRING = ( - '|batch=2,x=2|0,1,2,3|0,1,2,3|' - '/job:localhost/replica:0/task:0/device:TPU:0,' - '/job:localhost/replica:0/task:0/device:TPU:1,' - '/job:localhost/replica:0/task:0/device:TPU:2,' - '/job:localhost/replica:0/task:0/device:TPU:3' -) - -_2D_GLOBAL_IDS = test_util.create_device_ids_array((1, 2)) - -_2D_MESH = layout.Mesh([_MESH_DIM_BATCH, _MESH_DIM_X], _2D_GLOBAL_IDS, - np.ravel(_2D_GLOBAL_IDS).tolist(), - test_util.create_device_list((1, 2), 'TPU')) -_2D_X_Y_MESH = layout.Mesh([_MESH_DIM_X, _MESH_DIM_Y], _2D_GLOBAL_IDS, - np.ravel(_2D_GLOBAL_IDS).tolist(), - test_util.create_device_list((1, 2), 'CPU')) class LayoutPropagationV2Test(test_util.DTensorBaseTest): @@ -386,6 +371,21 @@ def func(input_a, input_b): ) api.check_layout(result, unsharded_x) + def test_layout_prop_parted_layout(self): + value = np.array([1.0, 2.0, 3.0, 4.0]) + expected = nn_ops.softmax_v2(gen_nn_ops.relu(value)) + + @polymorphic_function.function + def func(value): + return nn_ops.softmax_v2(gen_nn_ops.relu(value)) + + parted_layout = self.x_layout.to_parted() + result = func(api.relayout(value, parted_layout)) + + # Verifies the parted layout can be propagated through a chain of ops to the + # final output. + self.assertDTensorEqual(expected, parted_layout, result) + if __name__ == '__main__': test.main() diff --git a/tensorflow/dtensor/python/tests/layout_test.py b/tensorflow/dtensor/python/tests/layout_test.py index 0f40507c8684bb..b8e367a4901dd5 100644 --- a/tensorflow/dtensor/python/tests/layout_test.py +++ b/tensorflow/dtensor/python/tests/layout_test.py @@ -472,9 +472,10 @@ def do_relayout(): @combinations.generate(combinations.combine(is_graph=[False, True])) def test_relayout_to_parted(self, is_graph): data = np.array([1, 2, 3, 4.0], dtype='f4') + inp = api.relayout(data, self.y_layout) def do_relayout(): - return api.relayout(data, self.y_layout.to_parted()) + return api.relayout(inp, self.y_layout.to_parted()) if is_graph: do_relayout = polymorphic_function.function(do_relayout) diff --git a/tensorflow/dtensor/python/tests/spmd_test.py b/tensorflow/dtensor/python/tests/spmd_test.py index ece2253cc28b6e..c9c58abe47a674 100644 --- a/tensorflow/dtensor/python/tests/spmd_test.py +++ b/tensorflow/dtensor/python/tests/spmd_test.py @@ -2162,7 +2162,9 @@ def boolean_mask_func(t, m): return array_ops.boolean_mask(t, m) result = boolean_mask_func(tensor, mask) - self.assertDTensorEqual(expected, expected_output_layout, result) + self.assertDTensorEqual( + expected, expected_output_layout.to_parted(), result + ) def testRawWhere(self): if self.mesh.use_xla_spmd(): @@ -2171,8 +2173,6 @@ def testRawWhere(self): condition = constant_op.constant( np.array([True, True, False, False, True, False, True, True]) ) - expected = gen_array_ops.where(condition) - condition = api.relayout(condition, self.first_dimension_sharded_layout_1d) @polymorphic_function.function @@ -2180,8 +2180,14 @@ def func(c): return gen_array_ops.where(c) result = func(condition) + # With parted layout, the raw where op will output local index instead of + # global index. So the second half of test expectation ([0], [2], [3]) has + # an offset of 4. + expected = constant_op.constant( + np.array([[0], [1], [0], [2], [3]]), dtype=dtypes.int64 + ) self.assertDTensorEqual( - expected, self.first_dimension_sharded_layout, result + expected, self.first_dimension_sharded_layout.to_parted(), result ) @parameterized.named_parameters([ From ea2fa79af6f8d527b06eb09862fc9edb43590f90 Mon Sep 17 00:00:00 2001 From: Fergus Henderson Date: Mon, 24 Jul 2023 16:28:09 -0700 Subject: [PATCH 078/410] Fixes for TfLiteRegistrationExternalsCache. 1. Bug fix: change the cache of TfLiteRegistrationExternals objects in OpResolver from a set of pointers to a hash map (with operator name / code and version as the key). Without this, there was no reuse of such objects, potentially leading to an unbounded memory leak if the same OpResolver was used repeatedly with different Interpreter objects. 2. Change the cache of TfLiteRegistrationExternals objects in OpResolver from a unique_ptr to a shared_ptr, to allow the same map to be shared between the OpResolver and the InterpreterBuilder/Interpreter/Subgraph. PiperOrigin-RevId: 550706140 --- tensorflow/lite/c/BUILD | 20 ++++- tensorflow/lite/c/c_api_opaque_internal.cc | 46 ++++++++-- tensorflow/lite/c/c_api_opaque_internal.h | 56 +++++++++--- .../lite/c/c_api_opaque_internal_test.cc | 76 ++++++++++++++++ tensorflow/lite/core/api/op_resolver.h | 89 ++++++++++++++++++- tensorflow/lite/core/c/c_api.cc | 12 +-- tensorflow/lite/core/subgraph.cc | 32 ++++++- tensorflow/lite/core/subgraph.h | 13 +-- tensorflow/lite/delegates/delegate_test.cc | 4 + .../delegates/utils/simple_opaque_delegate.cc | 5 +- tensorflow/lite/special_rules.bzl | 6 ++ .../lite/tools/serialization/writer_lib.h | 19 +++- 12 files changed, 339 insertions(+), 39 deletions(-) create mode 100644 tensorflow/lite/c/c_api_opaque_internal_test.cc diff --git a/tensorflow/lite/c/BUILD b/tensorflow/lite/c/BUILD index c713456fd344fc..bc00f79e36c19c 100644 --- a/tensorflow/lite/c/BUILD +++ b/tensorflow/lite/c/BUILD @@ -9,7 +9,7 @@ load( "tflite_self_contained_libs_test_suite", ) load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") -load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite") +load("//tensorflow/lite:special_rules.bzl", "c_api_opaque_internal_visibility_allowlist", "tflite_portable_test_suite") load("//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite_with_c_headers_test") package( @@ -292,7 +292,7 @@ cc_library( tags = ["allow_undefined_symbols"], # For tflite::CreateOpResolver(). visibility = [ "//tensorflow/lite:__subpackages__", - ], + ] + c_api_opaque_internal_visibility_allowlist(), deps = [ ":c_api_types", "//tensorflow/lite:framework", @@ -303,6 +303,22 @@ cc_library( ], ) +cc_test( + name = "c_api_opaque_internal_test", + srcs = ["c_api_opaque_internal_test.cc"], + data = [ + "//tensorflow/lite:testdata/add.bin", + ], + deps = [ + ":c_api_internal", + ":c_api_opaque_internal", + "//tensorflow/lite:builtin_ops", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:builtin_ops", + "@com_google_googletest//:gtest_main", + ], +) + # Same as c_api_opaque_internal, but depends on the # '_without_alwayslink' variant of ':c_api_without_op_resolver'. cc_library( diff --git a/tensorflow/lite/c/c_api_opaque_internal.cc b/tensorflow/lite/c/c_api_opaque_internal.cc index e0bd99308eb113..2f7f1113cb5bb2 100644 --- a/tensorflow/lite/c/c_api_opaque_internal.cc +++ b/tensorflow/lite/c/c_api_opaque_internal.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/core/c/registration_external.h" @@ -24,9 +25,12 @@ limitations under the License. namespace tflite { namespace internal { -TfLiteRegistrationExternal* -CommonOpaqueConversionUtil::ObtainRegistrationExternal( - TfLiteContext* context, TfLiteRegistration* registration, int node_index) { +namespace { + +// Returns a dynamically allocated object; the caller is responsible for +// deallocating it using TfLiteRegistrationExternalDelete. +TfLiteRegistrationExternal* MakeRegistrationExternal( + const TfLiteRegistration* registration, int node_index) { // We need to allocate a new TfLiteRegistrationExternal object and then // populate its state correctly, based on the contents in 'registration'. @@ -36,13 +40,41 @@ CommonOpaqueConversionUtil::ObtainRegistrationExternal( registration_external->node_index = node_index; - registration->registration_external = registration_external; + return registration_external; +} + +} // anonymous namespace + +TfLiteRegistrationExternal* +CommonOpaqueConversionUtil::CachedObtainRegistrationExternal( + RegistrationExternalsCache* registration_externals_cache, + const TfLiteRegistration* registration, int node_index) { + OpResolver::OpId op_id{registration->builtin_code, registration->custom_name, + registration->version}; + auto it = registration_externals_cache->find(op_id); + if (it != registration_externals_cache->end()) { + return it->second.get(); + } + auto* registration_external = + MakeRegistrationExternal(registration, node_index); + registration_externals_cache->insert( + it, std::make_pair(op_id, registration_external)); - auto* subgraph = static_cast(context->impl_); - subgraph->registration_externals_.insert( - std::unique_ptr(registration_external)); return registration_external; } +TfLiteRegistrationExternal* +CommonOpaqueConversionUtil::ObtainRegistrationExternal( + TfLiteContext* context, const TfLiteRegistration* registration, + int node_index) { + auto* subgraph = static_cast(context->impl_); + if (!subgraph->registration_externals_) { + subgraph->registration_externals_ = + std::make_shared(); + } + return CachedObtainRegistrationExternal( + subgraph->registration_externals_.get(), registration, node_index); +} + } // namespace internal } // namespace tflite diff --git a/tensorflow/lite/c/c_api_opaque_internal.h b/tensorflow/lite/c/c_api_opaque_internal.h index e9b31285572ebd..f274b2ba833be7 100644 --- a/tensorflow/lite/c/c_api_opaque_internal.h +++ b/tensorflow/lite/c/c_api_opaque_internal.h @@ -15,33 +15,61 @@ limitations under the License. #ifndef TENSORFLOW_LITE_C_C_API_OPAQUE_INTERNAL_H_ #define TENSORFLOW_LITE_C_C_API_OPAQUE_INTERNAL_H_ +#include + +#include "tensorflow/lite/core/api/op_resolver.h" #include "tensorflow/lite/core/c/common.h" +// Internal structures and subroutines used by the C API. These are likely to +// change and should not be depended on directly by any C API clients. +// +// NOTE: This header does not follow C conventions and does not define a C API. +// It is effectively an (internal) implementation detail of the C API. + namespace tflite { namespace internal { class CommonOpaqueConversionUtil { public: - // Create a 'TfLiteRegistrationExternal' object that corresponds to the - // provided 'registration' argument, set it as the 'registration's - // 'registration_external' field and return the address of the external - // registration. We loosely define that a 'TfLiteRegistrationExternal' object - // "corresponds" to a 'TfLiteRegistration' object when calling any function - // pointer (like 'prepare') on the 'TfLiteRegistrationExternal' object calls - // into the corresponding function pointer of the 'TfLiteRegistration' object. + // Obtain (or create) a 'TfLiteRegistrationExternal' object that corresponds + // to the provided 'registration' argument, and return the address of the + // external registration. We loosely define that a + // 'TfLiteRegistrationExternal' object "corresponds" to a 'TfLiteRegistration' + // object when calling any function pointer (like 'prepare') on the + // 'TfLiteRegistrationExternal' object calls into the corresponding function + // pointer of the 'TfLiteRegistration' object. // - // The specified 'context' is used to store the 'TfLiteRegistrationExternal*' - // pointers. The 'TfLiteRegistrationExternal*' pointer will be deallocated - // when the 'context' gets destroyed. I.e., the caller of this function - // should not deallocate the object pointed to by the return value of - // 'ObtainRegistrationExternal'. + // The specified 'context' or 'op_resolver' object is used to store the + // 'TfLiteRegistrationExternal*' pointers. The 'TfLiteRegistrationExternal*' + // pointer will be deallocated when that object gets destroyed. I.e., the + // caller of this function should not deallocate the object pointed to by the + // return value of 'ObtainRegistrationExternal'. // // We also need to provide the 'node_index' that the 'registration' // corresponds to, so that the 'TfLiteRegistrationExternal' can store that - // index within its fields. + // index within its fields. If the registration does not yet correspond + // to a specific node index, then 'node_index' should be -1. static TfLiteRegistrationExternal* ObtainRegistrationExternal( - TfLiteContext* context, TfLiteRegistration* registration, int node_index); + TfLiteContext* context, const TfLiteRegistration* registration, + int node_index); + + // Get a shared_ptr to the RegistrationExternalsCache from an OpResolver. + // This is used to allow the InterpreterBuilder and OpResolver to share + // the same RegistrationExternalsCache, so that the RegistrationExternal + // objects in it can persist for the lifetimes of both the InterpreterBuilder + // and OpResolver. + static std::shared_ptr<::tflite::internal::RegistrationExternalsCache> + GetSharedCache(const ::tflite::OpResolver& op_resolver) { + return op_resolver.registration_externals_cache_; + } + + private: + static TfLiteRegistrationExternal* CachedObtainRegistrationExternal( + ::tflite::internal::RegistrationExternalsCache* + registration_externals_cache, + const TfLiteRegistration* registration, int node_index); }; + } // namespace internal } // namespace tflite #endif // TENSORFLOW_LITE_C_C_API_OPAQUE_INTERNAL_H_ diff --git a/tensorflow/lite/c/c_api_opaque_internal_test.cc b/tensorflow/lite/c/c_api_opaque_internal_test.cc new file mode 100644 index 00000000000000..0965b85f173c60 --- /dev/null +++ b/tensorflow/lite/c/c_api_opaque_internal_test.cc @@ -0,0 +1,76 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/c/c_api_opaque_internal.h" + +#include + +#include +#include "tensorflow/lite/builtin_ops.h" +#include "tensorflow/lite/core/subgraph.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/interpreter_builder.h" +#include "tensorflow/lite/kernels/builtin_op_kernels.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/model_builder.h" + +using tflite::FlatBufferModel; +using tflite::Interpreter; +using tflite::InterpreterBuilder; +using tflite::internal::CommonOpaqueConversionUtil; +using tflite::ops::builtin::BuiltinOpResolver; + +TEST(ObtainRegistrationFromContext, ProducesValidResult) { + BuiltinOpResolver op_resolver; + std::unique_ptr interpreter; + std::unique_ptr model = FlatBufferModel::BuildFromFile( + "tensorflow/lite/testdata/add.bin"); + ASSERT_NE(model, nullptr); + InterpreterBuilder builder(*model, op_resolver); + ASSERT_EQ(builder(&interpreter), kTfLiteOk); + ASSERT_NE(interpreter, nullptr); + TfLiteContext* context = interpreter->primary_subgraph().context(); + const TfLiteRegistration* registration = tflite::ops::builtin::Register_ADD(); + + TfLiteRegistrationExternal* registration_external = + CommonOpaqueConversionUtil::ObtainRegistrationExternal(context, + registration, 42); + + ASSERT_EQ(registration_external->builtin_code, kTfLiteBuiltinAdd); + ASSERT_EQ(registration_external->version, registration->version); + ASSERT_EQ(registration_external->custom_name, registration->custom_name); + ASSERT_EQ(registration_external->node_index, 42); +} + +TEST(ObtainRegistrationFromContext, CachingWorks) { + BuiltinOpResolver op_resolver; + std::unique_ptr interpreter; + std::unique_ptr model = FlatBufferModel::BuildFromFile( + "tensorflow/lite/testdata/add.bin"); + ASSERT_NE(model, nullptr); + InterpreterBuilder builder(*model, op_resolver); + ASSERT_EQ(builder(&interpreter), kTfLiteOk); + ASSERT_NE(interpreter, nullptr); + TfLiteContext* context = interpreter->primary_subgraph().context(); + const TfLiteRegistration* registration = tflite::ops::builtin::Register_ADD(); + + // Call it twice, and verify that we get the same result back. + TfLiteRegistrationExternal* registration_external1 = + CommonOpaqueConversionUtil::ObtainRegistrationExternal(context, + registration, 0); + TfLiteRegistrationExternal* registration_external2 = + CommonOpaqueConversionUtil::ObtainRegistrationExternal(context, + registration, 1); + ASSERT_EQ(registration_external1, registration_external2); +} diff --git a/tensorflow/lite/core/api/op_resolver.h b/tensorflow/lite/core/api/op_resolver.h index e8a4e32771a456..b43f1adc664ada 100644 --- a/tensorflow/lite/core/api/op_resolver.h +++ b/tensorflow/lite/core/api/op_resolver.h @@ -16,7 +16,10 @@ limitations under the License. #define TENSORFLOW_LITE_CORE_API_OP_RESOLVER_H_ #include +#include #include +#include +#include #include #include "tensorflow/lite/core/api/error_reporter.h" @@ -25,6 +28,16 @@ limitations under the License. namespace tflite { +#ifndef DOXYGEN_SKIP +class OpResolverInternal; // For friend declaration below. +class Subgraph; // For friend declaration below. + +namespace internal { +class CommonOpaqueConversionUtil; // For friend declaration below. +class RegistrationExternalsCache; // Forward decl. +} // namespace internal +#endif + /// Abstract interface that returns TfLiteRegistrations given op codes or custom /// op names. This is the mechanism that ops being referenced in the flatbuffer /// model are mapped to executable function pointers (TfLiteRegistrations). @@ -104,7 +117,9 @@ class OpResolver { return {}; } - virtual ~OpResolver() {} + virtual ~OpResolver() = default; + OpResolver() = default; + OpResolver(const OpResolver& other) = default; private: /// Returns true if this OpResolver may contain any "user defined" ops. @@ -120,8 +135,80 @@ class OpResolver { /// "builtin" ops, and may not support all of the "builtin" op enum values. virtual bool MayContainUserDefinedOps() const { return true; } +#ifndef DOXYGEN_SKIP friend class OpResolverInternal; + friend class Subgraph; // For OpId. + friend class tflite::internal::CommonOpaqueConversionUtil; + friend class tflite::internal::RegistrationExternalsCache; +#endif + + // This holds the identity of an operator. + // Ths is used as the key for the RegistrationExternalsCache below. + struct OpId { + int builtin_code; + const char* custom_name; + int version; + bool operator==(const OpId& other) const { + return builtin_code == other.builtin_code && + custom_name == other.custom_name && version == other.version; + } + struct Hasher { + size_t operator()(const OpId& op_id) const { + size_t hash_builtin_code = std::hash()(op_id.builtin_code); + size_t hash_custom_name = + op_id.custom_name != nullptr + ? std::hash()(std::string(op_id.custom_name)) + : 0; + size_t hash_version = std::hash()(op_id.version); + return Combine(hash_builtin_code, + Combine(hash_custom_name, hash_version)); + } + + private: + static size_t Combine(size_t hash1, size_t hash2) { + constexpr int num_bits_to_rotate_left = 21; + constexpr int num_bits_to_rotate_right = + std::numeric_limits::digits - num_bits_to_rotate_left; + size_t hash1_rotated = (hash1 << num_bits_to_rotate_left) | + (hash1 >> num_bits_to_rotate_right); + return hash1_rotated + hash2; + } + }; + }; + + // A set of 'TfLiteRegistrationExternal' objects whose lifetimes need to + // last at least as long as the lifetime of the OpResolver. + // We use shared_ptr rather than unique_ptr here, to allow the + // RegistrationExternalsCache to be shared with other classes such as the + // InterpreterBuilder and Interpreter. This is so that the + // TfLiteRegistrationExternal objects allocated by an OpResolver, + // which may be referenced by a Subgraph in an Interpreter, can remain live + // even if the OpResolver is destroyed, while also allowing the same + // OpResolver to be used with multiple InterpreterBuilders and multiple + // Interpreters. + mutable std::shared_ptr + registration_externals_cache_; +}; + +#ifndef DOXYGEN_SKIP +// Type for a set of owned 'TfLiteRegistrationExternal' objects. +// This is needed when converting TfLiteRegistration to +// TfLiteRegistrationExternal, to ensure that the number of +// TfLiteRegistrationExternal objects that we allocate is bounded, and to +// ensure that those objects get deallocated at the appropriate time. +// We use a public class rather than a typedef or using declaration here, +// to ensure that the class can be forward-declared. +// WARNING: Experimental interface, subject to change. +namespace internal { +class RegistrationExternalsCache + : private std::unordered_map, + OpResolver::OpId::Hasher> { + friend class ::tflite::Subgraph; + friend class ::tflite::internal::CommonOpaqueConversionUtil; }; +} // namespace internal +#endif // Handles the logic for converting between an OperatorCode structure extracted // from a flatbuffer and information about a registered operator diff --git a/tensorflow/lite/core/c/c_api.cc b/tensorflow/lite/core/c/c_api.cc index fd07b58478b4fd..343a2866645894 100644 --- a/tensorflow/lite/core/c/c_api.cc +++ b/tensorflow/lite/core/c/c_api.cc @@ -309,15 +309,17 @@ const TfLiteRegistration* CallbackOpResolver::FindOp(tflite::BuiltinOperator op, // Try using newer RegistrationExternal API. if (op_resolver_callbacks_.find_builtin_op_external) { - // Get a RegistrationExternal object and create a Registration (V3) object. + // Get a RegistrationExternal object and create a Registration (V4) object. const TfLiteRegistrationExternal* registration_external = op_resolver_callbacks_.find_builtin_op_external( op_resolver_callbacks_.user_data, static_cast(op), version); - if (registration_external && (registration_external->init != nullptr || - registration_external->free != nullptr || - registration_external->invoke != nullptr || - registration_external->prepare != nullptr)) { + if (registration_external != nullptr && + (registration_external->init != nullptr || + registration_external->free != nullptr || + registration_external->invoke != nullptr || + registration_external->prepare != nullptr || + registration_external->async_kernel != nullptr)) { TfLiteRegistration* new_registration = RegistrationExternalToRegistration(registration_external); temporary_builtin_registrations_.push_back( diff --git a/tensorflow/lite/core/subgraph.cc b/tensorflow/lite/core/subgraph.cc index 0f21ec72e5908f..62c148a485ceb9 100644 --- a/tensorflow/lite/core/subgraph.cc +++ b/tensorflow/lite/core/subgraph.cc @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/lite/c/common_internal.h" #include "tensorflow/lite/context_util.h" #include "tensorflow/lite/core/api/error_reporter.h" +#include "tensorflow/lite/core/api/op_resolver.h" #include "tensorflow/lite/core/api/profiler.h" #include "tensorflow/lite/core/api/tensor_utils.h" #include "tensorflow/lite/core/c/c_api_types.h" @@ -252,6 +253,7 @@ Subgraph::Subgraph(ErrorReporter* error_reporter, resource::InitializationStatusMap* initialization_status_map, int subgraph_index) : external_contexts_(external_contexts), + registration_externals_(new internal::RegistrationExternalsCache), error_reporter_(error_reporter), next_execution_plan_index_to_prepare_(0), next_execution_plan_index_to_plan_allocation_(0), @@ -509,8 +511,34 @@ TfLiteStatus Subgraph::ReplaceNodeSubsetsWithDelegateKernels( // The subgraph is taking ownership of the external registration, in case the // user has supplied an opaque delegate. if (TfLiteDelegateHasValidOpaqueDelegateBuilder(delegate)) { - registration_externals_.insert(std::unique_ptr( - registration.registration_external)); + // If the user has supplied an opaque delegate, then they _must_ also use + // TfLiteRegistrationExternal. + if (!registration.registration_external) { + TFLITE_LOG( + tflite::TFLITE_LOG_WARNING, + "For a delegate with the 'opaque_delegate_builder' field set, the " + "delegate kernel's TfLiteRegistration object must have the " + "'registration_external' field set."); + return kTfLiteDelegateError; + } + + // In this case, the subgraph takes ownership of the external registration. + OpResolver::OpId op_id{registration.registration_external->builtin_code, + registration.registration_external->custom_name, + registration.registration_external->version}; + auto [it, inserted] = registration_externals_->emplace( + op_id, std::unique_ptr( + registration.registration_external)); + // If there was already an entry for this op_id in the + // registration_externals_ cache, the statement above will have + // no effect on the registration_externals_ cache, + // but will deallocate registration.registration_externals. + // To ensure that registration remains valid, we need to use the + // registration_externals value that was previously in the cache. + if (!inserted) { + auto registration_external_from_cache = it->second.get(); + registration.registration_external = registration_external_from_cache; + } } // Ignore empty node replacement sets. diff --git a/tensorflow/lite/core/subgraph.h b/tensorflow/lite/core/subgraph.h index abaa163f372f83..3e30b5afe4c4a3 100644 --- a/tensorflow/lite/core/subgraph.h +++ b/tensorflow/lite/core/subgraph.h @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/lite/allocation.h" #include "tensorflow/lite/c/common_internal.h" #include "tensorflow/lite/core/api/error_reporter.h" +#include "tensorflow/lite/core/api/op_resolver.h" #include "tensorflow/lite/core/api/profiler.h" #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/core/macros.h" @@ -950,11 +951,13 @@ class Subgraph { // the TfLiteRegistrationExternal objects contained in this fielld. // // LINT.IfChange - // Ideally we could include c_api.h and use - // 'TfLiteRegistrationExternalDelete' as the deleter, but that would create a - // dependency cycle. - std::unordered_set< // NOLINT - std::unique_ptr> + // The definition of RegistrationExternalsCache implicitly assumes that + // TfLiteRegistrationExternalDelete is the same as the standard C++ delete + // operator. + // TODO(b/238435088): in op_resolver, include registration_external.h and use + // 'TfLiteRegistrationExternalDelete' as the deleter, then we can eliminate + // the IfChange...ThenChange directive below. + std::shared_ptr<::tflite::internal::RegistrationExternalsCache> registration_externals_; // LINT.ThenChange(//tensorflow/lite/core/c/c_api.cc) diff --git a/tensorflow/lite/delegates/delegate_test.cc b/tensorflow/lite/delegates/delegate_test.cc index 51a36d4f7898ac..078aa0863a7d55 100644 --- a/tensorflow/lite/delegates/delegate_test.cc +++ b/tensorflow/lite/delegates/delegate_test.cc @@ -389,6 +389,7 @@ TEST_F(TestDelegate, TestCopyFromBufferInvoke) { } TEST_F(TestDelegate, TestCopyFromBuffer) { + interpreter_->Invoke(); delegate_ = std::unique_ptr(new SimpleDelegate({0, 1, 2})); TfLiteDelegate* delegate = delegate_->get_tf_lite_delegate(); interpreter_->ModifyGraphWithDelegate(delegate); @@ -445,6 +446,9 @@ struct OpaqueTestDelegate { delegate_state->delegate_prepared = true; TfLiteRegistration registration{}; + registration.registration_external = TfLiteRegistrationExternalCreate( + kTfLiteBuiltinDelegate, "OpaqueTestDelegate delegate kernel", 1); + registration.prepare = [](TfLiteContext* context, TfLiteNode* node) -> TfLiteStatus { return kTfLiteOk; diff --git a/tensorflow/lite/delegates/utils/simple_opaque_delegate.cc b/tensorflow/lite/delegates/utils/simple_opaque_delegate.cc index e930492f352074..db22cbe891f7c3 100644 --- a/tensorflow/lite/delegates/utils/simple_opaque_delegate.cc +++ b/tensorflow/lite/delegates/utils/simple_opaque_delegate.cc @@ -29,7 +29,7 @@ limitations under the License. namespace tflite { namespace { -TfLiteRegistrationExternal* GetDelegateKernelRegistration( +TfLiteRegistrationExternal* CreateDelegateKernelRegistration( SimpleOpaqueDelegateInterface* delegate) { TfLiteRegistrationExternal* kernel_registration = TfLiteRegistrationExternalCreate(kTfLiteBuiltinDelegate, delegate->Name(), @@ -110,8 +110,9 @@ TfLiteStatus DelegatePrepare(TfLiteOpaqueContext* opaque_context, } TfLiteRegistrationExternal* delegate_kernel_registration = - GetDelegateKernelRegistration(simple_opaque_delegate); + CreateDelegateKernelRegistration(simple_opaque_delegate); + // Transfers ownership of delegate_kernel_registration to the opaque_context. return TfLiteOpaqueContextReplaceNodeSubsetsWithDelegateKernels( opaque_context, delegate_kernel_registration, BuildTfLiteArray(supported_nodes).get(), opaque_delegate); diff --git a/tensorflow/lite/special_rules.bzl b/tensorflow/lite/special_rules.bzl index c9614dbb4113d1..c13438ee4fdd38 100644 --- a/tensorflow/lite/special_rules.bzl +++ b/tensorflow/lite/special_rules.bzl @@ -48,6 +48,12 @@ def op_resolver_internal_visibility_allowlist(): This is a no-op outside of Google.""" return [] +def c_api_opaque_internal_visibility_allowlist(): + """Returns a list of packages that can depend on tensorflow/lite/c:c_api_opaque_internal. + + This is a no-op outside of Google.""" + return [] + def nnapi_plugin_impl_visibility_allowlist(): """Returns a list of packages that can depend on tensorflow/lite/acceleration/configuration:nnapi_plugin_impl. diff --git a/tensorflow/lite/tools/serialization/writer_lib.h b/tensorflow/lite/tools/serialization/writer_lib.h index 7bd6313ce51f28..2c0aa924813cb0 100644 --- a/tensorflow/lite/tools/serialization/writer_lib.h +++ b/tensorflow/lite/tools/serialization/writer_lib.h @@ -24,13 +24,30 @@ limitations under the License. #include #include +// This #include needs to precede the inclusion of any other TF Lite header +// file that might depend on the non-mutable schema_generated.h, directly, +// e.g. core/api/op_resolver.h, or indirectly, e.g. core/subgraph.h. +// That's because "tensorflow/lite/schema/mutable/schema_generated.h" +// and "tensorflow/lite/schema/schema_generated.h" both use the same +// header guard macro (FLATBUFFERS_GENERATED_SCHEMA_TFLITE_H_), but have +// different contents (the former is a superset of the latter). In particular +// the one in mutable/ is built with the "--gen-mutable" and "--gen-object-api" +// flags to the flatbuffer schema compiler which cause some additional +// (non-virtual) accessor methods and API functions to be declared. +// The code here uses those methods, so we need to make sure that we get +// the mutable variant of this header. +// +// The '#if' here prevents automatic reordering of this #include. +#if 1 +#include "tensorflow/lite/schema/mutable/schema_generated.h" +#endif + #include "absl/container/flat_hash_map.h" #include "tensorflow/lite/builtin_op_data.h" #include "tensorflow/lite/context_util.h" #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/core/interpreter.h" #include "tensorflow/lite/core/subgraph.h" -#include "tensorflow/lite/schema/mutable/schema_generated.h" #include "tensorflow/lite/tools/serialization/enum_mapping.h" #include "tensorflow/lite/version.h" From d543bec0b795358eac7c8b6d350209ecec067e71 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Mon, 24 Jul 2023 16:46:51 -0700 Subject: [PATCH 079/410] [xla:gpu] Don't build IREE transitive dependencies by default Put all IREE dependencies under a build flag PiperOrigin-RevId: 550710509 --- .../mlir/backends/openxla/transforms/BUILD | 2 +- .../mlir/backends/openxla/transforms/passes.h | 4 +- .../compiler/xla/service/gpu/openxla/BUILD | 64 ++++++++++++------- 3 files changed, 45 insertions(+), 25 deletions(-) diff --git a/tensorflow/compiler/xla/mlir/backends/openxla/transforms/BUILD b/tensorflow/compiler/xla/mlir/backends/openxla/transforms/BUILD index 2856f2551b3c7b..c72c0f2ce1db5d 100644 --- a/tensorflow/compiler/xla/mlir/backends/openxla/transforms/BUILD +++ b/tensorflow/compiler/xla/mlir/backends/openxla/transforms/BUILD @@ -58,6 +58,6 @@ package( cc_library( name = "passes", hdrs = ["passes.h"], - defines = ["XLA_DISABLE_OPENXLA_RUNTIME=1"], + defines = ["XLA_DISABLE_OPENXLA_COMPILER=1"], ) # copybara:comment_end diff --git a/tensorflow/compiler/xla/mlir/backends/openxla/transforms/passes.h b/tensorflow/compiler/xla/mlir/backends/openxla/transforms/passes.h index 7249094b880e33..50ed90d121d407 100644 --- a/tensorflow/compiler/xla/mlir/backends/openxla/transforms/passes.h +++ b/tensorflow/compiler/xla/mlir/backends/openxla/transforms/passes.h @@ -19,7 +19,7 @@ limitations under the License. //===----------------------------------------------------------------------===// // TODO(ezhulenev): We currently do not build with OpenXLA runtime in open // source because we do not have bazel dependency from XLA to IREE. -#if XLA_DISABLE_OPENXLA_RUNTIME +#if XLA_DISABLE_OPENXLA_COMPILER //===----------------------------------------------------------------------===// namespace mlir { @@ -33,7 +33,7 @@ inline void populateOpenXlaRuntimePasses(mlir::OpPassManager&, ThunkSequence*) { } // namespace xla::gpu //===----------------------------------------------------------------------===// -#else // !XLA_DISABLE_OPENXLA_RUNTIME +#else // !XLA_DISABLE_OPENXLA_COMPILER //===----------------------------------------------------------------------===// #include diff --git a/tensorflow/compiler/xla/service/gpu/openxla/BUILD b/tensorflow/compiler/xla/service/gpu/openxla/BUILD index 50113a4ffabf34..38aa94529445e2 100644 --- a/tensorflow/compiler/xla/service/gpu/openxla/BUILD +++ b/tensorflow/compiler/xla/service/gpu/openxla/BUILD @@ -14,32 +14,34 @@ package_group( # copybara:uncomment_begin(not supported in OSS build) # -# # Add `--define=xla_gpu_bundle_lib_iree_compiler=1` to build command to bundle `libIREECompiler.so` -# # with XLA:GPU by default. Otherwise use `XLA_OPENXLA_IREE_COMPILER_LIB` environment variable to -# # load custom compiler library. +# # Add `--define=xla_gpu_with_openxla_runtime=1` to build command to enable experimental OpenXLA/IREE +# # backend for XLA:GPU executables. # config_setting( -# name = "bundle_lib_iree_compiler", +# name = "with_openxla_runtime", # values = { -# "define": "xla_gpu_bundle_lib_iree_compiler=1", +# "define": "xla_gpu_with_openxla_runtime=1", # }, # ) # # cc_library( # name = "compiler", -# srcs = ["compiler.cc"], -# hdrs = ["compiler.h"], +# srcs = select({ +# ":with_openxla_runtime": ["compiler.cc"], +# "//conditions:default": [], +# }), +# hdrs = select({ +# ":with_openxla_runtime": ["compiler.h"], +# "//conditions:default": [], +# }), # # TODO(ezhulenev): Override cc_library()'s default compatibility because IREE targets are not # # compatible with `non_prod` constraint. # compatible_with = [], # data = select({ -# ":bundle_lib_iree_compiler": ["//third_party/iree/lib:libIREECompiler.so"], +# ":with_openxla_runtime": ["//third_party/iree/lib:libIREECompiler.so"], # "//conditions:default": [], # }), # deps = [ # "@com_google_absl//absl/base", -# "//third_party/iree/compiler/bindings/c:headers", -# "//third_party/iree/compiler/bindings/c:loader", -# "//third_party/iree/llvm-external-projects/iree-dialects:IREEInputDialect", # "@llvm-project//llvm:Support", # "@llvm-project//mlir:IR", # "@llvm-project//mlir:Support", @@ -49,28 +51,34 @@ package_group( # ] + tf_platform_deps( # "compiler", # platform_dir = "//tensorflow/compiler/xla/service/gpu/openxla/", -# ), +# ) + select({ +# ":with_openxla_runtime": [ +# "//third_party/iree/compiler/bindings/c:headers", +# "//third_party/iree/compiler/bindings/c:loader", +# "//third_party/iree/llvm-external-projects/iree-dialects:IREEInputDialect", +# ], +# "//conditions:default": [], +# }), # ) # # cc_library( # name = "executable", -# srcs = ["executable.cc"], +# srcs = select({ +# ":with_openxla_runtime": ["executable.cc"], +# "//conditions:default": [], +# }), # hdrs = ["executable.h"], # # TODO(ezhulenev): Override cc_library()'s default compatibility because IREE targets are not # # compatible with `non_prod` constraint. # compatible_with = [], +# defines = select({ +# ":with_openxla_runtime": [], +# "//conditions:default": ["XLA_DISABLE_OPENXLA_RUNTIME=1"], +# }), # deps = [ -# ":compiler", # "@com_google_absl//absl/log", # "@com_google_absl//absl/log:check", # "@com_google_absl//absl/strings", -# "//third_party/iree/runtime/src/iree/base", -# "//third_party/iree/runtime/src/iree/hal", -# "//third_party/iree/runtime/src/iree/hal/drivers/cuda", -# "//third_party/iree/runtime/src/iree/modules/hal", -# "//third_party/iree/runtime/src/iree/modules/hal:types", -# "//third_party/iree/runtime/src/iree/vm", -# "//third_party/iree/runtime/src/iree/vm/bytecode:module", # "@llvm-project//mlir:IR", # "//tensorflow/compiler/xla:status", # "//tensorflow/compiler/xla:statusor", @@ -80,7 +88,19 @@ package_group( # "//tensorflow/compiler/xla/service:executable", # "//tensorflow/compiler/xla/service/gpu:buffer_allocations", # "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", -# ], +# ] + select({ +# ":with_openxla_runtime": [ +# ":compiler", +# "//third_party/iree/runtime/src/iree/base", +# "//third_party/iree/runtime/src/iree/hal", +# "//third_party/iree/runtime/src/iree/hal/drivers/cuda", +# "//third_party/iree/runtime/src/iree/modules/hal", +# "//third_party/iree/runtime/src/iree/modules/hal:types", +# "//third_party/iree/runtime/src/iree/vm", +# "//third_party/iree/runtime/src/iree/vm/bytecode:module", +# ], +# "//conditions:default": [], +# }), # ) # # copybara:uncomment_end_and_comment_begin From 4f94930cb59eae43ca20a9b1ff92f3ca7c7ac6ee Mon Sep 17 00:00:00 2001 From: Bhavani Subramanian Date: Mon, 24 Jul 2023 09:05:02 -0700 Subject: [PATCH 080/410] Using oneDNN v3.x by default on Linux and Windows x86 builds --- .bazelrc | 2 - tensorflow/BUILD | 8 +- .../lite/experimental/tac/py_wrapper/BUILD | 8 +- .../core/common_runtime/mkl_layout_pass.cc | 18 --- tensorflow/core/kernels/BUILD | 8 +- .../core/kernels/mkl/mkl_avgpooling_op.cc | 18 +-- tensorflow/core/kernels/mkl/mkl_concat_op.cc | 12 +- .../kernels/mkl/mkl_conv_grad_filter_ops.cc | 28 ++-- .../kernels/mkl/mkl_conv_grad_input_ops.cc | 32 ++-- tensorflow/core/kernels/mkl/mkl_conv_ops.cc | 120 +++++++-------- tensorflow/core/kernels/mkl/mkl_conv_ops.h | 4 +- .../core/kernels/mkl/mkl_dequantize_op.cc | 8 +- .../mkl/mkl_eltwise_activation_base_op.h | 36 ++--- .../kernels/mkl/mkl_fused_batch_norm_op.cc | 144 +++++++++--------- .../kernels/mkl/mkl_fused_instance_norm_op.cc | 24 +-- tensorflow/core/kernels/mkl/mkl_kernel_util.h | 4 +- .../core/kernels/mkl/mkl_layer_norm_op.cc | 12 +- .../core/kernels/mkl/mkl_matmul_ops_common.h | 76 ++++----- .../core/kernels/mkl/mkl_maxpooling_op.cc | 18 +-- .../kernels/mkl/mkl_pooling_ops_common.cc | 30 ++-- .../core/kernels/mkl/mkl_pooling_ops_common.h | 60 ++++---- tensorflow/core/kernels/mkl/mkl_qmatmul_op.cc | 24 +-- .../core/kernels/mkl/mkl_qmatmul_op_test.cc | 8 +- .../core/kernels/mkl/mkl_quantize_op.cc | 60 ++++---- .../mkl_quantized_conv_ops_perchannel_test.cc | 4 +- tensorflow/core/kernels/mkl/mkl_relu_op.cc | 7 +- .../core/kernels/mkl/mkl_relu_op_test.cc | 7 +- .../mkl/mkl_requantize_per_channel_op.cc | 8 +- tensorflow/core/kernels/mkl/mkl_softmax_op.cc | 16 +- tensorflow/core/util/mkl_util.h | 20 +-- tensorflow/core/util/mkl_util_test.cc | 4 + tensorflow/python/BUILD | 5 +- tensorflow/tensorflow.bzl | 9 +- tensorflow/tsl/BUILD | 18 --- tensorflow/tsl/framework/contraction/BUILD | 2 +- tensorflow/tsl/mkl/build_defs.bzl | 5 +- tensorflow/tsl/tsl.bzl | 4 +- tensorflow/workspace2.bzl | 8 +- third_party/mkl_dnn/BUILD | 21 --- third_party/mkl_dnn/build_defs.bzl | 14 -- third_party/mkl_dnn/mkldnn_v1.BUILD | 7 +- 41 files changed, 418 insertions(+), 503 deletions(-) diff --git a/.bazelrc b/.bazelrc index f09b953b710062..b860e90fc79360 100644 --- a/.bazelrc +++ b/.bazelrc @@ -202,8 +202,6 @@ build:monolithic --define framework_shared_object=false build:monolithic --define tsl_protobuf_header_only=false build:monolithic --experimental_link_static_libraries_once=false # b/229868128 -build:linux --define=build_with_onednn_v2=true - # Please note that MKL on MacOS is still not supported. # If you would like to use a local MKL instead of downloading, please set the # environment variable "TF_MKL_ROOT" every time before build. diff --git a/tensorflow/BUILD b/tensorflow/BUILD index ddf9b47151a9fe..10e5031f9af628 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -32,10 +32,6 @@ load( "//third_party/mkl:build_defs.bzl", "if_mkl_ml", ) -load( - "//third_party/mkl_dnn:build_defs.bzl", - "if_onednn_v3", -) load("@bazel_skylib//:bzl_library.bzl", "bzl_library") load( "//tensorflow:tensorflow.default.bzl", @@ -108,10 +104,10 @@ PACKAGE_STATIC_DEPS = [ "@local_config_tensorrt//:__subpackages__", "@local_execution_config_platform//:__subpackages__", "@mkl_dnn_acl_compatible//:__subpackages__", - "@mkl_dnn_v1//:__subpackages__", "@ml_dtypes//:__subpackages__", "@nccl_archive//:__subpackages__", "@nvtx_archive//:__subpackages__", + "@onednn//:__subpackages__", "@org_sqlite//:__subpackages__", "@platforms//:__subpackages__", "@snappy//:__subpackages__", @@ -129,7 +125,7 @@ PACKAGE_STATIC_DEPS = [ "@flatbuffers//:__subpackages__", "@nccl_archive//:__subpackages__", "@triton//:__subpackages__", -] + tsl_async_value_deps() + if_onednn_v3(["@onednn_v3//:__subpackages__"]) +] + tsl_async_value_deps() package( # copybara:uncomment default_applicable_licenses = [":license"], diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/py_wrapper/BUILD b/tensorflow/compiler/mlir/lite/experimental/tac/py_wrapper/BUILD index 57ee70321ee7f5..7d90b84c626116 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/py_wrapper/BUILD +++ b/tensorflow/compiler/mlir/lite/experimental/tac/py_wrapper/BUILD @@ -1,9 +1,5 @@ load("//tensorflow:tensorflow.default.bzl", "pybind_extension") load("//tensorflow:tensorflow.bzl", "VERSION") -load( - "//third_party/mkl_dnn:build_defs.bzl", - "if_onednn_v3", -) package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -85,8 +81,8 @@ pybind_extension( "@local_config_tensorrt//:__subpackages__", "@local_execution_config_platform//:__subpackages__", "@mkl_dnn_acl_compatible//:__subpackages__", - "@mkl_dnn_v1//:__subpackages__", "@nsync//:__subpackages__", + "@onednn//:__subpackages__", "@org_sqlite//:__subpackages__", "@platforms//:__subpackages__", "@png//:__subpackages__", @@ -100,7 +96,7 @@ pybind_extension( "@upb//:__subpackages__", "@XNNPACK//:__subpackages__", "@zlib//:__subpackages__", - ] + if_onednn_v3(["@onednn_v3//:__subpackages__"]), + ], deps = [ ":tac_wrapper_lib", "//tensorflow/python/lib/core:pybind11_lib", diff --git a/tensorflow/core/common_runtime/mkl_layout_pass.cc b/tensorflow/core/common_runtime/mkl_layout_pass.cc index bd96f0c67287a1..1ff9a579097085 100644 --- a/tensorflow/core/common_runtime/mkl_layout_pass.cc +++ b/tensorflow/core/common_runtime/mkl_layout_pass.cc @@ -384,7 +384,6 @@ class MklLayoutRewritePass : public GraphOptimizationPass { const bool native_fmt = NativeFormatEnabled(); // NOTE: names are alphabetically sorted. -#ifndef ENABLE_ONEDNN_V3 rinfo_.push_back({csinfo_.addn, mkl_op_registry::GetMklOpName(csinfo_.addn), CopyAttrsAll, AlwaysRewrite, GetRewriteCause()}); rinfo_.push_back({csinfo_.add, mkl_op_registry::GetMklOpName(csinfo_.add), @@ -393,7 +392,6 @@ class MklLayoutRewritePass : public GraphOptimizationPass { rinfo_.push_back( {csinfo_.add_v2, mkl_op_registry::GetMklOpName(csinfo_.add_v2), CopyAttrsAll, RewriteIfAtleastOneMklInput, GetRewriteCause()}); -#endif // !ENABLE_ONEDNN_V3 rinfo_.push_back({csinfo_.avg_pool, mkl_op_registry::GetMklOpName(csinfo_.avg_pool), CopyAttrsAll, RewriteIfX86, GetRewriteCause()}); @@ -517,7 +515,6 @@ class MklLayoutRewritePass : public GraphOptimizationPass { : csinfo_.mkl_fused_matmul, CopyAttrsAllCheckConstFilter, FusedMatMulRewrite, GetRewriteCause()}); -#ifndef ENABLE_ONEDNN_V3 rinfo_.push_back( {csinfo_.fused_pad_conv2d, csinfo_.mkl_native_pad_with_conv2d, CopyAttrsAllCheckConstFilter, AlwaysRewrite, kRewriteForOpNameChange}); @@ -529,18 +526,15 @@ class MklLayoutRewritePass : public GraphOptimizationPass { rinfo_.push_back({csinfo_.lrn_grad, mkl_op_registry::GetMklOpName(csinfo_.lrn_grad), CopyAttrsAll, LrnGradRewrite, GetRewriteCause()}); -#endif // !ENABLE_ONEDNN_V3 rinfo_.push_back({csinfo_.matmul, mkl_op_registry::GetMklOpName(csinfo_.matmul), CopyAttrsAll, MatMulRewrite, kRewriteForOpNameChange}); -#ifndef ENABLE_ONEDNN_V3 rinfo_.push_back({csinfo_.leakyrelu, mkl_op_registry::GetMklOpName(csinfo_.leakyrelu), CopyAttrsAll, LeakyReluRewrite, GetRewriteCause()}); rinfo_.push_back({csinfo_.leakyrelu_grad, mkl_op_registry::GetMklOpName(csinfo_.leakyrelu_grad), CopyAttrsAll, LeakyReluRewrite, GetRewriteCause()}); -#endif // !ENABLE_ONEDNN_V3 rinfo_.push_back( {csinfo_.max_pool, mkl_op_registry::GetMklOpName(csinfo_.max_pool), CopyAttrsAll, NonDepthBatchWisePoolRewrite, GetRewriteCause()}); @@ -553,14 +547,12 @@ class MklLayoutRewritePass : public GraphOptimizationPass { rinfo_.push_back({csinfo_.max_pool3d_grad, mkl_op_registry::GetMklOpName(csinfo_.max_pool3d_grad), CopyAttrsAll, Maxpool3DGradRewrite, GetRewriteCause()}); -#ifndef ENABLE_ONEDNN_V3 rinfo_.push_back( {csinfo_.maximum, mkl_op_registry::GetMklOpName(csinfo_.maximum), CopyAttrsAll, RewriteIfAtleastOneMklInput, GetRewriteCause()}); rinfo_.push_back({csinfo_.mul, mkl_op_registry::GetMklOpName(csinfo_.mul), CopyAttrsAll, RewriteIfAtleastOneMklInput, GetRewriteCause()}); -#endif // !ENABLE_ONEDNN_V3 rinfo_.push_back({csinfo_.pad_with_conv2d, native_fmt ? csinfo_.mkl_native_pad_with_conv2d : csinfo_.mkl_pad_with_conv2d, @@ -577,7 +569,6 @@ class MklLayoutRewritePass : public GraphOptimizationPass { rinfo_.push_back({csinfo_.quantized_concatv2, mkl_op_registry::GetMklOpName(csinfo_.quantized_concatv2), CopyAttrsAll, ConcatV2Rewrite, kRewriteForOpNameChange}); -#ifndef ENABLE_ONEDNN_V3 rinfo_.push_back({csinfo_.quantized_conv2d, mkl_op_registry::GetMklOpName(csinfo_.quantized_conv2d), CopyAttrsQuantizedConv2D, AlwaysRewrite, @@ -619,11 +610,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass { mkl_op_registry::GetMklOpName( csinfo_.quantized_conv2d_with_bias_and_relu_and_requantize), CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange}); -#endif // !ENABLE_ONEDNN_V3 rinfo_.push_back({csinfo_.quantized_max_pool, mkl_op_registry::GetMklOpName(csinfo_.quantized_max_pool), CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange}); -#ifndef ENABLE_ONEDNN_V3 rinfo_.push_back({csinfo_.quantized_conv2d_with_bias_sum_and_relu, mkl_op_registry::GetMklOpName( csinfo_.quantized_conv2d_with_bias_sum_and_relu), @@ -639,7 +628,6 @@ class MklLayoutRewritePass : public GraphOptimizationPass { mkl_op_registry::GetMklOpName( csinfo_.quant_conv2d_with_bias_signed_sum_and_relu_and_requantize), CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange}); -#endif // !ENABLE_ONEDNN_V3 rinfo_.push_back( {csinfo_.quantized_matmul_with_bias, mkl_op_registry::GetMklOpName(csinfo_.quantized_matmul_with_bias), @@ -666,7 +654,6 @@ class MklLayoutRewritePass : public GraphOptimizationPass { csinfo_.quantized_matmul_with_bias_and_dequantize), CopyAttrsQuantizedMatMulWithBiasAndDequantize, AlwaysRewrite, kRewriteForOpNameChange}); -#ifndef ENABLE_ONEDNN_V3 rinfo_.push_back( {csinfo_.quantized_depthwise_conv2d, mkl_op_registry::GetMklOpName(csinfo_.quantized_depthwise_conv2d), @@ -687,12 +674,10 @@ class MklLayoutRewritePass : public GraphOptimizationPass { csinfo_ .quantized_depthwise_conv2d_with_bias_and_relu_and_requantize), CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange}); -#endif // !ENABLE_ONEDNN_V3 rinfo_.push_back({csinfo_.quantize_v2, mkl_op_registry::GetMklOpName(csinfo_.quantize_v2), CopyAttrsAll, QuantizeOpRewrite, kRewriteForOpNameChange}); -#ifndef ENABLE_ONEDNN_V3 rinfo_.push_back({csinfo_.relu, mkl_op_registry::GetMklOpName(csinfo_.relu), CopyAttrsAll, AlwaysRewrite, GetRewriteCause()}); rinfo_.push_back({csinfo_.relu_grad, @@ -719,12 +704,10 @@ class MklLayoutRewritePass : public GraphOptimizationPass { rinfo_.push_back( {csinfo_.slice, mkl_op_registry::GetMklOpName(csinfo_.slice), CopyAttrsAll, RewriteIfAtleastOneMklInput, GetRewriteCause()}); -#endif // !ENABLE_ONEDNN_V3 rinfo_.push_back({csinfo_.softmax, mkl_op_registry::GetMklOpName(csinfo_.softmax), CopyAttrsAll, RewriteIfX86, GetRewriteCause()}); -#ifndef ENABLE_ONEDNN_V3 rinfo_.push_back({csinfo_.squared_difference, mkl_op_registry::GetMklOpName(csinfo_.squared_difference), CopyAttrsAll, RewriteIfAtleastOneMklInput, @@ -732,7 +715,6 @@ class MklLayoutRewritePass : public GraphOptimizationPass { rinfo_.push_back({csinfo_.sub, mkl_op_registry::GetMklOpName(csinfo_.sub), CopyAttrsAll, RewriteIfAtleastOneMklInput, GetRewriteCause()}); -#endif // !ENABLE_ONEDNN_V3 rinfo_.push_back({csinfo_.transpose, mkl_op_registry::GetMklOpName(csinfo_.transpose), CopyAttrsAll, RewriteIfX86, kRewriteForOpNameChange}); diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 5042b6cebabfc7..7b67cb4cac4e54 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -31,10 +31,6 @@ load( "if_mkl", "mkl_deps", ) -load( - "//third_party/mkl_dnn:build_defs.bzl", - "if_onednn_v3", -) load("//tensorflow:tensorflow.default.bzl", "cc_header_only_library", "filegroup", "get_compatible_with_portable", "tf_cc_shared_library", "tf_cuda_cc_test", "tf_cuda_cc_tests", "tf_disable_ptxas_warning_flags", "tf_kernel_library") load( "//tensorflow/core/platform:build_config_root.bzl", @@ -7864,8 +7860,8 @@ tf_cc_shared_library( "@local_config_tensorrt//:__subpackages__", "@local_execution_config_platform//:__subpackages__", "@mkl_dnn_acl_compatible//:__subpackages__", - "@mkl_dnn_v1//:__subpackages__", "@nsync//:__subpackages__", + "@onednn//:__subpackages__", "@org_sqlite//:__subpackages__", "@platforms//:__subpackages__", "@png//:__subpackages__", @@ -7875,7 +7871,7 @@ tf_cc_shared_library( "@upb//:__subpackages__", "@zlib//:__subpackages__", # copybara:comment_end - ] + if_onednn_v3(["@onednn_v3//:__subpackages__"]), + ], visibility = ["//visibility:public"], deps = [ ":kernel_platform_strings", diff --git a/tensorflow/core/kernels/mkl/mkl_avgpooling_op.cc b/tensorflow/core/kernels/mkl/mkl_avgpooling_op.cc index 51a9cd21b7501b..169d2878e460f7 100644 --- a/tensorflow/core/kernels/mkl/mkl_avgpooling_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_avgpooling_op.cc @@ -83,12 +83,12 @@ class MklAvgPoolingOp : public MklPoolingForwardOpBase { memory::dims filter_dims, strides, padding_left, padding_right; // Get src/filter/stride/padding information. -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 this->PoolParamsToDims(&pool_params, &filter_dims, &strides, #else memory::dims dilations; this->PoolParamsToDims(&pool_params, &filter_dims, &strides, &dilations, -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 &padding_left, &padding_right, is_pool2d); // Get the input memory descriptor. @@ -114,11 +114,11 @@ class MklAvgPoolingOp : public MklPoolingForwardOpBase { pooling_prop_kind = prop_kind::forward_training; MklPoolingParams fwdParams( -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 src_dims, output_dims_mkl_order, filter_dims, strides, #else src_dims, output_dims_mkl_order, filter_dims, strides, dilations, -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 padding_left, padding_right, dnnl::algorithm::pooling_avg_exclude_padding, pooling_prop_kind, static_cast(this->data_format_mkldnn_), input_md, @@ -285,12 +285,12 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase { output_shape); memory::dims filter_dims, strides, padding_left, padding_right; -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 this->PoolParamsToDims(&pool_params, &filter_dims, &strides, #else memory::dims dilations; this->PoolParamsToDims(&pool_params, &filter_dims, &strides, &dilations, -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 &padding_left, &padding_right, is_pool2d); memory::dims orig_input_dims_mkl_order = @@ -335,11 +335,11 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase { // that is used in the backward pass. MklPoolingParams bwdParams( orig_input_dims_mkl_order, output_dims_mkl_order, filter_dims, -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 strides, padding_left, padding_right, #else strides, dilations, padding_left, padding_right, -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 dnnl::algorithm::pooling_avg_exclude_padding, prop_kind::forward_training, static_cast(this->data_format_mkldnn_), src_md, @@ -450,7 +450,6 @@ TF_CALL_float(REGISTER_MKL_AVGPOOL_KERNELS); TF_CALL_bfloat16(REGISTER_MKL_AVGPOOL_KERNELS); #undef REGISTER_MKL_AVGPOOL_KERNELS -#ifndef ENABLE_ONEDNN_V3 REGISTER_KERNEL_BUILDER(Name("_MklQuantizedAvgPool") .Device(DEVICE_CPU) .TypeConstraint("T") @@ -462,7 +461,6 @@ REGISTER_KERNEL_BUILDER(Name("_MklQuantizedAvgPool") .TypeConstraint("T") .Label(mkl_op_registry::kMklQuantizedOpLabel), MklAvgPoolingOp); -#endif // !ENABLE_ONEDNN_V3 } // namespace tensorflow diff --git a/tensorflow/core/kernels/mkl/mkl_concat_op.cc b/tensorflow/core/kernels/mkl/mkl_concat_op.cc index 497f997860176b..f41b8f018827b7 100644 --- a/tensorflow/core/kernels/mkl/mkl_concat_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_concat_op.cc @@ -40,7 +40,7 @@ using dnnl::concat; using dnnl::stream; namespace tensorflow { -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 #define CONCAT_PRIM_DESC(eng, concat_dims, src_md, dst_md_ptr) \ concat::primitive_desc(*dst_md_ptr, concat_dims, src_md, eng) #define CONCAT_PRIM_DESC_USING_SRC(eng, concat_dims, src_md) \ @@ -54,7 +54,7 @@ namespace tensorflow { concat::primitive_desc(eng, concat_dims, src_md) #define GET_MEMORY_DESC(md) md #define SET_MKL_LAYOUT(md) SetMklLayout(md) -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 typedef Eigen::ThreadPoolDevice CPUDevice; // List of TensorShape objects. Used in Concat/Split layers. @@ -302,7 +302,7 @@ class MklConcatFwdPrimitive : public MklPrimitive { #endif DCHECK_EQ(in_data.size(), context_.data_mem.size()); for (size_t i = 0; i < concat_fwd_dims.num_inputs; i++) { -#if !defined(ENABLE_ONEDNN_OPENMP) && !defined(ENABLE_ONEDNN_V3) +#if !defined(ENABLE_ONEDNN_OPENMP) && defined(ENABLE_ONEDNN_V2) context_.data_mem_shdptr[i]->set_data_handle( static_cast(in_data[i].get_data_handle()), *fwd_stream); } @@ -314,7 +314,7 @@ class MklConcatFwdPrimitive : public MklPrimitive { } context_.dst_mem->set_data_handle( static_cast(dst_data.get_data_handle())); -#endif // !ENABLE_ONEDNN_OPENMP && !ENABLE_ONEDNN_V3 +#endif // !ENABLE_ONEDNN_OPENMP && ENABLE_ONEDNN_V2 for (size_t i = 0; i < concat_fwd_dims.num_inputs; i++) { context_.data_mem[i] = *context_.data_mem_shdptr[i]; @@ -663,7 +663,7 @@ class MklConcatOp : public OpKernel { auto src_tf_fmt = MklTensorFormatToMklDnnDataFormat( mkl_input_shapes[k].GetTfDataFormat()); if (src_tf_fmt != mkl_common_format) { -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 memory::dims src_dims(src_md.data.dims, &src_md.data.dims[src_md.data.ndims]); #else @@ -673,7 +673,7 @@ class MklConcatOp : public OpKernel { else if (src_md.get_ndims() == 4) src_dims = {src_md.get_dims()[0], src_md.get_dims()[1], src_md.get_dims()[2], src_md.get_dims()[3]}; -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 src_md = memory::desc(src_dims, MklDnnType(), mkl_common_format); } diff --git a/tensorflow/core/kernels/mkl/mkl_conv_grad_filter_ops.cc b/tensorflow/core/kernels/mkl/mkl_conv_grad_filter_ops.cc index 7c24a23d68226e..dca67b791cf796 100644 --- a/tensorflow/core/kernels/mkl/mkl_conv_grad_filter_ops.cc +++ b/tensorflow/core/kernels/mkl/mkl_conv_grad_filter_ops.cc @@ -36,9 +36,9 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 using ConvBwdFilterDesc = dnnl::convolution_backward_weights::desc; -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 using ConvBwdFilterPd = dnnl::convolution_backward_weights::primitive_desc; struct MklConvBwdFilterParams { @@ -96,7 +96,7 @@ class MklConvBwdFilterPrimitive : public MklPrimitive { #ifdef DNNL_AARCH64_USE_ACL mutex_lock lock(primitive_execution_mu_); #endif -#if !defined(ENABLE_ONEDNN_OPENMP) && !defined(ENABLE_ONEDNN_V3) +#if !defined(ENABLE_ONEDNN_OPENMP) && defined(ENABLE_ONEDNN_V2) // TODO(intel-tf): Create a common function and avoid the duplicate code context_.src_mem->set_data_handle( static_cast(const_cast(src_data)), *bwd_filter_stream); @@ -121,7 +121,7 @@ class MklConvBwdFilterPrimitive : public MklPrimitive { } context_.diff_dst_mem->set_data_handle( static_cast(const_cast(diff_dst_data))); -#endif // !ENABLE_ONEDNN_OPENMP && !ENABLE_ONEDNN_V3 +#endif // !ENABLE_ONEDNN_OPENMP && ENABLE_ONEDNN_V2 execute_primitives(context_.bwd_filter_primitives, bwd_filter_stream, context_.bwd_filter_primitives_args); @@ -159,15 +159,15 @@ class MklConvBwdFilterPrimitive : public MklPrimitive { // Primitive descriptor and descriptor for convolution backward filter. std::shared_ptr bwd_filter_pd; -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 std::shared_ptr bwd_filter_desc; -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 // Primitive descriptor and descriptor for convolution forward. std::shared_ptr fwd_pd; -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 std::shared_ptr fwd_desc; -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 // Convolution backward filter primitive. std::shared_ptr conv_bwd_filter; @@ -188,13 +188,13 @@ class MklConvBwdFilterPrimitive : public MklPrimitive { diff_filter_mem(nullptr), diff_bias_mem(nullptr), diff_dst_mem(nullptr), -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 bwd_filter_desc(nullptr), -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 fwd_pd(nullptr), -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 fwd_desc(nullptr), -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 src_md(nullptr), diff_filter_md(nullptr), diff_bias_md(nullptr), @@ -228,7 +228,7 @@ class MklConvBwdFilterPrimitive : public MklPrimitive { new memory::desc({convBwdFilterDims.diff_bias_dims}, MklDnnType(), memory::format_tag::x)); -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 // Create descriptor and primitive descriptor for convolution forward. context_.fwd_desc.reset(new ConvFwdDesc( prop_kind::forward, dnnl::algorithm::convolution_direct, @@ -276,7 +276,7 @@ class MklConvBwdFilterPrimitive : public MklPrimitive { convBwdFilterDims.padding_left, convBwdFilterDims.padding_right, *context_.fwd_pd)); } -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 auto bwd_filter_pd = context_.bwd_filter_pd.get(); diff --git a/tensorflow/core/kernels/mkl/mkl_conv_grad_input_ops.cc b/tensorflow/core/kernels/mkl/mkl_conv_grad_input_ops.cc index 16a6db176843b1..d6f71efd2f4c56 100644 --- a/tensorflow/core/kernels/mkl/mkl_conv_grad_input_ops.cc +++ b/tensorflow/core/kernels/mkl/mkl_conv_grad_input_ops.cc @@ -39,15 +39,15 @@ using dnnl::prop_kind; using dnnl::stream; namespace tensorflow { -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 #define SET_MKL_LAYOUT(md) SetMklLayout(&md) #else #define SET_MKL_LAYOUT(md) SetMklLayout(md) -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 using ConvBwdDataDesc = dnnl::convolution_backward_data::desc; -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 using ConvBwdDataPd = dnnl::convolution_backward_data::primitive_desc; // Utility classes for enabling primitive reuse for conv bwd input. @@ -103,7 +103,7 @@ class MklConvBwdInputPrimitive : public MklPrimitive { #ifdef DNNL_AARCH64_USE_ACL mutex_lock lock(primitive_execution_mu_); #endif -#if !defined(ENABLE_ONEDNN_OPENMP) && !defined(ENABLE_ONEDNN_V3) +#if !defined(ENABLE_ONEDNN_OPENMP) && defined(ENABLE_ONEDNN_V2) // TODO(intel-tf): Create a common function and avoid the duplicate code context_.diff_src_mem->set_data_handle( static_cast(const_cast(diff_src_data)), *bwd_input_stream); @@ -118,7 +118,7 @@ class MklConvBwdInputPrimitive : public MklPrimitive { static_cast(const_cast(filter_data))); context_.diff_dst_mem->set_data_handle( static_cast(const_cast(diff_dst_data))); -#endif // !ENABLE_ONEDNN_OPENMP && !ENABLE_ONEDNN_V3 +#endif // !ENABLE_ONEDNN_OPENMP && ENABLE_ONEDNN_V2 execute_primitives(context_.bwd_input_primitives, bwd_input_stream, context_.bwd_input_primitives_args); @@ -143,15 +143,15 @@ class MklConvBwdInputPrimitive : public MklPrimitive { // Conv backward input primitive descriptor and descriptor. std::shared_ptr bwd_input_pd; -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 std::shared_ptr bwd_input_desc; -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 // Primitive descriptor and descriptor for conv fwd std::shared_ptr fwd_pd; -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 std::shared_ptr fwd_desc; -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 // Conv bwd input primitive. std::shared_ptr conv_bwd_input; @@ -170,13 +170,13 @@ class MklConvBwdInputPrimitive : public MklPrimitive { filter_mem(nullptr), diff_dst_mem(nullptr), bwd_input_pd(nullptr), -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 bwd_input_desc(nullptr), -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 fwd_pd(nullptr), -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 fwd_desc(nullptr), -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 conv_bwd_input(nullptr), diff_src_md(nullptr), filter_md(nullptr), @@ -204,7 +204,7 @@ class MklConvBwdInputPrimitive : public MklPrimitive { memory::format_tag::any)); // Create descriptors for both conv fwd and conv bwd input. -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 context_.bwd_input_desc.reset(new ConvBwdDataDesc( dnnl::algorithm::convolution_direct, *context_.diff_src_md, *context_.filter_md, *context_.diff_dst_md, convBwdInputDims.strides, @@ -233,7 +233,7 @@ class MklConvBwdInputPrimitive : public MklPrimitive { *context_.filter_md, *context_.diff_dst_md, convBwdInputDims.strides, convBwdInputDims.dilations, convBwdInputDims.padding_left, convBwdInputDims.padding_right, *context_.fwd_pd)); -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 // Create memory using dummy data. context_.diff_src_mem.reset(new memory( diff --git a/tensorflow/core/kernels/mkl/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl/mkl_conv_ops.cc index 6709876b8c80d8..e37c6c2698cd33 100644 --- a/tensorflow/core/kernels/mkl/mkl_conv_ops.cc +++ b/tensorflow/core/kernels/mkl/mkl_conv_ops.cc @@ -39,7 +39,7 @@ using ReorderPd = dnnl::reorder::primitive_desc; namespace tensorflow { -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 #define APPEND_DEPTHWISE(wei_dt, bias_dt, dst_dt, kernel, stride, padding, \ scales_mask, scales) \ append_dw(wei_dt, bias_dt, dst_dt, kernel, stride, padding, scales_mask, \ @@ -74,13 +74,13 @@ namespace tensorflow { #define SCALE wei_scale #define SUMMAND_SCALE_U8(summand_range, output_range) summand_range / 255.0f #define SUMMAND_SCALE_S8(summand_range, output_range) summand_range / 127.0f -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 -#if !defined(ENABLE_ONEDNN_OPENMP) && !defined(ENABLE_ONEDNN_V3) +#if !defined(ENABLE_ONEDNN_OPENMP) && defined(ENABLE_ONEDNN_V2) #define FWD_STREAM , *fwd_stream #else #define FWD_STREAM -#endif // !ENABLE_ONEDNN_OPENMP && !ENABLE_ONEDNN_V3 +#endif // !ENABLE_ONEDNN_OPENMP && ENABLE_ONEDNN_V2 // TODO(intel-tf) Remove this once old API of quantized ops is abandoned namespace quantized_fusions { @@ -284,9 +284,9 @@ class MklConvFwdPrimitive : public MklPrimitive { std::shared_ptr dst_scale_mem; // Desc & primitive desc -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 std::shared_ptr fwd_desc; -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 std::shared_ptr fwd_pd; // Memory desc @@ -325,9 +325,9 @@ class MklConvFwdPrimitive : public MklPrimitive { src_scale_mem(nullptr), wei_scale_mem(nullptr), dst_scale_mem(nullptr), -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 fwd_desc(nullptr), -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 fwd_pd(nullptr), src_md(nullptr), filter_md(nullptr), @@ -372,7 +372,7 @@ class MklConvFwdPrimitive : public MklPrimitive { MklDnnType(), memory::format_tag::any)); } -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 // Create a convolution descriptor context_.fwd_desc.reset(new convolution_forward::desc( prop_kind::forward, dnnl::algorithm::convolution_direct, @@ -385,7 +385,7 @@ class MklConvFwdPrimitive : public MklPrimitive { *context_.src_md, *context_.filter_md, *context_.dst_md, convFwdDims.strides, convFwdDims.dilations, convFwdDims.padding_left, convFwdDims.padding_right)); -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 } if (!convFwdDims.fuse_bn_dims.empty()) { @@ -422,7 +422,7 @@ class MklConvFwdPrimitive : public MklPrimitive { } else if (post_op_param.name == "sum") { DCHECK_EQ(post_op_param.param.size(), 1); float op_scale = post_op_param.param[0]; -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 post_ops.append_sum(op_scale); #else if (post_op_param.dtype != DT_INVALID) { @@ -436,8 +436,8 @@ class MklConvFwdPrimitive : public MklPrimitive { } else { post_ops.append_sum(op_scale); } -#endif //! ENABLE_ONEDNN_V3 -#ifndef ENABLE_ONEDNN_V3 +#endif //! !ENABLE_ONEDNN_V2 +#ifdef ENABLE_ONEDNN_V2 } else if (post_op_param.name == "output_scale") { if (post_op_param.param.size() == 1) { post_ops_attr.set_output_scales(0, post_op_param.param); @@ -470,7 +470,7 @@ class MklConvFwdPrimitive : public MklPrimitive { memory::format_tag::x)); context_.dst_scale_mem.reset( new memory(*context_.dst_scale_md, cpu_engine_, DummyData)); -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 } else if (post_op_param.name == "fuse_bn") { post_ops.append_binary(dnnl::algorithm::binary_sub, *context_.bn_mean_md); @@ -488,7 +488,7 @@ class MklConvFwdPrimitive : public MklPrimitive { } post_ops_attr.set_post_ops(post_ops); } -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 context_.fwd_pd.reset( new ConvFwdPd(*context_.fwd_desc, post_ops_attr, cpu_engine_)); #else @@ -505,7 +505,7 @@ class MklConvFwdPrimitive : public MklPrimitive { convFwdDims.strides, convFwdDims.dilations, convFwdDims.padding_left, convFwdDims.padding_right, post_ops_attr)); } -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 // Create memory primitive based on dummy data context_.src_mem.reset( @@ -530,7 +530,7 @@ class MklConvFwdPrimitive : public MklPrimitive { {DNNL_ARG_BIAS, *context_.bias_mem}, {DNNL_ARG_SCRATCHPAD, *context_.sp_mem}, {DNNL_ARG_DST, *context_.dst_mem}}; -#ifdef ENABLE_ONEDNN_V3 +#ifndef ENABLE_ONEDNN_V2 if (is_scale_set["src"] && is_scale_set["wei"] && is_scale_set["dst"]) { net_args.insert( {{DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, *context_.src_scale_mem}, @@ -538,7 +538,7 @@ class MklConvFwdPrimitive : public MklPrimitive { { DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, *context_.dst_scale_mem }}); } -#endif // ENABLE_ONEDNN_V3 +#endif // !ENABLE_ONEDNN_V2 } else if (!convFwdDims.fuse_bn_dims.empty()) { context_.bn_scale_mem.reset( new memory(*context_.bn_scale_md, cpu_engine_, DummyData)); @@ -566,7 +566,7 @@ class MklConvFwdPrimitive : public MklPrimitive { {DNNL_ARG_WEIGHTS, *context_.filter_mem}, {DNNL_ARG_SCRATCHPAD, *context_.sp_mem}, {DNNL_ARG_DST, *context_.dst_mem}}; -#ifdef ENABLE_ONEDNN_V3 +#ifndef ENABLE_ONEDNN_V2 if (is_scale_set["src"] && is_scale_set["wei"] && is_scale_set["dst"]) { net_args.insert( {{DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, *context_.src_scale_mem}, @@ -574,7 +574,7 @@ class MklConvFwdPrimitive : public MklPrimitive { { DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, *context_.dst_scale_mem }}); } -#endif // ENABLE_ONEDNN_V3 +#endif // !ENABLE_ONEDNN_V2 } context_.fwd_primitives_args.push_back(net_args); context_.fwd_primitives.push_back(*context_.conv_fwd); @@ -663,13 +663,13 @@ class MklConvFwdPrimitiveFactory : public MklPrimitiveFactory { for (auto& param : post_op_param.param) { key_creator.AddAsKey(param); } -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 } else if (post_op_param.name == "output_scale") { #else } else if (post_op_param.name == "src_scale" || post_op_param.name == "wei_scale" || post_op_param.name == "dst_scale") { -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 key_creator.AddAsKey(post_op_param.partial_key); } else if (post_op_param.name == "fuse_bn") { key_creator.AddAsKey(post_op_param.name); @@ -1261,11 +1261,11 @@ class MklConvOp : public OpKernel { MklDnnShape* output_mkl_shape, Tensor** output_tensor) { DCHECK(output_tensor); -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 auto dst_md = conv_prim_desc.dst_desc(); if (!std::is_same::value) { -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 dst_md.data.data_type = static_cast(MklDnnType()); #else @@ -1274,7 +1274,7 @@ class MklConvOp : public OpKernel { dst_md = memory::desc(output_dims_mkl_order, MklDnnType(), MklTensorFormatToMklDnnDataFormat(output_tf_format)); -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 } #else auto dst_md = @@ -1283,7 +1283,7 @@ class MklConvOp : public OpKernel { : memory::desc(conv_prim_desc.dst_desc().get_dims(), MklDnnType(), MklTensorFormatToMklDnnDataFormat(output_tf_format)); -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 // Allocate shape of MKL tensor output_mkl_shape->SetMklTensor(true); @@ -1372,11 +1372,11 @@ class MklConvOp : public OpKernel { string data_format_str_; TensorFormat data_format_; Tensor cached_filter_data_ TF_GUARDED_BY(mu_); -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 Tensor cached_filter_md_ TF_GUARDED_BY(mu_); #else FilterMemoryDesc cached_filter_md_ TF_GUARDED_BY(mu_); -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 // Initialize to values the template is instantiated with bool fuse_biasadd_ = bias_enabled; @@ -1426,7 +1426,7 @@ class MklConvOp : public OpKernel { *filter_tensor = &cached_filter_data_; memory::desc weights_desc = conv_prim_desc.weights_desc(); -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 // There is no tensor format in DNNL 1.x. So we cache the complete filter // descriptor as flat byte array. TensorShape cached_filter_md_shape; @@ -1445,7 +1445,7 @@ class MklConvOp : public OpKernel { weights_desc.get_data_type(), weights_desc.get_dims(), weights_desc.get_inner_blks(), weights_desc.get_inner_idxs(), weights_desc.get_strides()); -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 } void AllocateTensor(OpKernelContext* context, const ConvFwdPd& conv_prim_desc, @@ -1506,12 +1506,12 @@ class MklConvOp : public OpKernel { return; } -#ifdef ENABLE_ONEDNN_V3 +#ifndef ENABLE_ONEDNN_V2 // For now, cache filter only for blocked format if (filter_md.get_format_kind() != memory::format_kind::blocked) { return; } -#endif // ENABLE_ONEDNN_V3 +#endif // !ENABLE_ONEDNN_V2 // Otherwise, cache reordered filter filter.SetUsrMem(filter_md, &filter_tensor); @@ -1527,7 +1527,7 @@ class MklConvOp : public OpKernel { memcpy(cached_filter_data, filter_data, cached_filter_data_size); } -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 // TODO(intel-tf): This function is no longer used and needs to be removed bool AreMemoryDescriptorsEqual(const memory::desc& filter_md, const Tensor& cached_filter_md) { @@ -1545,14 +1545,14 @@ class MklConvOp : public OpKernel { } return true; } -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 Tfilter* GetCachedFilter(OpKernelContext* context, const memory::desc& filter_md) TF_LOCKS_EXCLUDED(mu_) { tf_shared_lock lock(mu_); const Tensor& cached_filter_data = cached_filter_data_; -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 const Tensor& cached_filter_md = cached_filter_md_; // Check if the memory descriptor of the cached weights is the same as @@ -1580,7 +1580,7 @@ class MklConvOp : public OpKernel { const_cast(cached_filter_data.flat().data())); } return nullptr; -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 } }; @@ -1971,19 +1971,19 @@ class MklQuantizedConvOp // If Requantize is fused, we set output_scale as first post op since it is // logically applied before any post op. Then we maintain the order of post // ops according to the order of fused_ops. -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 int idx = fuse_requantize ? 1 : 0; #else post_op_to_idx_["src_scale"] = 0; post_op_to_idx_["wei_scale"] = 1; post_op_to_idx_["dst_scale"] = 2; int idx = 3; -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 for (int i = 0; i < fused_ops_.size(); ++i) { if (fused_ops_[i] == "Requantize") { -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 post_op_to_idx_["output_scale"] = 0; -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 } else if (fused_ops_[i] == "Sum") { post_op_to_idx_["sum"] = idx++; } else if (fused_ops_[i] == "Relu") { @@ -2152,11 +2152,11 @@ class MklQuantizedConvOp std::vector SCALE(depth); float float_input_range = std::max(std::abs(min_input), std::abs(max_input)); -#ifdef ENABLE_ONEDNN_V3 +#ifndef ENABLE_ONEDNN_V2 float int_input_limit = std::is_same::value ? 255.0f : 127.0f; const float src_scale = float_input_range / int_input_limit; -#endif // ENABLE_ONEDNN_V3 +#endif // !ENABLE_ONEDNN_V2 if (std::is_same::value || std::is_same::value) { // min_freezed_output and max_freezed_output are the actual range @@ -2178,18 +2178,18 @@ class MklQuantizedConvOp float float_filter_range = std::max(std::abs(min_filter[i]), std::abs(max_filter[i])); // To understand the scaling, please see mkl_requantize_ops_test. -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 scales[i] = int_output_limit * float_input_range * float_filter_range / (int_const_scale_limit * float_output_range); #else wei_scale[i] = float_filter_range / 127.0; -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 } // we are creating a partial key here to use with primitive key caching to // improve key creation performance. Instead of using actual values we are // using the pointers for min/max_filter_vector, and this works since the // filter vector here is a constant. -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 FactoryKeyCreator param_key; param_key.AddAsKey(min_input); param_key.AddAsKey(max_input); @@ -2209,9 +2209,9 @@ class MklQuantizedConvOp dnnl::algorithm::undef, {dst_scale}, dst_param_key.GetKey()}; -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 } else { -#ifdef ENABLE_ONEDNN_V3 +#ifndef ENABLE_ONEDNN_V2 if (!std::is_same::value) TF_CHECK_OK(absl::FailedPreconditionError( "Output datatype is expected to be qint32.")); @@ -2234,10 +2234,10 @@ class MklQuantizedConvOp dnnl::algorithm::undef, {dst_scale}, dst_param_key.GetKey()}; -#endif // ENABLE_ONEDNN_V3 +#endif // !ENABLE_ONEDNN_V2 } -#ifdef ENABLE_ONEDNN_V3 +#ifndef ENABLE_ONEDNN_V2 FactoryKeyCreator src_param_key; src_param_key.AddAsKey(min_input); src_param_key.AddAsKey(max_input); @@ -2251,7 +2251,7 @@ class MklQuantizedConvOp src_param_key.GetKey()}; params.post_op_params[post_op_to_idx_["wei_scale"]] = { "wei_scale", dnnl::algorithm::undef, wei_scale, wei_param_key.GetKey()}; -#endif // ENABLE_ONEDNN_V3 +#endif // !ENABLE_ONEDNN_V2 if (this->get_fuse_add()) { // Calculate the scale (beta in oneDNN api term) for sum DataType summand_dt = this->input_type(this->get_input_add_idx()); @@ -2326,9 +2326,9 @@ class MklQuantizedConvOp dnnl::algorithm::undef, {1.0}, "", -#ifdef ENABLE_ONEDNN_V3 +#ifndef ENABLE_ONEDNN_V2 summand_dt -#endif // ENABLE_ONEDNN_V3 +#endif // !ENABLE_ONEDNN_V2 }; } } @@ -2376,7 +2376,7 @@ class MklQuantizedConvOp "Summand cannot be forwarded in the current fusion.")); return; } -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 MklConvOp< Device, Tinput, /*Tfilter*/ qint8, Tbias, Toutput, Ttemp_output, /*Tpadding*/ int32, @@ -2411,7 +2411,7 @@ class MklQuantizedConvOp std::max(std::abs(max_filter[i]), std::abs(min_filter[i]))); } dnnl::primitive_attr reorder_attr; -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 if (depth == 1) { reorder_attr.set_output_scales(0, scales); } else { @@ -2424,7 +2424,7 @@ class MklQuantizedConvOp reorder_attr.set_scales_mask(DNNL_ARG_SRC, 0); reorder_attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0); reorder_attr.set_scales_mask(DNNL_ARG_DST, 0); -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 auto summand_md = memory::desc(output_dims_mkl_order, MklDnnType(), memory::format_tag::nhwc); void* summand_buf = @@ -2456,7 +2456,7 @@ class MklQuantizedConvOp absl::InvalidArgumentError( "Summand cannot be forwarded in the current fusion.")); -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 } } @@ -2466,7 +2466,7 @@ class MklQuantizedConvOp if (!this->get_fuse_biasadd()) { return nullptr; } -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 if (std::is_same::value) { return static_cast( const_cast(bias_tensor.flat().data())); @@ -2502,7 +2502,7 @@ class MklQuantizedConvOp } if (!is_bias_const_ || IsBiasCacheEmpty(context) || !scales_are_valid) { dnnl::primitive_attr bias_attr; -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 if (depth == 1) { bias_attr.set_output_scales(0, scales_); } else { @@ -2515,7 +2515,7 @@ class MklQuantizedConvOp bias_attr.set_scales_mask(DNNL_ARG_SRC, 0); bias_attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0); bias_attr.set_scales_mask(DNNL_ARG_DST, 0); -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 auto bias_md = memory::desc({static_cast(bias_tensor.NumElements())}, MklDnnType(), memory::format_tag::x); @@ -2641,7 +2641,7 @@ class MklQuantizedConvOp } return GetCachedBias(context); -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 } bool is_bias_const_; diff --git a/tensorflow/core/kernels/mkl/mkl_conv_ops.h b/tensorflow/core/kernels/mkl/mkl_conv_ops.h index 0384df4b309285..6abb2862d35e6d 100644 --- a/tensorflow/core/kernels/mkl/mkl_conv_ops.h +++ b/tensorflow/core/kernels/mkl/mkl_conv_ops.h @@ -48,11 +48,11 @@ using dnnl::stream; namespace tensorflow { -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 // Op descriptor is no longer supported in oneDNN v3.x. Instead, primitive // descriptor will directly accept primitive parameters during creation. using ConvFwdDesc = dnnl::convolution_forward::desc; -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 using ConvFwdPd = dnnl::convolution_forward::primitive_desc; class MklDnnConvUtil { diff --git a/tensorflow/core/kernels/mkl/mkl_dequantize_op.cc b/tensorflow/core/kernels/mkl/mkl_dequantize_op.cc index a41afd657824a4..957612606c32d4 100644 --- a/tensorflow/core/kernels/mkl/mkl_dequantize_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_dequantize_op.cc @@ -152,7 +152,7 @@ class MklDequantizeOp : public OpKernel { std::vector scales; scales.push_back(scale_factor); primitive_attr attr; -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 attr.set_output_scales(0, scales); #else attr.set_scales_mask(DNNL_ARG_SRC, 0); @@ -160,7 +160,7 @@ class MklDequantizeOp : public OpKernel { MklDnnType(), memory::format_tag::x}, cpu_engine, scales.data()); -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 std::vector net; // Create reorder primitive and then execute. @@ -169,7 +169,7 @@ class MklDequantizeOp : public OpKernel { dst.GetUsrMem()->get_desc(), attr); net.push_back(reorder(reorder_pd)); std::vector> reorder_net_args; -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 reorder_net_args.push_back({{DNNL_ARG_FROM, *src.GetUsrMem()}, { DNNL_ARG_TO, *dst.GetUsrMem() }}); @@ -178,7 +178,7 @@ class MklDequantizeOp : public OpKernel { {{DNNL_ARG_FROM, *src.GetUsrMem()}, {DNNL_ARG_TO, *dst.GetUsrMem()}, {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, scale_mem}}); -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 execute_primitives(net, reorder_stream, reorder_net_args); } catch (dnnl::error& e) { string error_msg = "Status: " + std::to_string(e.status) + diff --git a/tensorflow/core/kernels/mkl/mkl_eltwise_activation_base_op.h b/tensorflow/core/kernels/mkl/mkl_eltwise_activation_base_op.h index 9bfddc2f1d0991..ad1f251b9f4022 100644 --- a/tensorflow/core/kernels/mkl/mkl_eltwise_activation_base_op.h +++ b/tensorflow/core/kernels/mkl/mkl_eltwise_activation_base_op.h @@ -43,11 +43,11 @@ using dnnl::stream; using EltwiseFwdActivationPd = dnnl::eltwise_forward::primitive_desc; namespace tensorflow { -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 #define GET_MEMORY_DESC(md) md.data #else #define GET_MEMORY_DESC(md) md -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 // TODO(tf-onednn): Consolidate this class with `MklEltWiseFwdParams` // in `mkl_relu_op.cc`. @@ -59,23 +59,23 @@ class MklEltwiseFwdActivationParams { public: memory::dims src_dims; memory::desc src_md; -#ifdef ENABLE_ONEDNN_V3 +#ifndef ENABLE_ONEDNN_V2 memory::desc dst_md; -#endif // ENABLE_ONEDNN_V3 +#endif // !ENABLE_ONEDNN_V2 algorithm alg_kind; float alpha; float beta; MklEltwiseFwdActivationParams(memory::dims src_dims, memory::desc src_md, -#ifdef ENABLE_ONEDNN_V3 +#ifndef ENABLE_ONEDNN_V2 memory::desc dst_md, -#endif // ENABLE_ONEDNN_V3 +#endif // !ENABLE_ONEDNN_V2 algorithm alg_kind, float alpha, float beta) : src_dims(src_dims), src_md(src_md), -#ifdef ENABLE_ONEDNN_V3 +#ifndef ENABLE_ONEDNN_V2 dst_md(dst_md), -#endif // ENABLE_ONEDNN_V3 +#endif // !ENABLE_ONEDNN_V2 alg_kind(alg_kind), alpha(alpha), beta(beta) { @@ -134,9 +134,9 @@ class MklEltwiseFwdActivationPrimitive : public MklPrimitive { std::shared_ptr dst_mem; // desc & primitive desc -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 std::shared_ptr fwd_desc; -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 std::shared_ptr fwd_pd; // memory desc @@ -156,9 +156,9 @@ class MklEltwiseFwdActivationPrimitive : public MklPrimitive { EltwiseFwdActivationContext() : src_mem(nullptr), dst_mem(nullptr), -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 fwd_desc(nullptr), -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 fwd_pd(nullptr), src_md(nullptr), dst_md(nullptr), @@ -174,7 +174,7 @@ class MklEltwiseFwdActivationPrimitive : public MklPrimitive { context_.src_mpd.reset(new memory::desc(*context_.src_md)); // Create an eltwise forward descriptor and primitive descriptor -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 context_.fwd_desc.reset(new eltwise_forward::desc( prop_kind::forward, fwdParams.alg_kind, *context_.src_md, fwdParams.alpha, fwdParams.beta)); @@ -185,7 +185,7 @@ class MklEltwiseFwdActivationPrimitive : public MklPrimitive { context_.fwd_pd.reset(new EltwiseFwdActivationPd( cpu_engine_, prop_kind::forward, fwdParams.alg_kind, *context_.src_md, *context_.dst_md, fwdParams.alpha, fwdParams.beta)); -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 auto fwd_pd = context_.fwd_pd.get(); // Create memory primitive based on dummy data @@ -301,15 +301,15 @@ class MklEltwiseFwdActivationOpBase : public OpKernel { // Create blocked memory descriptor src_md = MklDnnData::CreateBlockedMemDesc(src_dims, src_strides); -#ifdef ENABLE_ONEDNN_V3 +#ifndef ENABLE_ONEDNN_V2 memory::desc dst_md = src_md; -#endif // ENABLE_ONEDNN_V3 +#endif // !ENABLE_ONEDNN_V2 // Try to get an eltwise forward primitive from caching pool MklEltwiseFwdActivationParams fwdParams(src_dims, src_md, -#ifdef ENABLE_ONEDNN_V3 +#ifndef ENABLE_ONEDNN_V2 dst_md, -#endif // ENABLE_ONEDNN_V3 +#endif // !ENABLE_ONEDNN_V2 alg_kind, alpha_, beta_); MklEltwiseFwdActivationPrimitive* eltwise_fwd = MklEltwiseFwdActivationPrimitiveFactory::Get(fwdParams); diff --git a/tensorflow/core/kernels/mkl/mkl_fused_batch_norm_op.cc b/tensorflow/core/kernels/mkl/mkl_fused_batch_norm_op.cc index 9d4736fc6a83a8..1ae5080895e320 100644 --- a/tensorflow/core/kernels/mkl/mkl_fused_batch_norm_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_fused_batch_norm_op.cc @@ -41,7 +41,7 @@ using BatchNormBwdPd = dnnl::batch_normalization_backward::primitive_desc; namespace tensorflow { -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 #define FORWARD_INFERENCE prop_kind::forward_scoring #define GET_DIFF_SCALE_DATA_BUFFER diff_scale_shift_data #define GET_DIFF_SCALE_SHIFT_DATA_BUFFERS diff_scale_shift_data @@ -63,7 +63,7 @@ namespace tensorflow { #define SCALE_SHIFT_NET_ARGS \ {DNNL_ARG_SCALE, *context_.scale_mem}, { DNNL_ARG_SHIFT, *context_.shift_mem } #define SET_MKL_LAYOUT(md) SetMklLayout(md) -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 using CPUDevice = Eigen::ThreadPoolDevice; @@ -77,16 +77,16 @@ struct MklBatchNormFwdParams { TensorFormat data_format; FusedBNActivationMode activation_mode; memory::desc src_md; -#ifdef ENABLE_ONEDNN_V3 +#ifndef ENABLE_ONEDNN_V2 memory::desc dst_md; -#endif // ENABLE_ONEDNN_V3 +#endif // !ENABLE_ONEDNN_V2 MklBatchNormFwdParams(const memory::dims& src_dims, int depth, float eps, bool training, TensorFormat data_format, memory::desc src_md, -#ifdef ENABLE_ONEDNN_V3 +#ifndef ENABLE_ONEDNN_V2 memory::desc dst_md, -#endif // ENABLE_ONEDNN_V3 +#endif // !ENABLE_ONEDNN_V2 FusedBNActivationMode activation_mode) : src_dims(src_dims), depth(depth), @@ -94,14 +94,14 @@ struct MklBatchNormFwdParams { training(training), data_format(data_format), activation_mode(activation_mode), -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 src_md(src_md) { } #else src_md(src_md), dst_md(dst_md) { } -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 }; template @@ -123,18 +123,18 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive { // dst_data: output data buffer of dst // mean_data: output data buffer of means // variance_data: output data buffer of variances -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 void Execute(const T* src_data, const U* scale_shift_data, T* dst_data, #else void Execute(const T* src_data, const U* scale_data, const U* shift_data, T* dst_data, -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 U* mean_data, U* variance_data, std::shared_ptr fwd_stream, U* workspace_data) { #ifdef DNNL_AARCH64_USE_ACL mutex_lock lock(primitive_execution_mu_); #endif -#if !defined(ENABLE_ONEDNN_OPENMP) && !defined(ENABLE_ONEDNN_V3) +#if !defined(ENABLE_ONEDNN_OPENMP) && defined(ENABLE_ONEDNN_V2) // TODO(intel-tf): Create a common function and avoid the duplicate code context_.src_mem->set_data_handle( static_cast(const_cast(src_data)), *fwd_stream); @@ -161,7 +161,7 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive { context_.dst_mem->set_data_handle(static_cast(dst_data)); if (IS_SCALE_AND_SHIFT_FLAG_SET) { -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 context_.scale_shift_mem->set_data_handle( static_cast(const_cast(scale_shift_data))); #else @@ -169,7 +169,7 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive { static_cast(const_cast(scale_data))); context_.shift_mem->set_data_handle( static_cast(const_cast(shift_data))); -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 } if ((context_.pkind == prop_kind::forward_training) || @@ -180,7 +180,7 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive { if (workspace_data != nullptr) { context_.ws_mem->set_data_handle(workspace_data); } -#endif // !ENABLE_ONEDNN_OPENMP && !ENABLE_ONEDNN_V3 +#endif // !ENABLE_ONEDNN_OPENMP && ENABLE_ONEDNN_V2 // Execute batch-normalization forward primitives. execute_primitives(context_.fwd_primitives, fwd_stream, context_.net_args); @@ -189,12 +189,12 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive { context_.dst_mem->set_data_handle(DummyData); if (IS_SCALE_AND_SHIFT_FLAG_SET) { -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 context_.scale_shift_mem->set_data_handle(DummyData); #else context_.scale_mem->set_data_handle(DummyData); context_.shift_mem->set_data_handle(DummyData); -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 } if ((context_.pkind == prop_kind::forward_training) || @@ -225,12 +225,12 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive { // Inputs/outputs memory. std::shared_ptr src_mem; -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 std::shared_ptr scale_shift_mem; #else std::shared_ptr scale_mem; std::shared_ptr shift_mem; -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 std::shared_ptr dst_mem; std::shared_ptr mean_mem; std::shared_ptr variance_mem; @@ -249,12 +249,12 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive { : flags(0), pkind(prop_kind::forward_training), src_mem(nullptr), -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 scale_shift_mem(nullptr), #else scale_mem(nullptr), shift_mem(nullptr), -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 dst_mem(nullptr), mean_mem(nullptr), variance_mem(nullptr), @@ -275,7 +275,7 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive { // Memory descriptor auto src_md = fwdParams.src_md; // Create forward BatchNorm descriptor and primitive descriptor. -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 auto fwd_desc = batch_normalization_forward::desc( context_.pkind, src_md, fwdParams.eps, static_cast(context_.flags)); @@ -286,7 +286,7 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive { context_.fwd_pd.reset(new BatchNormFwdPd( cpu_engine_, context_.pkind, src_md, dst_md, fwdParams.eps, static_cast(context_.flags))); -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 // Create memory primitive based on dummy data context_.src_mem.reset( @@ -296,7 +296,7 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive { memory::dims m_dims = {1, fwdParams.depth}; if (IS_SCALE_AND_SHIFT_FLAG_SET) { -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 memory::dims s_dims = {2, fwdParams.depth}; context_.scale_shift_mem.reset( new memory({{s_dims}, MklDnnType(), memory::format_tag::nc}, @@ -309,7 +309,7 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive { context_.shift_mem.reset( new memory({{s_dims}, MklDnnType(), memory::format_tag::x}, cpu_engine_, DummyData)); -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 } if (fwdParams.training || (IS_SET(use_global_stats))) { @@ -482,18 +482,18 @@ struct MklBatchNormBwdParams { bool training; TensorFormat data_format; memory::desc src_md; -#ifdef ENABLE_ONEDNN_V3 +#ifndef ENABLE_ONEDNN_V2 memory::desc dst_md; memory::desc diff_src_md; -#endif // ENABLE_ONEDNN_V3 +#endif // !ENABLE_ONEDNN_V2 memory::desc diff_dst_md; MklBatchNormBwdParams(memory::dims src_dims, memory::dims diff_dst_dims, int depth, float eps, bool training, TensorFormat data_format, memory::desc src_md, -#ifdef ENABLE_ONEDNN_V3 +#ifndef ENABLE_ONEDNN_V2 memory::desc dst_md, memory::desc diff_src_md, -#endif // ENABLE_ONEDNN_V3 +#endif // !ENABLE_ONEDNN_V2 memory::desc diff_dst_md) : src_dims(src_dims), diff_dst_dims(diff_dst_dims), @@ -502,10 +502,10 @@ struct MklBatchNormBwdParams { training(training), data_format(data_format), src_md(src_md), -#ifdef ENABLE_ONEDNN_V3 +#ifndef ENABLE_ONEDNN_V2 dst_md(dst_md), diff_src_md(diff_src_md), -#endif // ENABLE_ONEDNN_V3 +#endif // !ENABLE_ONEDNN_V2 diff_dst_md(diff_dst_md) { } }; @@ -533,19 +533,19 @@ class MklFusedBatchNormBwdPrimitive : public MklPrimitive { // intermediate results is not implemented // on CPU as of now. void Execute(const T* src_data, const U* mean_data, const U* variance_data, -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 const T* diff_dst_data, const U* scale_shift_data, T* diff_src_data, U* diff_scale_shift_data, U* res_space_data, #else // oneDNN v3.x does not require 'shift_data' const T* diff_dst_data, const U* scale_data, T* diff_src_data, U* diff_scale_data, U* diff_shift_data, U* res_space_data, -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 std::shared_ptr bwd_stream) { #ifdef DNNL_AARCH64_USE_ACL mutex_lock lock(primitive_execution_mu_); #endif -#if !defined(ENABLE_ONEDNN_OPENMP) && !defined(ENABLE_ONEDNN_V3) +#if !defined(ENABLE_ONEDNN_OPENMP) && defined(ENABLE_ONEDNN_V2) // TODO(intel-tf): Create a common function and avoid the duplicate code context_.src_mem->set_data_handle( static_cast(const_cast(src_data)), *bwd_stream); @@ -576,7 +576,7 @@ class MklFusedBatchNormBwdPrimitive : public MklPrimitive { static_cast(const_cast(diff_dst_data))); if (IS_SCALE_AND_SHIFT_FLAG_SET) { -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 context_.scale_shift_mem->set_data_handle( static_cast(const_cast(scale_shift_data))); context_.diff_scale_shift_mem->set_data_handle( @@ -588,11 +588,11 @@ class MklFusedBatchNormBwdPrimitive : public MklPrimitive { static_cast(diff_scale_data)); context_.diff_shift_mem->set_data_handle( static_cast(diff_shift_data)); -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 } context_.diff_src_mem->set_data_handle(static_cast(diff_src_data)); -#endif // !ENABLE_ONEDNN_OPENMP && !ENABLE_ONEDNN_V3 +#endif // !ENABLE_ONEDNN_OPENMP && ENABLE_ONEDNN_V2 // Execute backward batch-normalization primitives. DCHECK_EQ(context_.bwd_primitives.size(), context_.net_args.size()); execute_primitives(context_.bwd_primitives, bwd_stream, context_.net_args); @@ -603,14 +603,14 @@ class MklFusedBatchNormBwdPrimitive : public MklPrimitive { context_.variance_mem->set_data_handle(DummyData); context_.diff_dst_mem->set_data_handle(DummyData); if (IS_SCALE_AND_SHIFT_FLAG_SET) { -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 context_.scale_shift_mem->set_data_handle(DummyData); context_.diff_scale_shift_mem->set_data_handle(DummyData); #else context_.scale_mem->set_data_handle(DummyData); context_.diff_scale_mem->set_data_handle(DummyData); context_.diff_shift_mem->set_data_handle(DummyData); -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 } context_.diff_src_mem->set_data_handle(DummyData); } @@ -631,14 +631,14 @@ class MklFusedBatchNormBwdPrimitive : public MklPrimitive { std::shared_ptr mean_mem; std::shared_ptr variance_mem; std::shared_ptr diff_dst_mem; -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 std::shared_ptr scale_shift_mem; std::shared_ptr diff_scale_shift_mem; #else std::shared_ptr scale_mem; std::shared_ptr diff_scale_mem; std::shared_ptr diff_shift_mem; -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 std::shared_ptr diff_src_mem; // Backward batch-normalization primitive descriptor. @@ -655,14 +655,14 @@ class MklFusedBatchNormBwdPrimitive : public MklPrimitive { mean_mem(nullptr), variance_mem(nullptr), diff_dst_mem(nullptr), -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 scale_shift_mem(nullptr), diff_scale_shift_mem(nullptr), #else scale_mem(nullptr), diff_scale_mem(nullptr), diff_shift_mem(nullptr), -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 diff_src_mem(nullptr) { } }; @@ -682,19 +682,19 @@ class MklFusedBatchNormBwdPrimitive : public MklPrimitive { memory::format_tag::nc); auto mean_desc = memory::desc({1, bwdParams.depth}, MklDnnType(), memory::format_tag::nc); -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 auto scale_shift_desc = memory::desc({2, bwdParams.depth}, MklDnnType(), memory::format_tag::nc); #else auto scale_shift_desc = memory::desc({bwdParams.depth}, MklDnnType(), memory::format_tag::x); -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 auto diff_scale_shift_desc = scale_shift_desc; // Forward batch-normalization descriptor and primitive descriptor. // Adding this back due to type difference with context.flags auto bn_flags = GetBatchNormFlags(bwdParams); -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 auto fwd_desc = batch_normalization_forward::desc( prop_kind::forward_training, src_md, bwdParams.eps, static_cast(bn_flags)); @@ -719,7 +719,7 @@ class MklFusedBatchNormBwdPrimitive : public MklPrimitive { cpu_engine_, prop_kind::backward, diff_src_md, diff_dst_md, src_md, bwdParams.eps, static_cast(bn_flags), fwd_pd)); -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 // Create memory primitives. context_.src_mem.reset(new memory(src_md, cpu_engine_, DummyData)); @@ -728,7 +728,7 @@ class MklFusedBatchNormBwdPrimitive : public MklPrimitive { context_.variance_mem.reset( new memory(variance_desc, cpu_engine_, DummyData)); context_.mean_mem.reset(new memory(mean_desc, cpu_engine_, DummyData)); -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 context_.scale_shift_mem.reset( new memory(scale_shift_desc, cpu_engine_, DummyData)); context_.diff_scale_shift_mem.reset( @@ -740,7 +740,7 @@ class MklFusedBatchNormBwdPrimitive : public MklPrimitive { new memory(diff_scale_shift_desc, cpu_engine_, DummyData)); context_.diff_shift_mem.reset( new memory(diff_scale_shift_desc, cpu_engine_, DummyData)); -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 context_.diff_src_mem.reset(new memory(src_md, cpu_engine_, DummyData)); context_.bn_bwd.reset(new batch_normalization_backward(*context_.bwd_pd)); @@ -750,7 +750,7 @@ class MklFusedBatchNormBwdPrimitive : public MklPrimitive { {DNNL_ARG_VARIANCE, *context_.variance_mem}, {DNNL_ARG_DIFF_DST, *context_.diff_dst_mem}, {DNNL_ARG_DIFF_SRC, *context_.diff_src_mem}, -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 {DNNL_ARG_SCALE_SHIFT, *context_.scale_shift_mem}, { DNNL_ARG_DIFF_SCALE_SHIFT, *context_.diff_scale_shift_mem }}); @@ -759,7 +759,7 @@ class MklFusedBatchNormBwdPrimitive : public MklPrimitive { {DNNL_ARG_SCALE, *context_.scale_mem}, {DNNL_ARG_DIFF_SCALE, *context_.diff_scale_mem}, {DNNL_ARG_DIFF_SHIFT, *context_.diff_shift_mem}}); -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 context_.bwd_primitives.push_back(*context_.bn_bwd); } @@ -977,12 +977,12 @@ class MklFusedBatchNormOp : public OpKernel { Tensor* reserved_space_tensor = nullptr; MklDnnData src(&cpu_engine_); -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 MklDnnData scale_shift(&cpu_engine_); #else MklDnnData scale(&cpu_engine_); MklDnnData shift(&cpu_engine_); -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 MklDnnData wksp(&cpu_engine_); memory::format_tag dnn_fmt; @@ -1009,17 +1009,17 @@ class MklFusedBatchNormOp : public OpKernel { auto src_md = dnn_shape_src.IsMklTensor() ? dnn_shape_src.GetMklLayout() : memory::desc(src_dims, MklDnnType(), dnn_fmt); -#ifdef ENABLE_ONEDNN_V3 +#ifndef ENABLE_ONEDNN_V2 auto dst_md = memory::desc(src_dims, MklDnnType(), dnn_fmt); -#endif // ENABLE_ONEDNN_V3 +#endif // !ENABLE_ONEDNN_V2 MklBatchNormFwdParams fwdParams(src_dims, depth_, epsilon_, is_training_, -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 tensor_format_, src_md, activation_mode_); #else tensor_format_, src_md, dst_md, activation_mode_); -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 // Create the oneDNN wrapper over Eigen threadpool and set max threads // in oneDNN. @@ -1062,7 +1062,7 @@ class MklFusedBatchNormOp : public OpKernel { else SetMeanVariance(est_mean_tensor, est_variance_tensor); -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 // oneDNN packs scale & shift as a combined array in float32 type // ...... scale_shift.AllocateBuffer(2 * depth_ * sizeof(U)); @@ -1083,7 +1083,7 @@ class MklFusedBatchNormOp : public OpKernel { const U* shift_tf = shift_tensor.flat().data(); std::memcpy(scale_data, scale_tf, depth_ * sizeof(U)); std::memcpy(shift_data, shift_tf, depth_ * sizeof(U)); -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 char* saved_mean_data_tf = reinterpret_cast(saved_mean_tensor->flat().data()); @@ -1124,12 +1124,12 @@ class MklFusedBatchNormOp : public OpKernel { AllocateOutputSetMklShape(context, kDstIndex, &dst_tensor, tf_shape_dst, dnn_shape_dst, native_format); -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 U* scale_shift_op_data = scale_shift_data; #else U* scale_op_data = scale_data; U* shift_op_data = shift_data; -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 U* mean_op_data = saved_mean_tensor->flat().data(); U* variance_op_data = saved_variance_tensor->flat().data(); T* dst_data = dst_tensor->flat().data(); @@ -1138,12 +1138,12 @@ class MklFusedBatchNormOp : public OpKernel { std::shared_ptr fwd_cpu_stream; fwd_cpu_stream.reset(CreateStream(&eigen_tp, bn_fwd->GetEngine())); -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 bn_fwd->Execute(src_data, scale_shift_op_data, dst_data, mean_op_data, #else bn_fwd->Execute(src_data, scale_op_data, shift_op_data, dst_data, mean_op_data, -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 variance_op_data, fwd_cpu_stream, ws_data); float adjust_factor = 1.0; if (is_training_) { @@ -1464,14 +1464,14 @@ class MklFusedBatchNormGradOp : public OpKernel { MklDnnData src(&cpu_engine_); MklDnnData diff_dst(&cpu_engine_); -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 MklDnnData scale_shift(&cpu_engine_); MklDnnData diff_scale_shift(&cpu_engine_); #else MklDnnData scale(&cpu_engine_); MklDnnData diff_scale(&cpu_engine_); MklDnnData diff_shift(&cpu_engine_); -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 memory::dims src_dims = dnn_shape_src.IsMklTensor() @@ -1492,11 +1492,11 @@ class MklFusedBatchNormGradOp : public OpKernel { dnn_shape_diff_dst.IsMklTensor() ? dnn_shape_diff_dst.GetMklLayout() : memory::desc(diff_dst_dims, MklDnnType(), dnn_fmt); -#ifdef ENABLE_ONEDNN_V3 +#ifndef ENABLE_ONEDNN_V2 memory::desc dst_md = memory::desc(src_dims, MklDnnType(), dnn_fmt); memory::desc diff_src_md = memory::desc(diff_dst_dims, MklDnnType(), dnn_fmt); -#endif // ENABLE_ONEDNN_V3 +#endif // !ENABLE_ONEDNN_V2 MklDnnData reorder_src(&cpu_engine_); MklDnnData reorder_diff_dst(&cpu_engine_); @@ -1525,7 +1525,7 @@ class MklFusedBatchNormGradOp : public OpKernel { } } -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 // scale_shift -- oneDNN packs scales/shifts as scale_shift in order // of scale, ..., scale, shift, ...., shift scale_shift.AllocateBuffer(2 * depth_ * sizeof(U)); @@ -1546,13 +1546,13 @@ class MklFusedBatchNormGradOp : public OpKernel { std::memcpy(scale_data_tf, scale_tf, depth_ * sizeof(U)); diff_scale.AllocateBuffer(depth_ * sizeof(U)); diff_shift.AllocateBuffer(depth_ * sizeof(U)); -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 MklBatchNormBwdParams bwdParams(src_dims, diff_dst_dims, depth_, epsilon_, is_training_, tensor_format_, src_md, -#ifdef ENABLE_ONEDNN_V3 +#ifndef ENABLE_ONEDNN_V2 dst_md, diff_src_md, -#endif // ENABLE_ONEDNN_V3 +#endif // !ENABLE_ONEDNN_V2 diff_dst_md); Eigen::ThreadPoolInterface* eigen_interface = EigenThreadPoolFromTfContext(context); @@ -1600,7 +1600,7 @@ class MklFusedBatchNormGradOp : public OpKernel { static_cast(const_cast(saved_mean_tensor.flat().data())); U* variance_data = static_cast( const_cast(saved_variance_tensor.flat().data())); -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 U* scale_shift_data = scale_shift_data_tf; U* diff_scale_shift_data = static_cast(diff_scale_shift.GetAllocatedBuffer()); @@ -1608,7 +1608,7 @@ class MklFusedBatchNormGradOp : public OpKernel { U* scale_data = scale_data_tf; U* diff_scale_data = static_cast(diff_scale.GetAllocatedBuffer()); U* diff_shift_data = static_cast(diff_shift.GetAllocatedBuffer()); -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 T* diff_src_data = static_cast(diff_src_tensor->flat().data()); U* res_space_data = diff --git a/tensorflow/core/kernels/mkl/mkl_fused_instance_norm_op.cc b/tensorflow/core/kernels/mkl/mkl_fused_instance_norm_op.cc index b33be9350aedc7..c7dd9c585e04b6 100644 --- a/tensorflow/core/kernels/mkl/mkl_fused_instance_norm_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_fused_instance_norm_op.cc @@ -104,11 +104,11 @@ class MklFusedInstanceNormOp : public OpKernel { } auto src_md = memory::desc(src_dims, MklDnnType(), tag); -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 #define NUM_DUPLICATE 2 #else #define NUM_DUPLICATE 1 -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 memory::dims scale_shift_dims = { static_cast(NUM_DUPLICATE * num_elements_scale)}; auto scale_shift_md = memory::desc(scale_shift_dims, MklDnnType(), @@ -116,7 +116,7 @@ class MklFusedInstanceNormOp : public OpKernel { int64_t tensor_shape = scale_shift_md.get_size() / sizeof(float); #undef NUM_DUPLICATE -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 Tensor scale_shift_tensor; OP_REQUIRES_OK( ctx, ctx->allocate_temp(DataTypeToEnum::v(), {tensor_shape}, @@ -150,12 +150,12 @@ class MklFusedInstanceNormOp : public OpKernel { num_elements_scale, scale_fp32_buf, shift_fp32_buf); auto scale_mem = memory(scale_shift_md, cpu_engine_, scale_fp32_buf); auto shift_mem = memory(scale_shift_md, cpu_engine_, shift_fp32_buf); -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 batch_normalization_forward::primitive_desc bnorm_pd; if (fuse_activation_) { dnnl::post_ops post_ops; dnnl::primitive_attr post_ops_attr; -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 post_ops.append_eltwise(1.0, dnnl::algorithm::eltwise_relu, leakyrelu_alpha_, 0.0); post_ops_attr.set_post_ops(post_ops); @@ -169,16 +169,16 @@ class MklFusedInstanceNormOp : public OpKernel { cpu_engine_, prop_kind::forward_inference, src_md, src_md, epsilon_, normalization_flags::use_scale | normalization_flags::use_shift, post_ops_attr); -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 } else { -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 bnorm_pd = batch_normalization_forward::primitive_desc(bnorm_desc, cpu_engine_); #else bnorm_pd = batch_normalization_forward::primitive_desc( cpu_engine_, prop_kind::forward_inference, src_md, src_md, epsilon_, normalization_flags::use_scale | normalization_flags::use_shift); -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 } auto bnorm_prim = batch_normalization_forward(bnorm_pd); @@ -198,12 +198,12 @@ class MklFusedInstanceNormOp : public OpKernel { std::unordered_map bnorm_args; bnorm_args.insert({DNNL_ARG_SRC, *src_mem_ptr}); bnorm_args.insert({DNNL_ARG_DST, *dst_mem_ptr}); -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 bnorm_args.insert({DNNL_ARG_SCALE_SHIFT, scale_shift_mem}); #else bnorm_args.insert({DNNL_ARG_SCALE, scale_mem}); bnorm_args.insert({DNNL_ARG_SHIFT, shift_mem}); -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 // Perform batchnorm computation for each batch in input for (int i = 0; i < batch_size; i++) { @@ -287,14 +287,14 @@ class MklFusedInstanceNormOp : public OpKernel { auto data_size = sizeof(float) * num_elements; void* scale_buf_dst = fp32_scale_or_combine_buf; void* shift_buf_dst = nullptr; -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 shift_buf_dst = static_cast(fp32_scale_or_combine_buf) + data_size; (void)fp32_shift_buf; #else OP_REQUIRES(ctx, (fp32_shift_buf != nullptr), absl::InvalidArgumentError("Invalid shift buffer")); shift_buf_dst = fp32_shift_buf; -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 if (std::is_same::value) { memcpy(scale_buf_dst, scale_buf_src, data_size); diff --git a/tensorflow/core/kernels/mkl/mkl_kernel_util.h b/tensorflow/core/kernels/mkl/mkl_kernel_util.h index 2e7ff537cfa000..4dd6c17dd55f84 100644 --- a/tensorflow/core/kernels/mkl/mkl_kernel_util.h +++ b/tensorflow/core/kernels/mkl/mkl_kernel_util.h @@ -49,7 +49,7 @@ class MklTestingUtil { } }; -#ifdef ENABLE_ONEDNN_V3 +#ifndef ENABLE_ONEDNN_V2 // Since oneDNN v3.x exposes only an opaque memory descriptor, it is no longer // possible to cache the entire filter memory descriptor as is. So we store // all relevant information about it in the following class. @@ -94,7 +94,7 @@ class FilterMemoryDesc { memory::dims inner_idxs_; memory::dims strides_; }; -#endif // ENABLE_ONEDNN_V3 +#endif // !ENABLE_ONEDNN_V2 } // namespace tensorflow #endif // INTEL_MKL diff --git a/tensorflow/core/kernels/mkl/mkl_layer_norm_op.cc b/tensorflow/core/kernels/mkl/mkl_layer_norm_op.cc index ae5ad08b3f4393..fd92b8309cfe9f 100644 --- a/tensorflow/core/kernels/mkl/mkl_layer_norm_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_layer_norm_op.cc @@ -79,7 +79,7 @@ class MklLayerNormOp : public OpKernel { static_cast(const_cast(src_tensor.flat().data())); auto src_mem = memory(src_md, cpu_engine, src_buf); -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 // oneDNN v2.x requires scale-shift as a combined array in float32 type. memory::dims scale_shift_dims = { 2, static_cast(num_elements_scale)}; @@ -124,7 +124,7 @@ class MklLayerNormOp : public OpKernel { void* shift_buf_dst = static_cast(shift_buf_tensor.flat().data()); auto shift_mem = memory(scale_shift_md, cpu_engine, shift_buf_dst); -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 void* scale_buf_src = static_cast(const_cast(scale_tensor.flat().data())); @@ -159,7 +159,7 @@ class MklLayerNormOp : public OpKernel { shift_reorder_prim.execute(*cpu_stream, shift_reorder_args); // Create layer_normalization primitive -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 auto lnorm_desc = layer_normalization_forward::desc( prop_kind::forward_inference, src_md, epsilon_, normalization_flags::use_scale_shift); @@ -170,7 +170,7 @@ class MklLayerNormOp : public OpKernel { auto lnorm_pd = layer_normalization_forward::primitive_desc( cpu_engine, prop_kind::forward_inference, src_md, dst_md, epsilon_, normalization_flags::use_scale | normalization_flags::use_shift); -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 auto lnorm_prim = layer_normalization_forward(lnorm_pd); // mean and variance memory @@ -189,12 +189,12 @@ class MklLayerNormOp : public OpKernel { lnorm_args.insert({DNNL_ARG_SRC, src_mem}); lnorm_args.insert({DNNL_ARG_MEAN, mean_mem}); lnorm_args.insert({DNNL_ARG_VARIANCE, variance_mem}); -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 lnorm_args.insert({DNNL_ARG_SCALE_SHIFT, scale_shift_mem}); #else lnorm_args.insert({DNNL_ARG_SCALE, scale_mem}); lnorm_args.insert({DNNL_ARG_SHIFT, shift_mem}); -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 lnorm_args.insert({DNNL_ARG_DST, dst_mem}); lnorm_prim.execute(*cpu_stream, lnorm_args); } catch (dnnl::error& e) { diff --git a/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h b/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h index 4b86eb50d5863e..60852930bc5a7a 100644 --- a/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h +++ b/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h @@ -40,7 +40,7 @@ using dnnl::stream; namespace tensorflow { -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 #define APPEND_ELTWISE(scale, alg, alpha, beta) \ append_eltwise(scale, alg, alpha, beta) #define APPEND_ELTWISE_RELU6(scale, alpha, beta) \ @@ -58,13 +58,13 @@ namespace tensorflow { (post_op_param.name == "dst_scale") #define SET_MKL_LAYOUT(md) SetMklLayout(md) #define TSCALED_BIAS float -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 -#if !defined(ENABLE_ONEDNN_OPENMP) && !defined(ENABLE_ONEDNN_V3) +#if !defined(ENABLE_ONEDNN_OPENMP) && defined(ENABLE_ONEDNN_V2) #define FWD_STREAM , *fwd_stream #else #define FWD_STREAM -#endif // !ENABLE_ONEDNN_OPENMP && !ENABLE_ONEDNN_V3 +#endif // !ENABLE_ONEDNN_OPENMP && ENABLE_ONEDNN_V2 static Eigen::internal::CacheSizes cache_sizes = Eigen::internal::CacheSizes(); @@ -208,9 +208,9 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive { std::shared_ptr dst_scale_mem; // Descriptor and primitive-descriptor for forward inner-product. -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 std::shared_ptr fwd_desc; -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 std::shared_ptr fwd_pd; // Memory descriptors. @@ -238,9 +238,9 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive { src_scale_mem(nullptr), wei_scale_mem(nullptr), dst_scale_mem(nullptr), -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 fwd_desc(nullptr), -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 fwd_pd(nullptr), src_md(nullptr), weight_md(nullptr), @@ -282,7 +282,7 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive { memory::format_tag::any)); } // Create an inner-product. -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 context_.fwd_desc.reset(new inner_product_forward::desc( matmul_fwd_params.const_weight ? prop_kind::forward_inference : prop_kind::forward_training, @@ -290,7 +290,7 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive { *context_.dst_md)); context_.fwd_pd.reset(new inner_product_forward::primitive_desc( *context_.fwd_desc, cpu_engine_)); -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 // Check if there is any fusion as post-ops auto const& post_op_params = matmul_fwd_params.post_op_params; @@ -348,7 +348,7 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive { float op_beta = post_op_param.param[2]; post_ops.APPEND_ELTWISE(op_scale, dnnl::algorithm::eltwise_logistic, op_alpha, op_beta); -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 } else if (post_op_param.name == "output_scale") { DCHECK_EQ(post_op_param.param.size(), 1); std::vector scales; @@ -376,7 +376,7 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive { memory::format_tag::x)); context_.dst_scale_mem.reset( new memory(*context_.dst_scale_md, cpu_engine_, DummyData)); -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 } else if (post_op_param.name == "sum") { DCHECK_EQ(post_op_param.param.size(), 1); float op_scale = post_op_param.param[0]; @@ -395,7 +395,7 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive { post_ops_attr.set_post_ops(post_ops); } -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 context_.fwd_pd.reset(new inner_product_forward::primitive_desc( *context_.fwd_desc, post_ops_attr, cpu_engine_)); #else @@ -405,7 +405,7 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive { : prop_kind::forward_training, *context_.src_md, *context_.weight_md, *context_.bias_md, *context_.dst_md, post_ops_attr)); -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 // Create memory primitive based on dummy data context_.src_mem.reset( @@ -428,7 +428,7 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive { {DNNL_ARG_BIAS, *context_.bias_mem}, {DNNL_ARG_SCRATCHPAD, *context_.sp_mem}, {DNNL_ARG_DST, *context_.dst_mem}}; -#ifdef ENABLE_ONEDNN_V3 +#ifndef ENABLE_ONEDNN_V2 if (is_scale_set["src"] && is_scale_set["wei"] && is_scale_set["dst"]) { net_args.insert( {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, *context_.src_scale_mem}); @@ -437,7 +437,7 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive { net_args.insert( {DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, *context_.dst_scale_mem}); } -#endif // ENABLE_ONEDNN_V3 +#endif // !ENABLE_ONEDNN_V2 context_.net_args.push_back(net_args); context_.fwd_primitives.push_back(*context_.matmul_fwd); return; @@ -521,13 +521,13 @@ class MklDnnMatMulFwdPrimitiveFactory : public MklPrimitiveFactory { DCHECK_EQ(post_op_param.param.size(), 1); key_creator.AddAsKey(post_op_param.name); key_creator.AddAsKey(post_op_param.param[0]); -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 } else if (post_op_param.name == "output_scale") { #else } else if (post_op_param.name == "src_scale" || post_op_param.name == "wei_scale" || post_op_param.name == "dst_scale") { -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 DCHECK_EQ(post_op_param.param.size(), 1); key_creator.AddAsKey(post_op_param.name); key_creator.AddAsKey(post_op_param.param[0]); @@ -612,12 +612,12 @@ class MklDnnMatMulOpBase : public OpKernel { return; } -#ifdef ENABLE_ONEDNN_V3 +#ifndef ENABLE_ONEDNN_V2 // For now, cache weights only for blocked format if (weight_md.get_format_kind() != memory::format_kind::blocked) { return; } -#endif // ENABLE_ONEDNN_V3 +#endif // !ENABLE_ONEDNN_V2 // reorder and cache the weight weight.SetUsrMem(weight_md, &weight_tensor); @@ -638,7 +638,7 @@ class MklDnnMatMulOpBase : public OpKernel { // cache the memory descriptor auto expected_md = matmul_fwd_pd->weights_desc(); -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 TensorShape weight_mkl_format; weight_mkl_format.AddDim(sizeof(expected_md) / sizeof(Tweight)); @@ -653,7 +653,7 @@ class MklDnnMatMulOpBase : public OpKernel { expected_md.get_data_type(), expected_md.get_dims(), expected_md.get_inner_blks(), expected_md.get_inner_idxs(), expected_md.get_strides()); -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 } Tweight* GetCachedWeight(OpKernelContext* context, @@ -661,7 +661,7 @@ class MklDnnMatMulOpBase : public OpKernel { TF_LOCKS_EXCLUDED(mu_) { tf_shared_lock lock(mu_); const Tensor& weight_t = weight_oi_; -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 const Tensor& weight_md_t = weight_oi_md_; // Check if the memory descriptor of the cached weight is same as @@ -693,7 +693,7 @@ class MklDnnMatMulOpBase : public OpKernel { const_cast(weight_t.flat().data())); } return nullptr; -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 } bool IsBiasCacheEmpty() TF_LOCKS_EXCLUDED(bias_cache_mutex_) { @@ -740,11 +740,11 @@ class MklDnnMatMulOpBase : public OpKernel { // Tensor to save reordered weight mutex mu_; Tensor weight_oi_ TF_GUARDED_BY(mu_); -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 Tensor weight_oi_md_ TF_GUARDED_BY(mu_); #else FilterMemoryDesc weight_oi_md_ TF_GUARDED_BY(mu_); -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 bool is_weight_const_; @@ -814,7 +814,7 @@ class MklMatMulPrimitive : public MklPrimitive { #ifdef DNNL_AARCH64_USE_ACL mutex_lock lock(primitive_execution_mu_); #endif -#if !defined(ENABLE_ONEDNN_OPENMP) && !defined(ENABLE_ONEDNN_V3) +#if !defined(ENABLE_ONEDNN_OPENMP) && defined(ENABLE_ONEDNN_V2) context_.a_mem->set_data_handle( static_cast(const_cast(a_data)), *stream); context_.b_mem->set_data_handle( @@ -837,7 +837,7 @@ class MklMatMulPrimitive : public MklPrimitive { context_.sp_mem->set_data_handle(sp_data); if (mul_data != nullptr) context_.mul_mem->set_data_handle(mul_data); if (add_data != nullptr) context_.add_mem->set_data_handle(add_data); -#endif // !ENABLE_ONEDNN_OPENMP && !ENABLE_ONEDNN_V3 +#endif // !ENABLE_ONEDNN_OPENMP && ENABLE_ONEDNN_V2 execute_primitives(context_.matmul_primitives, stream, context_.net_args); // After execution, set data handle back @@ -865,9 +865,9 @@ class MklMatMulPrimitive : public MklPrimitive { std::shared_ptr sp_mem; // Descriptor and primitive-descriptor for MatMul. -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 std::shared_ptr desc; -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 std::shared_ptr prim_desc; // Memory descriptors. @@ -888,9 +888,9 @@ class MklMatMulPrimitive : public MklPrimitive { mul_mem(nullptr), add_mem(nullptr), sp_mem(nullptr), -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 desc(nullptr), -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 prim_desc(nullptr), a_md(nullptr), b_md(nullptr), @@ -917,10 +917,10 @@ class MklMatMulPrimitive : public MklPrimitive { params.c_strides)); // Create matmul. -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 context_.desc.reset( new matmul::desc(*context_.a_md, *context_.b_md, *context_.c_md)); -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 // Check if there is any fusion as post-ops auto const& post_op_params = params.post_op_params; @@ -929,14 +929,14 @@ class MklMatMulPrimitive : public MklPrimitive { if (!post_op_params.empty()) { for (auto const& post_op_param : post_op_params) { if (post_op_param.name == "output_scale") { -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 // TODO(intel-tf): Verify if this code is needed. If not, it needs to // be removed. DCHECK_EQ(post_op_param.param.size(), 1); std::vector scales; scales.push_back(post_op_param.param[0]); post_ops_attr.set_output_scales(0, scales); -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 } else if (post_op_param.name == "mul") { context_.mul_md.reset(new memory::desc({post_op_param.dims}, post_op_param.data_type, @@ -954,14 +954,14 @@ class MklMatMulPrimitive : public MklPrimitive { post_ops_attr.set_post_ops(post_ops); } post_ops_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 context_.prim_desc.reset( new matmul::primitive_desc(*context_.desc, post_ops_attr, cpu_engine_)); #else context_.prim_desc.reset( new matmul::primitive_desc(cpu_engine_, *context_.a_md, *context_.b_md, *context_.c_md, post_ops_attr)); -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 // Create memory primitive based on dummy data. context_.a_mem.reset( diff --git a/tensorflow/core/kernels/mkl/mkl_maxpooling_op.cc b/tensorflow/core/kernels/mkl/mkl_maxpooling_op.cc index 2360cdefba407c..304b499b62d3a0 100644 --- a/tensorflow/core/kernels/mkl/mkl_maxpooling_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_maxpooling_op.cc @@ -116,12 +116,12 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase { : TFShapeToMklDnnDimsInNCDHW(input_tensor.shape(), this->data_format_tf_); memory::dims filter_dims, strides, padding_left, padding_right; -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 this->PoolParamsToDims(&pool_params, &filter_dims, &strides, #else memory::dims dilations; this->PoolParamsToDims(&pool_params, &filter_dims, &strides, &dilations, -#endif // ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 &padding_left, &padding_right, is_pool2d); // Get a pooling op from the cached pool @@ -134,11 +134,11 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase { else pooling_prop_kind = prop_kind::forward_training; MklPoolingParams fwdParams( -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 src_dims, output_dims_mkl_order, filter_dims, strides, #else src_dims, output_dims_mkl_order, filter_dims, strides, dilations, -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 padding_left, padding_right, dnnl::algorithm::pooling_max, pooling_prop_kind, static_cast(this->data_format_mkldnn_), input_md, @@ -291,12 +291,12 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase { orig_input_shape); memory::dims filter_dims, strides, padding_left, padding_right; -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 this->PoolParamsToDims(&pool_params, &filter_dims, &strides, #else memory::dims dilations; this->PoolParamsToDims(&pool_params, &filter_dims, &strides, &dilations, -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 &padding_left, &padding_right, is_pool2d); memory::dims orig_input_dims_mkl_order = @@ -333,12 +333,12 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase { MklPoolingParams bwdParams( orig_input_dims_mkl_order, output_dims_mkl_order, filter_dims, -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 strides, padding_left, padding_right, dnnl::algorithm::pooling_max, #else strides, dilations, padding_left, padding_right, dnnl::algorithm::pooling_max, -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 prop_kind::forward_training, static_cast(this->data_format_mkldnn_), src_md, this->native_format_); @@ -458,7 +458,6 @@ TF_CALL_bfloat16(REGISTER_MKL_MAXPOOL3D_KERNELS); TF_CALL_float(REGISTER_MKL_MAXPOOL_KERNELS); TF_CALL_bfloat16(REGISTER_MKL_MAXPOOL_KERNELS); -#ifndef ENABLE_ONEDNN_V3 REGISTER_KERNEL_BUILDER(Name("_MklQuantizedMaxPool") .Device(DEVICE_CPU) .TypeConstraint("T") @@ -477,7 +476,6 @@ REGISTER_KERNEL_BUILDER( REGISTER_KERNEL_BUILDER( Name("_QuantizedMaxPool3D").Device(DEVICE_CPU).TypeConstraint("T"), MklMaxPoolingOp); -#endif // !ENABLE_ONEDNN_V3 } // namespace tensorflow #endif // INTEL_MKL diff --git a/tensorflow/core/kernels/mkl/mkl_pooling_ops_common.cc b/tensorflow/core/kernels/mkl/mkl_pooling_ops_common.cc index c73233cab8dd26..d53b8f37eef2e2 100644 --- a/tensorflow/core/kernels/mkl/mkl_pooling_ops_common.cc +++ b/tensorflow/core/kernels/mkl/mkl_pooling_ops_common.cc @@ -25,14 +25,14 @@ limitations under the License. #include "tensorflow/core/framework/kernel_shape_util.h" namespace tensorflow { -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 #define AVG_POOLING_DCHECK(params) \ params.alg_kind == dnnl::algorithm::pooling_avg || #define GET_MEMORY_DESC(md) md.data #else #define AVG_POOLING_DCHECK(params) #define GET_MEMORY_DESC(md) md -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 using dnnl::prop_kind; template @@ -58,7 +58,7 @@ void MklPoolingFwdPrimitive::Setup(const MklPoolingParams& fwdParams) { : memory::format_tag::any)); // Create a pooling descriptor. -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 context_.fwd_desc.reset(new pooling_forward::desc( fwdParams.prop_kind, fwdParams.alg_kind, *context_.src_md, *context_.dst_md, fwdParams.strides, fwdParams.filter_dims, @@ -70,7 +70,7 @@ void MklPoolingFwdPrimitive::Setup(const MklPoolingParams& fwdParams) { cpu_engine_, fwdParams.prop_kind, fwdParams.alg_kind, *context_.src_md, *context_.dst_md, fwdParams.strides, fwdParams.filter_dims, fwdParams.dilations, fwdParams.padding_left, fwdParams.padding_right)); -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 context_.dst_fmt = static_cast(memory::format_tag::any); // Create oneDNN internal memory object with dummy data. @@ -104,7 +104,7 @@ void MklPoolingFwdPrimitive::Execute(const T* src_data, T* dst_data, #ifdef DNNL_AARCH64_USE_ACL mutex_lock lock(primitive_execution_mu_); #endif -#if !defined(ENABLE_ONEDNN_OPENMP) && !defined(ENABLE_ONEDNN_V3) +#if !defined(ENABLE_ONEDNN_OPENMP) && defined(ENABLE_ONEDNN_V2) context_.src_mem->set_data_handle( static_cast(const_cast(src_data)), *fwd_stream); context_.dst_mem->set_data_handle(static_cast(dst_data), *fwd_stream); @@ -124,7 +124,7 @@ void MklPoolingFwdPrimitive::Execute(const T* src_data, T* dst_data, DCHECK(ws_data != nullptr); context_.ws_mem->set_data_handle(ws_data); } -#endif // !ENABLE_ONEDNN_OPENMP && !ENABLE_ONEDNN_V3 +#endif // !ENABLE_ONEDNN_OPENMP && ENABLE_ONEDNN_V2 execute_primitives(context_.fwd_primitives, fwd_stream, context_.net_args); // Set back data handle. @@ -141,10 +141,8 @@ void MklPoolingFwdPrimitive::Execute(const T* src_data, T* dst_data, template class MklPoolingFwdPrimitive; template class MklPoolingFwdPrimitive; -#ifndef ENABLE_ONEDNN_V3 template class MklPoolingFwdPrimitive; template class MklPoolingFwdPrimitive; -#endif // !ENABLE_ONEDNN_V3 template void MklPoolingBwdPrimitive::Setup(const MklPoolingParams& bwdParams) { @@ -164,7 +162,7 @@ void MklPoolingBwdPrimitive::Setup(const MklPoolingParams& bwdParams) { ? bwdParams.src_format : memory::format_tag::any)); -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 // Create a backward primitive. The implementation for backward must comply to // the workspace format it gets from forward pass, so we directly use src_md // and dst_md here. @@ -190,7 +188,7 @@ void MklPoolingBwdPrimitive::Setup(const MklPoolingParams& bwdParams) { cpu_engine_, bwdParams.alg_kind, *context_.src_md, *context_.dst_md, bwdParams.strides, bwdParams.filter_dims, bwdParams.dilations, bwdParams.padding_left, bwdParams.padding_right, *context_.fwd_pd)); -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 // Create oneDNN internal memory object with dummy data. context_.diff_src_mem.reset(new memory(context_.bwd_pd.get()->diff_src_desc(), @@ -221,7 +219,7 @@ void MklPoolingBwdPrimitive::Execute(const T* diff_dst_data, #ifdef DNNL_AARCH64_USE_ACL mutex_lock lock(primitive_execution_mu_); #endif -#if !defined(ENABLE_ONEDNN_OPENMP) && !defined(ENABLE_ONEDNN_V3) +#if !defined(ENABLE_ONEDNN_OPENMP) && defined(ENABLE_ONEDNN_V2) context_.diff_dst_mem->set_data_handle( static_cast(const_cast(diff_dst_data)), *bwd_stream); context_.diff_src_mem->set_data_handle(static_cast(diff_src_data), @@ -238,7 +236,7 @@ void MklPoolingBwdPrimitive::Execute(const T* diff_dst_data, DCHECK(ws_data != nullptr); context_.ws_mem->set_data_handle(const_cast(ws_data)); } -#endif // !ENABLE_ONEDNN_OPENMP && !ENABLE_ONEDNN_V3 +#endif // !ENABLE_ONEDNN_OPENMP && ENABLE_ONEDNN_V2 execute_primitives(context_.bwd_primitives, bwd_stream, context_.net_args); @@ -327,13 +325,13 @@ void MklPoolParameters::Init(OpKernelContext* context, col_stride = GetTensorDim(stride, data_format, 'W'); depth_stride = GetTensorDim(stride, data_format, 'C'); -#ifdef ENABLE_ONEDNN_V3 +#ifndef ENABLE_ONEDNN_V2 // TODO(intel-tf): we are setting dilations to 0 to mimic the behavior of // oneDNN v2.x integration code. We can extend this in the future to support // dilations != 0 row_dilation = 0; col_dilation = 0; -#endif // ENABLE_ONEDNN_V3 +#endif // !ENABLE_ONEDNN_V2 // We only support 2D pooling across width/height and depthwise // pooling, not a combination. @@ -356,12 +354,12 @@ void MklPoolParameters::Init(OpKernelContext* context, col_stride = GetTensorDim(stride, data_format, '2'); depth_stride = GetTensorDim(stride, data_format, 'C'); -#ifdef ENABLE_ONEDNN_V3 +#ifndef ENABLE_ONEDNN_V2 // TODO(intel-tf): TensorFlow's 3D-pooling API does not support dilations planes_dilation = 0; row_dilation = 0; col_dilation = 0; -#endif // ENABLE_ONEDNN_V3 +#endif // !ENABLE_ONEDNN_V2 // We only support 3D pooling across depth/width/height and depthwise // pooling, not a combination. diff --git a/tensorflow/core/kernels/mkl/mkl_pooling_ops_common.h b/tensorflow/core/kernels/mkl/mkl_pooling_ops_common.h index 012244a79691ad..d3ad93f73c264b 100644 --- a/tensorflow/core/kernels/mkl/mkl_pooling_ops_common.h +++ b/tensorflow/core/kernels/mkl/mkl_pooling_ops_common.h @@ -33,13 +33,13 @@ limitations under the License. namespace tensorflow { -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 #define GET_DIMS data.dims #define SET_MKL_LAYOUT(md) SetMklLayout(&md) #else #define GET_DIMS get_dims() #define SET_MKL_LAYOUT(md) SetMklLayout(md) -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 using dnnl::pooling_backward; using dnnl::pooling_forward; @@ -54,9 +54,9 @@ struct MklPoolingParams { memory::dims dst_dims; memory::dims filter_dims; memory::dims strides; -#ifdef ENABLE_ONEDNN_V3 +#ifndef ENABLE_ONEDNN_V2 memory::dims dilations; -#endif // ENABLE_ONEDNN_V3 +#endif // !ENABLE_ONEDNN_V2 memory::dims padding_left; memory::dims padding_right; dnnl::algorithm alg_kind; @@ -67,9 +67,9 @@ struct MklPoolingParams { MklPoolingParams(memory::dims src_dims, memory::dims dst_dims, memory::dims filter_dims, memory::dims strides, -#ifdef ENABLE_ONEDNN_V3 +#ifndef ENABLE_ONEDNN_V2 memory::dims dilations, -#endif // ENABLE_ONEDNN_V3 +#endif // !ENABLE_ONEDNN_V2 memory::dims padding_left, memory::dims padding_right, dnnl::algorithm alg_kind, dnnl::prop_kind prop_kind, memory::format_tag src_format, memory::desc src_md, @@ -78,9 +78,9 @@ struct MklPoolingParams { dst_dims(dst_dims), filter_dims(filter_dims), strides(strides), -#ifdef ENABLE_ONEDNN_V3 +#ifndef ENABLE_ONEDNN_V2 dilations(dilations), -#endif // ENABLE_ONEDNN_V3 +#endif // !ENABLE_ONEDNN_V2 padding_left(padding_left), padding_right(padding_right), alg_kind(alg_kind), @@ -141,9 +141,9 @@ class MklPoolingFwdPrimitive : public MklPrimitive { std::shared_ptr dst_mem; // Pooling forward descriptor and primitive descriptor. -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 std::shared_ptr fwd_desc; -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 std::shared_ptr fwd_pd; // Memory descriptor. @@ -164,9 +164,9 @@ class MklPoolingFwdPrimitive : public MklPrimitive { ws_mem(nullptr), src_mem(nullptr), dst_mem(nullptr), -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 fwd_desc(nullptr), -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 fwd_pd(nullptr), src_md(nullptr), dst_md(nullptr), @@ -220,9 +220,9 @@ class MklPoolingFwdPrimitiveFactory : public MklPrimitiveFactory { key_creator.AddAsKey(fwdParams.dst_dims); key_creator.AddAsKey(fwdParams.filter_dims); key_creator.AddAsKey(fwdParams.strides); -#ifdef ENABLE_ONEDNN_V3 +#ifndef ENABLE_ONEDNN_V2 key_creator.AddAsKey(fwdParams.dilations); -#endif // ENABLE_ONEDNN_V3 +#endif // !ENABLE_ONEDNN_V2 key_creator.AddAsKey(fwdParams.padding_left); key_creator.AddAsKey(fwdParams.padding_right); key_creator.AddAsKey(static_cast(fwdParams.alg_kind)); @@ -297,10 +297,10 @@ class MklPoolingBwdPrimitive : public MklPrimitive { std::shared_ptr dst_md; // Forward and backward pooling descriptors and primitive descriptors. -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 std::shared_ptr fwd_desc; std::shared_ptr bwd_desc; -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 std::shared_ptr fwd_pd; std::shared_ptr bwd_pd; @@ -320,10 +320,10 @@ class MklPoolingBwdPrimitive : public MklPrimitive { diff_dst_mem(nullptr), src_md(nullptr), dst_md(nullptr), -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 fwd_desc(nullptr), bwd_desc(nullptr), -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 fwd_pd(nullptr), bwd_pd(nullptr), bwd(nullptr) { @@ -376,9 +376,9 @@ class MklPoolingBwdPrimitiveFactory : public MklPrimitiveFactory { key_creator.AddAsKey(bwdParams.dst_dims); key_creator.AddAsKey(bwdParams.filter_dims); key_creator.AddAsKey(bwdParams.strides); -#ifdef ENABLE_ONEDNN_V3 +#ifndef ENABLE_ONEDNN_V2 key_creator.AddAsKey(bwdParams.dilations); -#endif // ENABLE_ONEDNN_V3 +#endif // !ENABLE_ONEDNN_V2 key_creator.AddAsKey(bwdParams.padding_left); key_creator.AddAsKey(bwdParams.padding_right); key_creator.AddAsKey(static_cast(bwdParams.alg_kind)); @@ -416,11 +416,11 @@ struct MklPoolParameters { int col_stride; int depth_stride; -#ifdef ENABLE_ONEDNN_V3 +#ifndef ENABLE_ONEDNN_V2 int planes_dilation; // Pool3D int row_dilation; int col_dilation; -#endif // ENABLE_ONEDNN_V3 +#endif // !ENABLE_ONEDNN_V2 int64 out_planes; // Pool3D int64 out_height; @@ -450,11 +450,11 @@ struct MklPoolParameters { row_stride(0), col_stride(0), depth_stride(0), -#ifdef ENABLE_ONEDNN_V3 +#ifndef ENABLE_ONEDNN_V2 planes_dilation(0), row_dilation(0), col_dilation(0), -#endif // ENABLE_ONEDNN_V3 +#endif // !ENABLE_ONEDNN_V2 out_planes(0), out_height(0), out_width(0), @@ -572,9 +572,9 @@ class MklPoolingOpBase : public OpKernel { void PoolParamsToDims(const MklPoolParameters* pool_params, memory::dims* filter_dims, memory::dims* strides, -#ifdef ENABLE_ONEDNN_V3 +#ifndef ENABLE_ONEDNN_V2 memory::dims* dilations, -#endif // ENABLE_ONEDNN_V3 +#endif // !ENABLE_ONEDNN_V2 memory::dims* padding_left, memory::dims* padding_right, bool is_pool2d) { if (is_pool2d) { @@ -583,10 +583,10 @@ class MklPoolingOpBase : public OpKernel { memory::dims({pool_params->window_rows, pool_params->window_cols}); *strides = memory::dims({pool_params->row_stride, pool_params->col_stride}); -#ifdef ENABLE_ONEDNN_V3 +#ifndef ENABLE_ONEDNN_V2 *dilations = memory::dims({pool_params->row_dilation, pool_params->col_dilation}); -#endif // ENABLE_ONEDNN_V3 +#endif // !ENABLE_ONEDNN_V2 *padding_left = memory::dims({static_cast(pool_params->pad_top), static_cast(pool_params->pad_left)}); *padding_right = memory::dims({static_cast(pool_params->pad_bottom), @@ -599,11 +599,11 @@ class MklPoolingOpBase : public OpKernel { *strides = memory::dims({pool_params->planes_stride, pool_params->row_stride, pool_params->col_stride}); -#ifdef ENABLE_ONEDNN_V3 +#ifndef ENABLE_ONEDNN_V2 *dilations = memory::dims({pool_params->planes_dilation, pool_params->row_dilation, pool_params->col_dilation}); -#endif // ENABLE_ONEDNN_V3 +#endif // !ENABLE_ONEDNN_V2 *padding_left = memory::dims({static_cast(pool_params->pad_P1), static_cast(pool_params->pad_top), diff --git a/tensorflow/core/kernels/mkl/mkl_qmatmul_op.cc b/tensorflow/core/kernels/mkl/mkl_qmatmul_op.cc index 990b4f31f3d89a..8e0e85f1b5c4a9 100644 --- a/tensorflow/core/kernels/mkl/mkl_qmatmul_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_qmatmul_op.cc @@ -96,7 +96,7 @@ limitations under the License. // // More information of this implementation can be found in // https://software.intel.com/en-us/articles/lower-numerical-precision-deep-learning-inference-and-training -#if defined(INTEL_MKL) +#ifdef INTEL_MKL #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/kernels/fill_functor.h" @@ -116,11 +116,11 @@ enum { namespace tensorflow { -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 #define TSCALED_BIAS Tbias #else #define TSCALED_BIAS float -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 template @@ -320,7 +320,7 @@ class MklDnnQuantizedMatMulOp UserScratchPad scratch_pad; scratch_pad.AllocateSPTensor(matmul_fwd, context); -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 Tbias* bias_data = this->GetBiasHandle( context, matmul_fwd_pd, bias_tensor, weight_tensor, cpu_stream); #else @@ -330,7 +330,7 @@ class MklDnnQuantizedMatMulOp Tensor temp_scaled_bias_tensor; this->GetBiasHandle(context, matmul_fwd_pd, bias_tensor, weight_tensor, cpu_stream, &temp_scaled_bias_tensor, &bias_data); -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 // Execute inner-product matmul_fwd->Execute(src_data, weight_data, bias_data, dst_data, matmul_fwd_dims, scratch_pad.Get(), cpu_stream); @@ -428,7 +428,7 @@ class MklDnnQuantizedMatMulOp absl::InvalidArgumentError(absl::StrCat( "`max_b` must be rank 0 but is rank ", max_weight_tensor.dims()))); -#ifdef ENABLE_ONEDNN_V3 +#ifndef ENABLE_ONEDNN_V2 const float min_input = min_input_tensor.scalar()(); const float max_input = max_input_tensor.scalar()(); const float min_weight = min_weight_tensor.scalar()(); @@ -441,7 +441,7 @@ class MklDnnQuantizedMatMulOp float wei_scale = std::max(std::abs(min_weight), std::abs(max_weight)) / 127.0; float dst_scale = 1.0; -#endif // ENABLE_ONEDNN_V3 +#endif // !ENABLE_ONEDNN_V2 // When the output type is quint8, the output data is requantized into // quint8. A post_op "output_scale" is added to do the conversion. if (std::is_same::value || @@ -464,7 +464,7 @@ class MklDnnQuantizedMatMulOp const float max_freezed_output = max_freezed_tensor.scalar()(); float scale_eightbit = std::max(std::abs(min_freezed_output), std::abs(max_freezed_output)); -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 float min_output_value; float max_output_value; ComputeOutputRangeForInt32(context, &min_output_value, &max_output_value); @@ -504,7 +504,7 @@ class MklDnnQuantizedMatMulOp params.post_op_params.push_back({"src_scale", {src_scale}}); params.post_op_params.push_back({"wei_scale", {wei_scale}}); params.post_op_params.push_back({"dst_scale", {dst_scale}}); -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 } // This function handles bias conversion and compensation for MIN_FIRST and @@ -512,7 +512,7 @@ class MklDnnQuantizedMatMulOp // B's32 = Q'a * Qw * Bf32 + Q'a * Qw * Min(Af32) * 1 * Wf32 // If input is quantized via SCALE, // Bs32 = Qa * Qw * Bf32. -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 Tbias* GetBiasHandle( OpKernelContext* context, std::shared_ptr& @@ -753,7 +753,7 @@ class MklDnnQuantizedMatMulOp } return false; } -#endif // ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 private: memory* input_bias_ = nullptr; @@ -849,4 +849,4 @@ REGISTER_MKL_KERNEL_ALL_BIAS_TYPES("_MklQuantizedMatMulWithBiasAndDequantize", } // namespace tensorflow -#endif // INTEL_MKL && !ENABLE_ONEDNN_V3 +#endif // INTEL_MKL diff --git a/tensorflow/core/kernels/mkl/mkl_qmatmul_op_test.cc b/tensorflow/core/kernels/mkl/mkl_qmatmul_op_test.cc index 22b56e19e3bb63..95baa8b0f5eb3b 100644 --- a/tensorflow/core/kernels/mkl/mkl_qmatmul_op_test.cc +++ b/tensorflow/core/kernels/mkl/mkl_qmatmul_op_test.cc @@ -273,11 +273,11 @@ TEST_F(QuantizedMatMulTest, Small_withBiasAndReq) { // 178 * 1.00392 ~= 178.698 ~= 179 Tensor expected(allocator(), DT_QUINT8, TensorShape({2, 4})); -#ifdef ENABLE_ONEDNN_V3 +#ifndef ENABLE_ONEDNN_V2 test::FillValues(&expected, {84, 60, 116, 52, 183, 168, 233, 178}); #else test::FillValues(&expected, {84, 60, 116, 52, 184, 169, 234, 179}); -#endif // ENABLE_ONEDNN_V3 +#endif // !ENABLE_ONEDNN_V2 const Tensor& output = *GetOutput(0); test::ExpectTensorEqual(expected, output); @@ -470,11 +470,11 @@ TEST_F(QuantizedMatMulTest, Small_withBiasAndReluAndReq) { // 178 * 1.00392 ~= 178.698 ~= 179 Tensor expected(allocator(), DT_QUINT8, TensorShape({2, 4})); -#ifdef ENABLE_ONEDNN_V3 +#ifndef ENABLE_ONEDNN_V2 test::FillValues(&expected, {84, 60, 116, 52, 183, 168, 233, 178}); #else test::FillValues(&expected, {84, 60, 116, 52, 184, 169, 234, 179}); -#endif // ENABLE_ONEDNN_V3 +#endif // !ENABLE_ONEDNN_V2 const Tensor& output = *GetOutput(0); test::ExpectTensorEqual(expected, output); diff --git a/tensorflow/core/kernels/mkl/mkl_quantize_op.cc b/tensorflow/core/kernels/mkl/mkl_quantize_op.cc index 36e7178d5b7249..3e17aa2e162c56 100644 --- a/tensorflow/core/kernels/mkl/mkl_quantize_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_quantize_op.cc @@ -57,11 +57,11 @@ enum { namespace tensorflow { -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 #define SET_MKL_LAYOUT(md) SetMklLayout(&md) #else #define SET_MKL_LAYOUT(md) SetMklLayout(md) -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 typedef Eigen::ThreadPoolDevice CPUDevice; @@ -69,9 +69,9 @@ struct MklReorderWithScaleFwdParams { memory::dims src_dims; memory::desc src_md; memory::desc dst_md; -#ifdef ENABLE_ONEDNN_V3 +#ifndef ENABLE_ONEDNN_V2 memory::desc scale_md; -#endif // ENABLE_ONEDNN_V3 +#endif // !ENABLE_ONEDNN_V2 string dtypes = string(""); struct PostOpParam { string name; @@ -79,7 +79,7 @@ struct MklReorderWithScaleFwdParams { }; PostOpParam post_op_params; -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 MklReorderWithScaleFwdParams(memory::dims src_dims, memory::desc src_md, memory::desc dst_md) : src_dims(src_dims), src_md(src_md), dst_md(dst_md) {} @@ -90,7 +90,7 @@ struct MklReorderWithScaleFwdParams { src_md(src_md), dst_md(dst_md), scale_md(scale_md) {} -#endif // ENABLE_ONEDNN_V3 +#endif // !ENABLE_ONEDNN_V2 }; class MklReorderWithScalePrimitive : public MklPrimitive { @@ -107,30 +107,30 @@ class MklReorderWithScalePrimitive : public MklPrimitive { std::shared_ptr GetPrimitive() { return context_.reorder_prim; } void Execute(void* src_data, void* dst_data, -#ifdef ENABLE_ONEDNN_V3 +#ifndef ENABLE_ONEDNN_V2 void* scale_data, -#endif // ENABLE_ONEDNN_V3 +#endif // !ENABLE_ONEDNN_V2 std::shared_ptr reorder_stream) { #ifdef DNNL_AARCH64_USE_ACL mutex_lock lock(primitive_execution_mu_); #endif -#if !defined(ENABLE_ONEDNN_OPENMP) && !defined(ENABLE_ONEDNN_V3) +#if !defined(ENABLE_ONEDNN_OPENMP) && defined(ENABLE_ONEDNN_V2) context_.src_mem->set_data_handle(src_data, *reorder_stream); context_.dst_mem->set_data_handle(dst_data, *reorder_stream); #else context_.src_mem->set_data_handle(src_data); context_.dst_mem->set_data_handle(dst_data); -#endif // !ENABLE_ONEDNN_OPENMP && !ENABLE_ONEDNN_V3 -#ifdef ENABLE_ONEDNN_V3 +#endif // !ENABLE_ONEDNN_OPENMP && ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V2 context_.scale_mem->set_data_handle(scale_data); -#endif // ENABLE_ONEDNN_V3 +#endif // !ENABLE_ONEDNN_V2 context_.reorder_prim->execute(*reorder_stream, context_.prim_args); // After execution, set data handle back. context_.src_mem->set_data_handle(DummyData); context_.dst_mem->set_data_handle(DummyData); -#ifdef ENABLE_ONEDNN_V3 +#ifndef ENABLE_ONEDNN_V2 context_.scale_mem->set_data_handle(DummyData); -#endif // ENABLE_ONEDNN_V3 +#endif // !ENABLE_ONEDNN_V2 } private: @@ -139,9 +139,9 @@ class MklReorderWithScalePrimitive : public MklPrimitive { // MKL-DNN memory std::shared_ptr src_mem; std::shared_ptr dst_mem; -#ifdef ENABLE_ONEDNN_V3 +#ifndef ENABLE_ONEDNN_V2 std::shared_ptr scale_mem; -#endif // ENABLE_ONEDNN_V3 +#endif // !ENABLE_ONEDNN_V2 // Reorder primitive descriptor and primitive std::shared_ptr reorder_pd; @@ -155,9 +155,9 @@ class MklReorderWithScalePrimitive : public MklPrimitive { ReorderContext() : src_mem(nullptr), dst_mem(nullptr), -#ifdef ENABLE_ONEDNN_V3 +#ifndef ENABLE_ONEDNN_V2 scale_mem(nullptr), -#endif // ENABLE_ONEDNN_V3 +#endif // !ENABLE_ONEDNN_V2 reorder_pd(nullptr), reorder_prim(nullptr) { } @@ -170,14 +170,14 @@ class MklReorderWithScalePrimitive : public MklPrimitive { new memory(fwdParams.src_md, cpu_engine_, DummyData)); context_.dst_mem.reset( new memory(fwdParams.dst_md, cpu_engine_, DummyData)); -#ifdef ENABLE_ONEDNN_V3 +#ifndef ENABLE_ONEDNN_V2 context_.scale_mem.reset( new memory(fwdParams.scale_md, cpu_engine_, DummyData)); -#endif // ENABLE_ONEDNN_V3 +#endif // !ENABLE_ONEDNN_V2 // Check if there is any fusion as post-ops dnnl::primitive_attr post_ops_attr; -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 auto const& post_op_params = fwdParams.post_op_params; DCHECK(post_op_params.name == "scale"); DCHECK_EQ(post_op_params.param.size(), 1); @@ -186,7 +186,7 @@ class MklReorderWithScalePrimitive : public MklPrimitive { post_ops_attr.set_output_scales(0, scales); #else post_ops_attr.set_scales_mask(DNNL_ARG_SRC, 0 /* mask */); -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 context_.reorder_pd.reset( new ReorderPd(cpu_engine_, context_.src_mem->get_desc(), cpu_engine_, @@ -196,10 +196,10 @@ class MklReorderWithScalePrimitive : public MklPrimitive { context_.reorder_prim.reset(new reorder(*context_.reorder_pd)); context_.prim_args.insert({DNNL_ARG_FROM, *context_.src_mem}); context_.prim_args.insert({DNNL_ARG_TO, *context_.dst_mem}); -#ifdef ENABLE_ONEDNN_V3 +#ifndef ENABLE_ONEDNN_V2 context_.prim_args.insert( {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, *context_.scale_mem}); -#endif // ENABLE_ONEDNN_V3 +#endif // !ENABLE_ONEDNN_V2 } #ifdef DNNL_AARCH64_USE_ACL @@ -437,9 +437,9 @@ class MklQuantizeV2Op : public OpKernel { // they are wrapper MklDnnData src(&cpu_engine); MklDnnData dst(&cpu_engine); -#ifdef ENABLE_ONEDNN_V3 +#ifndef ENABLE_ONEDNN_V2 MklDnnData scale(&cpu_engine); -#endif // ENABLE_ONEDNN_V3 +#endif // !ENABLE_ONEDNN_V2 auto src_md = src_mkl_shape.IsMklTensor() @@ -538,7 +538,7 @@ class MklQuantizeV2Op : public OpKernel { const int64 number_of_steps = static_cast(1) << number_of_bits; scale_factor = (number_of_steps - 1.0) / (max_range - min_range); } -#ifdef ENABLE_ONEDNN_V3 +#ifndef ENABLE_ONEDNN_V2 auto scale_md = memory::desc({1}, MklDnnType(), memory::format_tag::x); MklReorderWithScaleFwdParams fwdParams(src_dims, src_md, dst_md, scale_md); @@ -549,7 +549,7 @@ class MklQuantizeV2Op : public OpKernel { #else MklReorderWithScaleFwdParams fwdParams(src_dims, src_md, dst_md); fwdParams.dtypes.append(typeid(T).name()); -#endif // ENABLE_ONEDNN_V3 +#endif // !ENABLE_ONEDNN_V2 fwdParams.post_op_params.name = "scale"; fwdParams.post_op_params.param.push_back(scale_factor); @@ -566,9 +566,9 @@ class MklQuantizeV2Op : public OpKernel { cpu_stream.reset(CreateStream(&eigen_tp, reorder_prim->GetEngine())); reorder_prim->Execute(src.GetUsrMemDataHandle(), dst.GetUsrMemDataHandle(), -#ifdef ENABLE_ONEDNN_V3 +#ifndef ENABLE_ONEDNN_V2 scale.GetUsrMemDataHandle(), -#endif // ENABLE_ONEDNN_V3 +#endif // !ENABLE_ONEDNN_V2 cpu_stream); output_min_tensor->scalar()() = min_range; diff --git a/tensorflow/core/kernels/mkl/mkl_quantized_conv_ops_perchannel_test.cc b/tensorflow/core/kernels/mkl/mkl_quantized_conv_ops_perchannel_test.cc index 5965bbda51ea8c..c3388fa51c31ce 100644 --- a/tensorflow/core/kernels/mkl/mkl_quantized_conv_ops_perchannel_test.cc +++ b/tensorflow/core/kernels/mkl/mkl_quantized_conv_ops_perchannel_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#if defined(INTEL_MKL) && !defined(ENABLE_ONEDNN_V3) && defined(ENABLE_MKL) +#if defined(INTEL_MKL) && defined(ENABLE_MKL) #define EIGEN_USE_THREADS #include @@ -191,4 +191,4 @@ TEST_F(QuantizedConv2DPerChannelTest, SmallOldAPI) { TestSmall(true); } TEST_F(QuantizedConv2DPerChannelTest, SmallNewAPI) { TestSmall(false); } } // namespace tensorflow -#endif // INTEL_MKL && !ENABLE_ONEDNN_V3 && ENABLE_MKL +#endif // INTEL_MKL && ENABLE_MKL diff --git a/tensorflow/core/kernels/mkl/mkl_relu_op.cc b/tensorflow/core/kernels/mkl/mkl_relu_op.cc index 03f19e21da86ca..2d7064b51692e9 100644 --- a/tensorflow/core/kernels/mkl/mkl_relu_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_relu_op.cc @@ -14,7 +14,10 @@ limitations under the License. ==============================================================================*/ // See docs in ../ops/nn_ops.cc. -#if defined(INTEL_MKL) && !defined(ENABLE_ONEDNN_V3) +#if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V2) +// TODO(intel-tf): This file is no longer used and needs to be removed. +// This file will be an empty compilation unit when building with oneDNN v3.x +// (default behavior). It can be compiled only when building with oneDNN v2.x. #include @@ -1237,4 +1240,4 @@ TF_CALL_bfloat16(REGISTER_LeakyRelu_MKL_SUPPORTED_KERNELS_TYPES); } // namespace tensorflow -#endif // INTEL_MKL && !ENABLE_ONEDNN_V3 +#endif // INTEL_MKL && ENABLE_ONEDNN_V2 diff --git a/tensorflow/core/kernels/mkl/mkl_relu_op_test.cc b/tensorflow/core/kernels/mkl/mkl_relu_op_test.cc index c507dd210b34d9..f39b96f3606f51 100644 --- a/tensorflow/core/kernels/mkl/mkl_relu_op_test.cc +++ b/tensorflow/core/kernels/mkl/mkl_relu_op_test.cc @@ -13,7 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#if defined(INTEL_MKL) && !defined(ENABLE_ONEDNN_V3) && defined(ENABLE_MKL) +#if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V2) && defined(ENABLE_MKL) +// TODO(intel-tf): This file is no longer used and needs to be removed. +// This file will be an empty compilation unit when building with oneDNN v3.x +// (default behavior). It can be compiled only when building with oneDNN v2.x. #include "absl/strings/match.h" #include "tensorflow/cc/ops/const_op.h" @@ -136,4 +139,4 @@ TEST_ALL_SIZES(LeakyReluGrad) } // namespace tensorflow -#endif // INTEL_MKL && !ENABLE_ONEDNN_V3 && ENABLE_MKL +#endif // INTEL_MKL && ENABLE_ONEDNN_V2 && ENABLE_MKL diff --git a/tensorflow/core/kernels/mkl/mkl_requantize_per_channel_op.cc b/tensorflow/core/kernels/mkl/mkl_requantize_per_channel_op.cc index ecf518baf4ed98..6080df45c1ae5a 100644 --- a/tensorflow/core/kernels/mkl/mkl_requantize_per_channel_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_requantize_per_channel_op.cc @@ -106,7 +106,7 @@ class MklRequantizePerChannelOp : public OpKernel { } dnnl::primitive_attr reorder_attr; -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 reorder_attr.set_output_scales(2, scales); #else reorder_attr.set_scales_mask(DNNL_ARG_SRC, 2); @@ -114,7 +114,7 @@ class MklRequantizePerChannelOp : public OpKernel { MklDnnType(), memory::format_tag::x}, cpu_engine_, scales.data()); -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 // Create the oneDNN wrapper over Eigen threadpool and set max threads // in oneDNN. @@ -157,9 +157,9 @@ class MklRequantizePerChannelOp : public OpKernel { reorder_stream.reset(CreateStream(&eigen_tp, cpu_engine_)); std::unordered_map reorder_args = { {DNNL_ARG_FROM, *input_mem_prim}, {DNNL_ARG_TO, *output_mem_prim}}; -#ifdef ENABLE_ONEDNN_V3 +#ifndef ENABLE_ONEDNN_V2 reorder_args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, scale_mem}); -#endif // ENABLE_ONEDNN_V3 +#endif // !ENABLE_ONEDNN_V2 std::unique_ptr reorder_prim( new dnnl::reorder(reorder_pd)); reorder_prim->execute(*reorder_stream, reorder_args); diff --git a/tensorflow/core/kernels/mkl/mkl_softmax_op.cc b/tensorflow/core/kernels/mkl/mkl_softmax_op.cc index 50caffa5e2d1e7..7a7e14c1d90524 100644 --- a/tensorflow/core/kernels/mkl/mkl_softmax_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_softmax_op.cc @@ -67,7 +67,7 @@ class MklSoftmaxPrimitive : public MklPrimitive { #ifdef DNNL_AARCH64_USE_ACL mutex_lock lock(primitive_execution_mu_); #endif -#if !defined(ENABLE_ONEDNN_OPENMP) && !defined(ENABLE_ONEDNN_V3) +#if !defined(ENABLE_ONEDNN_OPENMP) && defined(ENABLE_ONEDNN_V2) context_.src_mem->set_data_handle( static_cast(const_cast(src_data)), *fwd_cpu_stream); context_.dst_mem->set_data_handle(static_cast(dst_data), @@ -76,7 +76,7 @@ class MklSoftmaxPrimitive : public MklPrimitive { context_.src_mem->set_data_handle( static_cast(const_cast(src_data))); context_.dst_mem->set_data_handle(static_cast(dst_data)); -#endif // !ENABLE_ONEDNN_OPENMP && !ENABLE_ONEDNN_V3 +#endif // !ENABLE_ONEDNN_OPENMP && ENABLE_ONEDNN_V2 DCHECK_EQ(context_.fwd_primitives.size(), context_.fwd_net_args.size()); execute_primitives(context_.fwd_primitives, fwd_cpu_stream, @@ -98,9 +98,9 @@ class MklSoftmaxPrimitive : public MklPrimitive { std::shared_ptr dst_mem; // Primitive descriptor. -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 std::shared_ptr fwd_desc; -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 // Memory descriptor. std::shared_ptr src_md; @@ -115,9 +115,9 @@ class MklSoftmaxPrimitive : public MklPrimitive { SoftmaxFwdContext() : src_mem(nullptr), dst_mem(nullptr), -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 fwd_desc(nullptr), -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 src_md(nullptr), fwd_pd(nullptr), softmax_fwd(nullptr) { @@ -132,7 +132,7 @@ class MklSoftmaxPrimitive : public MklPrimitive { new memory::desc({fwdParams.src_dims}, MklDnnType(), src_format)); // Create softmax descriptor and primitive descriptor. -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 context_.fwd_desc.reset(new dnnl::softmax_forward::desc( prop_kind::forward_scoring, *context_.src_md, fwdParams.axis)); context_.fwd_pd.reset(new dnnl::softmax_forward::primitive_desc( @@ -142,7 +142,7 @@ class MklSoftmaxPrimitive : public MklPrimitive { cpu_engine_, prop_kind::forward_inference, dnnl::algorithm::softmax_accurate, *context_.src_md, *context_.src_md /* dst_md */, fwdParams.axis)); -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 // Create memory primitive based on dummy data. context_.src_mem.reset( diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h index caa0b11305d199..e426221ff844d5 100644 --- a/tensorflow/core/util/mkl_util.h +++ b/tensorflow/core/util/mkl_util.h @@ -159,7 +159,7 @@ inline void execute_primitives( } } -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 #define ARE_MEMORY_DESCS_EQUAL(md1, md2) dnnl_memory_desc_equal(&md1, &md2) #define CREATE_MEMORY_DESC_USING_STRIDES dnnl_memory_desc_init_by_strides #define GET_DATA_TYPE data_type @@ -193,7 +193,7 @@ inline void execute_primitives( #define GET_STRIDES_DIMS(dims, dims_outer_blocks) dims #define INIT_DIMS_FROM_DESC(in_dims, md) in_dims = md.get_dims() #define MEMORY_DESC memory::desc -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 // In oneDNN v1.x, the format (ex. NCHW) used to initialize a memory descriptor // (md) structure will no longer be recorded in its `format` field. Instead, it @@ -477,14 +477,14 @@ class MklDnnShape { inline void SetElemType(memory::data_type dt) { data_.T_ = dt; } inline const memory::data_type GetElemType() { return data_.T_; } -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 inline void SetMklLayout(memory::desc* md) { CHECK_NOTNULL(md); data_.mkl_md_ = md->data; } #else inline void SetMklLayout(const memory::desc& md) { data_.mkl_md_ = md; } -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 inline const memory::desc GetMklLayout() const { return memory::desc(data_.mkl_md_); @@ -1342,7 +1342,7 @@ inline void CreateAndExecuteReorder(const ReorderPd& reorder_desc, std::vector net; net.push_back(dnnl::reorder(reorder_desc)); std::vector net_args; -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 net_args.push_back({{DNNL_ARG_FROM, src_mem}, {DNNL_ARG_TO, dst_mem}}); #else if (scale_mem != nullptr) { @@ -1352,7 +1352,7 @@ inline void CreateAndExecuteReorder(const ReorderPd& reorder_desc, } else { net_args.push_back({{DNNL_ARG_FROM, src_mem}, {DNNL_ARG_TO, dst_mem}}); } -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 ExecutePrimitive(net, &net_args, engine, ctx); } @@ -1514,11 +1514,11 @@ class MklDnnData { std::shared_ptr t_stream = nullptr) { CHECK_NOTNULL(user_memory_); CHECK_NOTNULL(data_buffer); -#if !defined(ENABLE_ONEDNN_OPENMP) && !defined(ENABLE_ONEDNN_V3) +#if !defined(ENABLE_ONEDNN_OPENMP) && defined(ENABLE_ONEDNN_V2) user_memory_->set_data_handle(data_buffer, *t_stream); #else user_memory_->set_data_handle(data_buffer); -#endif // !ENABLE_ONEDNN_OPENMP && !ENABLE_ONEDNN_V3 +#endif // !ENABLE_ONEDNN_OPENMP && ENABLE_ONEDNN_V2 } /// Set function for data buffer of user memory primitive. @@ -2195,7 +2195,7 @@ class MklReorderPrimitiveFactory : public MklPrimitiveFactory { auto to_inner_blks = to_desc.GET_INNER_BLKS; auto to_inner_idxs = to_desc.GET_INNER_IDXS; auto to_strides = to_desc.GET_STRIDES; -#ifndef ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V2 memory::dims from_inner_blks_1(from_inner_blks, &from_inner_blks[from_inner_nblks]); memory::dims from_inner_idxs_1(from_inner_idxs, @@ -2206,7 +2206,7 @@ class MklReorderPrimitiveFactory : public MklPrimitiveFactory { &from_strides[from_desc.ndims]); memory::dims to_strides_outer_blocks(to_strides, &to_strides[to_desc.ndims]); -#endif // !ENABLE_ONEDNN_V3 +#endif // ENABLE_ONEDNN_V2 key_creator.AddAsKey(prefix); #ifdef DNNL_AARCH64_USE_ACL diff --git a/tensorflow/core/util/mkl_util_test.cc b/tensorflow/core/util/mkl_util_test.cc index e456a912ff2b63..30799c72dcb538 100644 --- a/tensorflow/core/util/mkl_util_test.cc +++ b/tensorflow/core/util/mkl_util_test.cc @@ -53,6 +53,9 @@ TEST(MklUtilTest, MklDnnTfShape) { EXPECT_NE(b_tf_shape_nchw, b_mkldnn_tf_shape); } +#ifdef ENABLE_ONEDNN_V2 +// TODO(intel-tf): This code is not tested for oneDNN v3.x and needs to be +// removed TEST(MklUtilTest, MklDnnBlockedFormatTest) { // Let's create 2D tensor of shape {3, 4} with 3 being innermost dimension // first (case 1) and then it being outermost dimension (case 2). @@ -80,6 +83,7 @@ TEST(MklUtilTest, MklDnnBlockedFormatTest) { EXPECT_EQ(b_md2.data.dims[0], 3); EXPECT_EQ(b_md2.data.dims[1], 4); } +#endif // ENABLE_ONEDNN_V2 TEST(MklUtilTest, LRUCacheTest) { // The cached objects are of type int* diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index aad649ecc49476..92c87d2b79fda0 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -20,7 +20,6 @@ load( "//tensorflow/tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured", ) -load("//third_party/mkl_dnn:build_defs.bzl", "if_onednn_v3") # TODO(mdan): Break into per-directory files. @@ -674,10 +673,10 @@ pywrap_tensorflow_macro( "@local_config_tensorrt//:__subpackages__", "@local_execution_config_platform//:__subpackages__", "@mkl_dnn_acl_compatible//:__subpackages__", - "@mkl_dnn_v1//:__subpackages__", "@nccl_archive//:__subpackages__", "@nsync//:__subpackages__", "@nvtx_archive//:__subpackages__", + "@onednn//:__subpackages__", "@org_sqlite//:__subpackages__", "@platforms//:__subpackages__", "@png//:__subpackages__", @@ -692,7 +691,7 @@ pywrap_tensorflow_macro( "@upb//:__subpackages__", "@XNNPACK//:__subpackages__", "@zlib//:__subpackages__", - ] + tsl_async_value_deps() + if_onednn_v3(["@onednn_v3//:__subpackages__"]), + ] + tsl_async_value_deps(), win_def_file = ":pywrap_tensorflow_filtered_def_file", deps = [ "//tensorflow/c:c_api", diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index 7ad3a2a358b6a1..875f35e4ce9e29 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -53,7 +53,6 @@ load( "if_mkldnn_aarch64_acl", "if_mkldnn_aarch64_acl_openmp", "if_mkldnn_openmp", - "if_onednn_v3", ) load( "//third_party/compute_library:build_defs.bzl", @@ -441,9 +440,8 @@ def tf_copts( # optimizations for Intel builds using oneDNN if configured if_enable_mkl(["-DENABLE_MKL"]) + if_mkldnn_openmp(["-DENABLE_ONEDNN_OPENMP"]) + - if_onednn_v3(["-DENABLE_ONEDNN_V3"]) + - if_mkldnn_aarch64_acl(["-DDNNL_AARCH64_USE_ACL=1"]) + - if_mkldnn_aarch64_acl_openmp(["-DENABLE_ONEDNN_OPENMP"]) + + if_mkldnn_aarch64_acl(["-DDNNL_AARCH64_USE_ACL=1", "-DENABLE_ONEDNN_V2=1"]) + + if_mkldnn_aarch64_acl_openmp(["-DENABLE_ONEDNN_OPENMP", "-DENABLE_ONEDNN_V2=1"]) + if_zendnn(["-DAMD_ZENDNN"]) + if_enable_acl(["-DXLA_CPU_USE_ACL=1", "-fexceptions"]) + if_android_arm(["-mfpu=neon", "-fomit-frame-pointer"]) + @@ -3259,9 +3257,9 @@ def tf_python_pybind_static_deps(testonly = False): "@local_config_tensorrt//:__subpackages__", "@local_execution_config_platform//:__subpackages__", "@mkl_dnn_acl_compatible//:__subpackages__", - "@mkl_dnn_v1//:__subpackages__", "@nsync//:__subpackages__", "@nccl_archive//:__subpackages__", + "@onednn//:__subpackages__", "@org_sqlite//:__subpackages__", "@platforms//:__subpackages__", "@png//:__subpackages__", @@ -3280,7 +3278,6 @@ def tf_python_pybind_static_deps(testonly = False): "@com_google_benchmark//:__subpackages__", "@com_google_googletest//:__subpackages__", ] - static_deps += if_onednn_v3(["@onednn_v3//:__subpackages__"]) return if_oss(static_deps) # buildozer: enable=function-docstring-args diff --git a/tensorflow/tsl/BUILD b/tensorflow/tsl/BUILD index 47e987dcdab035..076836bc5418d7 100644 --- a/tensorflow/tsl/BUILD +++ b/tensorflow/tsl/BUILD @@ -266,24 +266,6 @@ config_setting( visibility = ["//visibility:public"], ) -selects.config_setting_group( - name = "linux_x86_64_with_onednn_v2", - match_all = [ - ":linux_x86_64", - "@org_tensorflow//third_party/mkl_dnn:build_with_onednn_v2", - ], - visibility = ["//visibility:public"], -) - -selects.config_setting_group( - name = "linux_x86_64_with_onednn_v3", - match_all = [ - ":linux_x86_64", - "@org_tensorflow//third_party/mkl_dnn:build_with_onednn_v3", - ], - visibility = ["//visibility:public"], -) - config_setting( name = "ios_x86_64", flag_values = if_google( diff --git a/tensorflow/tsl/framework/contraction/BUILD b/tensorflow/tsl/framework/contraction/BUILD index 8e5a63dcddbda2..3cfd5f5351413d 100644 --- a/tensorflow/tsl/framework/contraction/BUILD +++ b/tensorflow/tsl/framework/contraction/BUILD @@ -121,7 +121,7 @@ cc_library( "//tensorflow/tsl:linux_ppc64le": [], "//tensorflow/tsl:linux_s390x": [], "//tensorflow/tsl:macos_arm64": [], - "//conditions:default": ["@mkl_dnn_v1//:mkl_dnn"], + "//conditions:default": ["@onednn//:mkl_dnn"], }), ) diff --git a/tensorflow/tsl/mkl/build_defs.bzl b/tensorflow/tsl/mkl/build_defs.bzl index eaa0b2dbde729a..28d15f2b4fcfca 100644 --- a/tensorflow/tsl/mkl/build_defs.bzl +++ b/tensorflow/tsl/mkl/build_defs.bzl @@ -102,9 +102,8 @@ def mkl_deps(): """ return select({ "@org_tensorflow//tensorflow/tsl/mkl:build_with_mkl_aarch64": ["@mkl_dnn_acl_compatible//:mkl_dnn_acl"], - "@org_tensorflow//tensorflow/tsl:linux_x86_64_with_onednn_v2": ["@mkl_dnn_v1//:mkl_dnn"], - "@org_tensorflow//tensorflow/tsl:linux_x86_64_with_onednn_v3": ["@onednn_v3//:mkl_dnn"], - "@org_tensorflow//tensorflow/tsl:windows": ["@mkl_dnn_v1//:mkl_dnn"], + "@org_tensorflow//tensorflow/tsl:linux_x86_64": ["@onednn//:mkl_dnn"], + "@org_tensorflow//tensorflow/tsl:windows": ["@onednn//:mkl_dnn"], "//conditions:default": [], }) diff --git a/tensorflow/tsl/tsl.bzl b/tensorflow/tsl/tsl.bzl index a17264c50bfe26..ce729c50664e5b 100644 --- a/tensorflow/tsl/tsl.bzl +++ b/tensorflow/tsl/tsl.bzl @@ -253,8 +253,8 @@ def tsl_copts( # optimizations for Intel builds using oneDNN if configured if_enable_mkl(["-DENABLE_MKL"]) + if_mkldnn_openmp(["-DENABLE_ONEDNN_OPENMP"]) + - if_mkldnn_aarch64_acl(["-DDNNL_AARCH64_USE_ACL=1"]) + - if_mkldnn_aarch64_acl_openmp(["-DENABLE_ONEDNN_OPENMP"]) + + if_mkldnn_aarch64_acl(["-DDNNL_AARCH64_USE_ACL=1", "-DENABLE_ONEDNN_V2=1"]) + + if_mkldnn_aarch64_acl_openmp(["-DENABLE_ONEDNN_OPENMP", "-DENABLE_ONEDNN_V2=1"]) + if_enable_acl(["-DXLA_CPU_USE_ACL=1", "-fexceptions"]) + if_android_arm(["-mfpu=neon", "-fomit-frame-pointer"]) + if_linux_x86_64(["-msse3"]) + diff --git a/tensorflow/workspace2.bzl b/tensorflow/workspace2.bzl index 9f1f474c270fe9..abd115d9fee4c8 100644 --- a/tensorflow/workspace2.bzl +++ b/tensorflow/workspace2.bzl @@ -192,11 +192,11 @@ def _tf_repositories(): ) tf_http_archive( - name = "onednn_v3", + name = "onednn", build_file = "//third_party/mkl_dnn:mkldnn_v1.BUILD", - sha256 = "28e31f2d576e1a7e3a796f5c33c1d733c256078cff1c48b9e2a692d5975e1401", - strip_prefix = "oneDNN-3.1", - urls = tf_mirror_urls("https://github.com/oneapi-src/oneDNN/archive/refs/tags/v3.1.tar.gz"), + sha256 = "8b1db9cc5799ae39c2a567eb836962de0346d79fbc3d8e6f7090a3d9f8729129", + strip_prefix = "oneDNN-3.2", + urls = tf_mirror_urls("https://github.com/oneapi-src/oneDNN/archive/refs/tags/v3.2.tar.gz"), ) tf_http_archive( diff --git a/third_party/mkl_dnn/BUILD b/third_party/mkl_dnn/BUILD index 151bc26b01cd58..7fc25ea49d5987 100644 --- a/third_party/mkl_dnn/BUILD +++ b/third_party/mkl_dnn/BUILD @@ -28,27 +28,6 @@ config_setting( visibility = ["//visibility:public"], ) -config_setting( - name = "build_with_onednn_v3", - define_values = { - "build_with_mkl": "true", - "build_with_onednn_v3": "true", - }, - visibility = ["//visibility:public"], -) - -# The following config is needed since oneDNN v2.x and v3.x are API incompatible. -config_setting( - name = "build_with_onednn_v2", - define_values = { - # We are not defining 'build_with_mkl' since this config can be invoked - # on x86_64 platforms without --config=mkl (through Eigen contraction - # kernel) - "build_with_onednn_v2": "true", - }, - visibility = ["//visibility:public"], -) - config_setting( name = "build_with_mkl_aarch64_openmp", define_values = { diff --git a/third_party/mkl_dnn/build_defs.bzl b/third_party/mkl_dnn/build_defs.bzl index 0e24769059068f..c8ac81b673c91e 100644 --- a/third_party/mkl_dnn/build_defs.bzl +++ b/third_party/mkl_dnn/build_defs.bzl @@ -1,7 +1,6 @@ """Starlark macros for oneDNN. if_mkldnn_openmp checks if we are building x86 backend with OpenMP. -if_onednn_v3 checks if we are using oneDNN v3. if_mkldnn_aarch64_acl checks if we are building with Arm Compute Library. if_mkldnn_aarch64_acl_openmp checks if we are building ACL with OpenMP. """ @@ -22,19 +21,6 @@ def if_mkldnn_openmp(if_true, if_false = []): "//conditions:default": if_false, }) -def if_onednn_v3(if_true, if_false = []): - """Returns `if_true` if oneDNN v3.x is used. - - Returns a select statement which evaluates to if_true if we're building - with oneDNN v3.x open source library only. Otherwise, the select statement - evaluates to if_false. - - """ - return select({ - "@org_tensorflow//third_party/mkl_dnn:build_with_onednn_v3": if_true, - "//conditions:default": if_false, - }) - def if_mkldnn_aarch64_acl(if_true, if_false = []): return select({ "@org_tensorflow//third_party/mkl:build_with_mkl_aarch64": if_true, diff --git a/third_party/mkl_dnn/mkldnn_v1.BUILD b/third_party/mkl_dnn/mkldnn_v1.BUILD index dd3f63f738c8e3..e648e76cb3e0d9 100644 --- a/third_party/mkl_dnn/mkldnn_v1.BUILD +++ b/third_party/mkl_dnn/mkldnn_v1.BUILD @@ -15,6 +15,7 @@ _CMAKE_COMMON_LIST = { "#cmakedefine DNNL_SYCL_HIP": "#undef DNNL_SYCL_HIP", "#cmakedefine DNNL_ENABLE_STACK_CHECKER": "#undef DNNL_ENABLE_STACK_CHECKER", "#cmakedefine DNNL_EXPERIMENTAL": "#undef DNNL_EXPERIMENTAL", + "#cmakedefine ONEDNN_BUILD_GRAPH": "#undef ONEDNN_BUILD_GRAPH", "#cmakedefine01 BUILD_TRAINING": "#define BUILD_TRAINING 1", "#cmakedefine01 BUILD_INFERENCE": "#define BUILD_INFERENCE 0", "#cmakedefine01 BUILD_PRIMITIVE_ALL": "#define BUILD_PRIMITIVE_ALL 1", @@ -86,9 +87,9 @@ expand_template( name = "dnnl_version_h", out = "include/oneapi/dnnl/dnnl_version.h", substitutions = { - "@DNNL_VERSION_MAJOR@": "2", - "@DNNL_VERSION_MINOR@": "7", - "@DNNL_VERSION_PATCH@": "3", + "@DNNL_VERSION_MAJOR@": "3", + "@DNNL_VERSION_MINOR@": "2", + "@DNNL_VERSION_PATCH@": "0", "@DNNL_VERSION_HASH@": "N/A", }, template = "include/oneapi/dnnl/dnnl_version.h.in", From dfa72976f7fb85b3061c2b7b089f2c5d2cd5075a Mon Sep 17 00:00:00 2001 From: Bhavani Subramanian Date: Mon, 24 Jul 2023 16:56:40 -0700 Subject: [PATCH 081/410] Addressed review comments --- .../core/kernels/mkl/mkl_avgpooling_op.cc | 16 +- tensorflow/core/kernels/mkl/mkl_concat_op.cc | 12 +- .../kernels/mkl/mkl_conv_grad_filter_ops.cc | 28 ++-- .../kernels/mkl/mkl_conv_grad_input_ops.cc | 32 ++-- tensorflow/core/kernels/mkl/mkl_conv_ops.cc | 120 +++++++-------- tensorflow/core/kernels/mkl/mkl_conv_ops.h | 4 +- .../core/kernels/mkl/mkl_dequantize_op.cc | 8 +- .../mkl/mkl_eltwise_activation_base_op.h | 36 ++--- .../kernels/mkl/mkl_fused_batch_norm_op.cc | 144 +++++++++--------- .../kernels/mkl/mkl_fused_instance_norm_op.cc | 24 +-- tensorflow/core/kernels/mkl/mkl_kernel_util.h | 4 +- .../core/kernels/mkl/mkl_layer_norm_op.cc | 12 +- .../core/kernels/mkl/mkl_matmul_ops_common.h | 76 ++++----- .../core/kernels/mkl/mkl_maxpooling_op.cc | 16 +- .../kernels/mkl/mkl_pooling_ops_common.cc | 28 ++-- .../core/kernels/mkl/mkl_pooling_ops_common.h | 60 ++++---- tensorflow/core/kernels/mkl/mkl_qmatmul_op.cc | 22 +-- .../core/kernels/mkl/mkl_qmatmul_op_test.cc | 8 +- .../core/kernels/mkl/mkl_quantize_op.cc | 60 ++++---- tensorflow/core/kernels/mkl/mkl_relu_op.cc | 4 +- .../core/kernels/mkl/mkl_relu_op_test.cc | 4 +- .../mkl/mkl_requantize_per_channel_op.cc | 8 +- tensorflow/core/kernels/mkl/mkl_softmax_op.cc | 16 +- tensorflow/core/util/mkl_util.h | 20 +-- tensorflow/core/util/mkl_util_test.cc | 4 +- tensorflow/tensorflow.bzl | 5 + tensorflow/tsl/mkl/build_defs.bzl | 15 ++ 27 files changed, 403 insertions(+), 383 deletions(-) diff --git a/tensorflow/core/kernels/mkl/mkl_avgpooling_op.cc b/tensorflow/core/kernels/mkl/mkl_avgpooling_op.cc index 169d2878e460f7..7e5ebd5b6a8362 100644 --- a/tensorflow/core/kernels/mkl/mkl_avgpooling_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_avgpooling_op.cc @@ -83,12 +83,12 @@ class MklAvgPoolingOp : public MklPoolingForwardOpBase { memory::dims filter_dims, strides, padding_left, padding_right; // Get src/filter/stride/padding information. -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 this->PoolParamsToDims(&pool_params, &filter_dims, &strides, #else memory::dims dilations; this->PoolParamsToDims(&pool_params, &filter_dims, &strides, &dilations, -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 &padding_left, &padding_right, is_pool2d); // Get the input memory descriptor. @@ -114,11 +114,11 @@ class MklAvgPoolingOp : public MklPoolingForwardOpBase { pooling_prop_kind = prop_kind::forward_training; MklPoolingParams fwdParams( -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 src_dims, output_dims_mkl_order, filter_dims, strides, #else src_dims, output_dims_mkl_order, filter_dims, strides, dilations, -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 padding_left, padding_right, dnnl::algorithm::pooling_avg_exclude_padding, pooling_prop_kind, static_cast(this->data_format_mkldnn_), input_md, @@ -285,12 +285,12 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase { output_shape); memory::dims filter_dims, strides, padding_left, padding_right; -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 this->PoolParamsToDims(&pool_params, &filter_dims, &strides, #else memory::dims dilations; this->PoolParamsToDims(&pool_params, &filter_dims, &strides, &dilations, -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 &padding_left, &padding_right, is_pool2d); memory::dims orig_input_dims_mkl_order = @@ -335,11 +335,11 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase { // that is used in the backward pass. MklPoolingParams bwdParams( orig_input_dims_mkl_order, output_dims_mkl_order, filter_dims, -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 strides, padding_left, padding_right, #else strides, dilations, padding_left, padding_right, -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 dnnl::algorithm::pooling_avg_exclude_padding, prop_kind::forward_training, static_cast(this->data_format_mkldnn_), src_md, diff --git a/tensorflow/core/kernels/mkl/mkl_concat_op.cc b/tensorflow/core/kernels/mkl/mkl_concat_op.cc index f41b8f018827b7..497f997860176b 100644 --- a/tensorflow/core/kernels/mkl/mkl_concat_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_concat_op.cc @@ -40,7 +40,7 @@ using dnnl::concat; using dnnl::stream; namespace tensorflow { -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 #define CONCAT_PRIM_DESC(eng, concat_dims, src_md, dst_md_ptr) \ concat::primitive_desc(*dst_md_ptr, concat_dims, src_md, eng) #define CONCAT_PRIM_DESC_USING_SRC(eng, concat_dims, src_md) \ @@ -54,7 +54,7 @@ namespace tensorflow { concat::primitive_desc(eng, concat_dims, src_md) #define GET_MEMORY_DESC(md) md #define SET_MKL_LAYOUT(md) SetMklLayout(md) -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 typedef Eigen::ThreadPoolDevice CPUDevice; // List of TensorShape objects. Used in Concat/Split layers. @@ -302,7 +302,7 @@ class MklConcatFwdPrimitive : public MklPrimitive { #endif DCHECK_EQ(in_data.size(), context_.data_mem.size()); for (size_t i = 0; i < concat_fwd_dims.num_inputs; i++) { -#if !defined(ENABLE_ONEDNN_OPENMP) && defined(ENABLE_ONEDNN_V2) +#if !defined(ENABLE_ONEDNN_OPENMP) && !defined(ENABLE_ONEDNN_V3) context_.data_mem_shdptr[i]->set_data_handle( static_cast(in_data[i].get_data_handle()), *fwd_stream); } @@ -314,7 +314,7 @@ class MklConcatFwdPrimitive : public MklPrimitive { } context_.dst_mem->set_data_handle( static_cast(dst_data.get_data_handle())); -#endif // !ENABLE_ONEDNN_OPENMP && ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_OPENMP && !ENABLE_ONEDNN_V3 for (size_t i = 0; i < concat_fwd_dims.num_inputs; i++) { context_.data_mem[i] = *context_.data_mem_shdptr[i]; @@ -663,7 +663,7 @@ class MklConcatOp : public OpKernel { auto src_tf_fmt = MklTensorFormatToMklDnnDataFormat( mkl_input_shapes[k].GetTfDataFormat()); if (src_tf_fmt != mkl_common_format) { -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 memory::dims src_dims(src_md.data.dims, &src_md.data.dims[src_md.data.ndims]); #else @@ -673,7 +673,7 @@ class MklConcatOp : public OpKernel { else if (src_md.get_ndims() == 4) src_dims = {src_md.get_dims()[0], src_md.get_dims()[1], src_md.get_dims()[2], src_md.get_dims()[3]}; -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 src_md = memory::desc(src_dims, MklDnnType(), mkl_common_format); } diff --git a/tensorflow/core/kernels/mkl/mkl_conv_grad_filter_ops.cc b/tensorflow/core/kernels/mkl/mkl_conv_grad_filter_ops.cc index dca67b791cf796..7c24a23d68226e 100644 --- a/tensorflow/core/kernels/mkl/mkl_conv_grad_filter_ops.cc +++ b/tensorflow/core/kernels/mkl/mkl_conv_grad_filter_ops.cc @@ -36,9 +36,9 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 using ConvBwdFilterDesc = dnnl::convolution_backward_weights::desc; -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 using ConvBwdFilterPd = dnnl::convolution_backward_weights::primitive_desc; struct MklConvBwdFilterParams { @@ -96,7 +96,7 @@ class MklConvBwdFilterPrimitive : public MklPrimitive { #ifdef DNNL_AARCH64_USE_ACL mutex_lock lock(primitive_execution_mu_); #endif -#if !defined(ENABLE_ONEDNN_OPENMP) && defined(ENABLE_ONEDNN_V2) +#if !defined(ENABLE_ONEDNN_OPENMP) && !defined(ENABLE_ONEDNN_V3) // TODO(intel-tf): Create a common function and avoid the duplicate code context_.src_mem->set_data_handle( static_cast(const_cast(src_data)), *bwd_filter_stream); @@ -121,7 +121,7 @@ class MklConvBwdFilterPrimitive : public MklPrimitive { } context_.diff_dst_mem->set_data_handle( static_cast(const_cast(diff_dst_data))); -#endif // !ENABLE_ONEDNN_OPENMP && ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_OPENMP && !ENABLE_ONEDNN_V3 execute_primitives(context_.bwd_filter_primitives, bwd_filter_stream, context_.bwd_filter_primitives_args); @@ -159,15 +159,15 @@ class MklConvBwdFilterPrimitive : public MklPrimitive { // Primitive descriptor and descriptor for convolution backward filter. std::shared_ptr bwd_filter_pd; -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 std::shared_ptr bwd_filter_desc; -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 // Primitive descriptor and descriptor for convolution forward. std::shared_ptr fwd_pd; -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 std::shared_ptr fwd_desc; -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 // Convolution backward filter primitive. std::shared_ptr conv_bwd_filter; @@ -188,13 +188,13 @@ class MklConvBwdFilterPrimitive : public MklPrimitive { diff_filter_mem(nullptr), diff_bias_mem(nullptr), diff_dst_mem(nullptr), -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 bwd_filter_desc(nullptr), -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 fwd_pd(nullptr), -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 fwd_desc(nullptr), -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 src_md(nullptr), diff_filter_md(nullptr), diff_bias_md(nullptr), @@ -228,7 +228,7 @@ class MklConvBwdFilterPrimitive : public MklPrimitive { new memory::desc({convBwdFilterDims.diff_bias_dims}, MklDnnType(), memory::format_tag::x)); -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 // Create descriptor and primitive descriptor for convolution forward. context_.fwd_desc.reset(new ConvFwdDesc( prop_kind::forward, dnnl::algorithm::convolution_direct, @@ -276,7 +276,7 @@ class MklConvBwdFilterPrimitive : public MklPrimitive { convBwdFilterDims.padding_left, convBwdFilterDims.padding_right, *context_.fwd_pd)); } -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 auto bwd_filter_pd = context_.bwd_filter_pd.get(); diff --git a/tensorflow/core/kernels/mkl/mkl_conv_grad_input_ops.cc b/tensorflow/core/kernels/mkl/mkl_conv_grad_input_ops.cc index d6f71efd2f4c56..16a6db176843b1 100644 --- a/tensorflow/core/kernels/mkl/mkl_conv_grad_input_ops.cc +++ b/tensorflow/core/kernels/mkl/mkl_conv_grad_input_ops.cc @@ -39,15 +39,15 @@ using dnnl::prop_kind; using dnnl::stream; namespace tensorflow { -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 #define SET_MKL_LAYOUT(md) SetMklLayout(&md) #else #define SET_MKL_LAYOUT(md) SetMklLayout(md) -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 using ConvBwdDataDesc = dnnl::convolution_backward_data::desc; -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 using ConvBwdDataPd = dnnl::convolution_backward_data::primitive_desc; // Utility classes for enabling primitive reuse for conv bwd input. @@ -103,7 +103,7 @@ class MklConvBwdInputPrimitive : public MklPrimitive { #ifdef DNNL_AARCH64_USE_ACL mutex_lock lock(primitive_execution_mu_); #endif -#if !defined(ENABLE_ONEDNN_OPENMP) && defined(ENABLE_ONEDNN_V2) +#if !defined(ENABLE_ONEDNN_OPENMP) && !defined(ENABLE_ONEDNN_V3) // TODO(intel-tf): Create a common function and avoid the duplicate code context_.diff_src_mem->set_data_handle( static_cast(const_cast(diff_src_data)), *bwd_input_stream); @@ -118,7 +118,7 @@ class MklConvBwdInputPrimitive : public MklPrimitive { static_cast(const_cast(filter_data))); context_.diff_dst_mem->set_data_handle( static_cast(const_cast(diff_dst_data))); -#endif // !ENABLE_ONEDNN_OPENMP && ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_OPENMP && !ENABLE_ONEDNN_V3 execute_primitives(context_.bwd_input_primitives, bwd_input_stream, context_.bwd_input_primitives_args); @@ -143,15 +143,15 @@ class MklConvBwdInputPrimitive : public MklPrimitive { // Conv backward input primitive descriptor and descriptor. std::shared_ptr bwd_input_pd; -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 std::shared_ptr bwd_input_desc; -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 // Primitive descriptor and descriptor for conv fwd std::shared_ptr fwd_pd; -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 std::shared_ptr fwd_desc; -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 // Conv bwd input primitive. std::shared_ptr conv_bwd_input; @@ -170,13 +170,13 @@ class MklConvBwdInputPrimitive : public MklPrimitive { filter_mem(nullptr), diff_dst_mem(nullptr), bwd_input_pd(nullptr), -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 bwd_input_desc(nullptr), -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 fwd_pd(nullptr), -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 fwd_desc(nullptr), -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 conv_bwd_input(nullptr), diff_src_md(nullptr), filter_md(nullptr), @@ -204,7 +204,7 @@ class MklConvBwdInputPrimitive : public MklPrimitive { memory::format_tag::any)); // Create descriptors for both conv fwd and conv bwd input. -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 context_.bwd_input_desc.reset(new ConvBwdDataDesc( dnnl::algorithm::convolution_direct, *context_.diff_src_md, *context_.filter_md, *context_.diff_dst_md, convBwdInputDims.strides, @@ -233,7 +233,7 @@ class MklConvBwdInputPrimitive : public MklPrimitive { *context_.filter_md, *context_.diff_dst_md, convBwdInputDims.strides, convBwdInputDims.dilations, convBwdInputDims.padding_left, convBwdInputDims.padding_right, *context_.fwd_pd)); -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 // Create memory using dummy data. context_.diff_src_mem.reset(new memory( diff --git a/tensorflow/core/kernels/mkl/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl/mkl_conv_ops.cc index e37c6c2698cd33..6709876b8c80d8 100644 --- a/tensorflow/core/kernels/mkl/mkl_conv_ops.cc +++ b/tensorflow/core/kernels/mkl/mkl_conv_ops.cc @@ -39,7 +39,7 @@ using ReorderPd = dnnl::reorder::primitive_desc; namespace tensorflow { -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 #define APPEND_DEPTHWISE(wei_dt, bias_dt, dst_dt, kernel, stride, padding, \ scales_mask, scales) \ append_dw(wei_dt, bias_dt, dst_dt, kernel, stride, padding, scales_mask, \ @@ -74,13 +74,13 @@ namespace tensorflow { #define SCALE wei_scale #define SUMMAND_SCALE_U8(summand_range, output_range) summand_range / 255.0f #define SUMMAND_SCALE_S8(summand_range, output_range) summand_range / 127.0f -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 -#if !defined(ENABLE_ONEDNN_OPENMP) && defined(ENABLE_ONEDNN_V2) +#if !defined(ENABLE_ONEDNN_OPENMP) && !defined(ENABLE_ONEDNN_V3) #define FWD_STREAM , *fwd_stream #else #define FWD_STREAM -#endif // !ENABLE_ONEDNN_OPENMP && ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_OPENMP && !ENABLE_ONEDNN_V3 // TODO(intel-tf) Remove this once old API of quantized ops is abandoned namespace quantized_fusions { @@ -284,9 +284,9 @@ class MklConvFwdPrimitive : public MklPrimitive { std::shared_ptr dst_scale_mem; // Desc & primitive desc -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 std::shared_ptr fwd_desc; -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 std::shared_ptr fwd_pd; // Memory desc @@ -325,9 +325,9 @@ class MklConvFwdPrimitive : public MklPrimitive { src_scale_mem(nullptr), wei_scale_mem(nullptr), dst_scale_mem(nullptr), -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 fwd_desc(nullptr), -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 fwd_pd(nullptr), src_md(nullptr), filter_md(nullptr), @@ -372,7 +372,7 @@ class MklConvFwdPrimitive : public MklPrimitive { MklDnnType(), memory::format_tag::any)); } -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 // Create a convolution descriptor context_.fwd_desc.reset(new convolution_forward::desc( prop_kind::forward, dnnl::algorithm::convolution_direct, @@ -385,7 +385,7 @@ class MklConvFwdPrimitive : public MklPrimitive { *context_.src_md, *context_.filter_md, *context_.dst_md, convFwdDims.strides, convFwdDims.dilations, convFwdDims.padding_left, convFwdDims.padding_right)); -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 } if (!convFwdDims.fuse_bn_dims.empty()) { @@ -422,7 +422,7 @@ class MklConvFwdPrimitive : public MklPrimitive { } else if (post_op_param.name == "sum") { DCHECK_EQ(post_op_param.param.size(), 1); float op_scale = post_op_param.param[0]; -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 post_ops.append_sum(op_scale); #else if (post_op_param.dtype != DT_INVALID) { @@ -436,8 +436,8 @@ class MklConvFwdPrimitive : public MklPrimitive { } else { post_ops.append_sum(op_scale); } -#endif //! !ENABLE_ONEDNN_V2 -#ifdef ENABLE_ONEDNN_V2 +#endif //! ENABLE_ONEDNN_V3 +#ifndef ENABLE_ONEDNN_V3 } else if (post_op_param.name == "output_scale") { if (post_op_param.param.size() == 1) { post_ops_attr.set_output_scales(0, post_op_param.param); @@ -470,7 +470,7 @@ class MklConvFwdPrimitive : public MklPrimitive { memory::format_tag::x)); context_.dst_scale_mem.reset( new memory(*context_.dst_scale_md, cpu_engine_, DummyData)); -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 } else if (post_op_param.name == "fuse_bn") { post_ops.append_binary(dnnl::algorithm::binary_sub, *context_.bn_mean_md); @@ -488,7 +488,7 @@ class MklConvFwdPrimitive : public MklPrimitive { } post_ops_attr.set_post_ops(post_ops); } -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 context_.fwd_pd.reset( new ConvFwdPd(*context_.fwd_desc, post_ops_attr, cpu_engine_)); #else @@ -505,7 +505,7 @@ class MklConvFwdPrimitive : public MklPrimitive { convFwdDims.strides, convFwdDims.dilations, convFwdDims.padding_left, convFwdDims.padding_right, post_ops_attr)); } -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 // Create memory primitive based on dummy data context_.src_mem.reset( @@ -530,7 +530,7 @@ class MklConvFwdPrimitive : public MklPrimitive { {DNNL_ARG_BIAS, *context_.bias_mem}, {DNNL_ARG_SCRATCHPAD, *context_.sp_mem}, {DNNL_ARG_DST, *context_.dst_mem}}; -#ifndef ENABLE_ONEDNN_V2 +#ifdef ENABLE_ONEDNN_V3 if (is_scale_set["src"] && is_scale_set["wei"] && is_scale_set["dst"]) { net_args.insert( {{DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, *context_.src_scale_mem}, @@ -538,7 +538,7 @@ class MklConvFwdPrimitive : public MklPrimitive { { DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, *context_.dst_scale_mem }}); } -#endif // !ENABLE_ONEDNN_V2 +#endif // ENABLE_ONEDNN_V3 } else if (!convFwdDims.fuse_bn_dims.empty()) { context_.bn_scale_mem.reset( new memory(*context_.bn_scale_md, cpu_engine_, DummyData)); @@ -566,7 +566,7 @@ class MklConvFwdPrimitive : public MklPrimitive { {DNNL_ARG_WEIGHTS, *context_.filter_mem}, {DNNL_ARG_SCRATCHPAD, *context_.sp_mem}, {DNNL_ARG_DST, *context_.dst_mem}}; -#ifndef ENABLE_ONEDNN_V2 +#ifdef ENABLE_ONEDNN_V3 if (is_scale_set["src"] && is_scale_set["wei"] && is_scale_set["dst"]) { net_args.insert( {{DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, *context_.src_scale_mem}, @@ -574,7 +574,7 @@ class MklConvFwdPrimitive : public MklPrimitive { { DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, *context_.dst_scale_mem }}); } -#endif // !ENABLE_ONEDNN_V2 +#endif // ENABLE_ONEDNN_V3 } context_.fwd_primitives_args.push_back(net_args); context_.fwd_primitives.push_back(*context_.conv_fwd); @@ -663,13 +663,13 @@ class MklConvFwdPrimitiveFactory : public MklPrimitiveFactory { for (auto& param : post_op_param.param) { key_creator.AddAsKey(param); } -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 } else if (post_op_param.name == "output_scale") { #else } else if (post_op_param.name == "src_scale" || post_op_param.name == "wei_scale" || post_op_param.name == "dst_scale") { -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 key_creator.AddAsKey(post_op_param.partial_key); } else if (post_op_param.name == "fuse_bn") { key_creator.AddAsKey(post_op_param.name); @@ -1261,11 +1261,11 @@ class MklConvOp : public OpKernel { MklDnnShape* output_mkl_shape, Tensor** output_tensor) { DCHECK(output_tensor); -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 auto dst_md = conv_prim_desc.dst_desc(); if (!std::is_same::value) { -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 dst_md.data.data_type = static_cast(MklDnnType()); #else @@ -1274,7 +1274,7 @@ class MklConvOp : public OpKernel { dst_md = memory::desc(output_dims_mkl_order, MklDnnType(), MklTensorFormatToMklDnnDataFormat(output_tf_format)); -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 } #else auto dst_md = @@ -1283,7 +1283,7 @@ class MklConvOp : public OpKernel { : memory::desc(conv_prim_desc.dst_desc().get_dims(), MklDnnType(), MklTensorFormatToMklDnnDataFormat(output_tf_format)); -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 // Allocate shape of MKL tensor output_mkl_shape->SetMklTensor(true); @@ -1372,11 +1372,11 @@ class MklConvOp : public OpKernel { string data_format_str_; TensorFormat data_format_; Tensor cached_filter_data_ TF_GUARDED_BY(mu_); -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 Tensor cached_filter_md_ TF_GUARDED_BY(mu_); #else FilterMemoryDesc cached_filter_md_ TF_GUARDED_BY(mu_); -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 // Initialize to values the template is instantiated with bool fuse_biasadd_ = bias_enabled; @@ -1426,7 +1426,7 @@ class MklConvOp : public OpKernel { *filter_tensor = &cached_filter_data_; memory::desc weights_desc = conv_prim_desc.weights_desc(); -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 // There is no tensor format in DNNL 1.x. So we cache the complete filter // descriptor as flat byte array. TensorShape cached_filter_md_shape; @@ -1445,7 +1445,7 @@ class MklConvOp : public OpKernel { weights_desc.get_data_type(), weights_desc.get_dims(), weights_desc.get_inner_blks(), weights_desc.get_inner_idxs(), weights_desc.get_strides()); -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 } void AllocateTensor(OpKernelContext* context, const ConvFwdPd& conv_prim_desc, @@ -1506,12 +1506,12 @@ class MklConvOp : public OpKernel { return; } -#ifndef ENABLE_ONEDNN_V2 +#ifdef ENABLE_ONEDNN_V3 // For now, cache filter only for blocked format if (filter_md.get_format_kind() != memory::format_kind::blocked) { return; } -#endif // !ENABLE_ONEDNN_V2 +#endif // ENABLE_ONEDNN_V3 // Otherwise, cache reordered filter filter.SetUsrMem(filter_md, &filter_tensor); @@ -1527,7 +1527,7 @@ class MklConvOp : public OpKernel { memcpy(cached_filter_data, filter_data, cached_filter_data_size); } -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 // TODO(intel-tf): This function is no longer used and needs to be removed bool AreMemoryDescriptorsEqual(const memory::desc& filter_md, const Tensor& cached_filter_md) { @@ -1545,14 +1545,14 @@ class MklConvOp : public OpKernel { } return true; } -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 Tfilter* GetCachedFilter(OpKernelContext* context, const memory::desc& filter_md) TF_LOCKS_EXCLUDED(mu_) { tf_shared_lock lock(mu_); const Tensor& cached_filter_data = cached_filter_data_; -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 const Tensor& cached_filter_md = cached_filter_md_; // Check if the memory descriptor of the cached weights is the same as @@ -1580,7 +1580,7 @@ class MklConvOp : public OpKernel { const_cast(cached_filter_data.flat().data())); } return nullptr; -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 } }; @@ -1971,19 +1971,19 @@ class MklQuantizedConvOp // If Requantize is fused, we set output_scale as first post op since it is // logically applied before any post op. Then we maintain the order of post // ops according to the order of fused_ops. -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 int idx = fuse_requantize ? 1 : 0; #else post_op_to_idx_["src_scale"] = 0; post_op_to_idx_["wei_scale"] = 1; post_op_to_idx_["dst_scale"] = 2; int idx = 3; -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 for (int i = 0; i < fused_ops_.size(); ++i) { if (fused_ops_[i] == "Requantize") { -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 post_op_to_idx_["output_scale"] = 0; -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 } else if (fused_ops_[i] == "Sum") { post_op_to_idx_["sum"] = idx++; } else if (fused_ops_[i] == "Relu") { @@ -2152,11 +2152,11 @@ class MklQuantizedConvOp std::vector SCALE(depth); float float_input_range = std::max(std::abs(min_input), std::abs(max_input)); -#ifndef ENABLE_ONEDNN_V2 +#ifdef ENABLE_ONEDNN_V3 float int_input_limit = std::is_same::value ? 255.0f : 127.0f; const float src_scale = float_input_range / int_input_limit; -#endif // !ENABLE_ONEDNN_V2 +#endif // ENABLE_ONEDNN_V3 if (std::is_same::value || std::is_same::value) { // min_freezed_output and max_freezed_output are the actual range @@ -2178,18 +2178,18 @@ class MklQuantizedConvOp float float_filter_range = std::max(std::abs(min_filter[i]), std::abs(max_filter[i])); // To understand the scaling, please see mkl_requantize_ops_test. -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 scales[i] = int_output_limit * float_input_range * float_filter_range / (int_const_scale_limit * float_output_range); #else wei_scale[i] = float_filter_range / 127.0; -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 } // we are creating a partial key here to use with primitive key caching to // improve key creation performance. Instead of using actual values we are // using the pointers for min/max_filter_vector, and this works since the // filter vector here is a constant. -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 FactoryKeyCreator param_key; param_key.AddAsKey(min_input); param_key.AddAsKey(max_input); @@ -2209,9 +2209,9 @@ class MklQuantizedConvOp dnnl::algorithm::undef, {dst_scale}, dst_param_key.GetKey()}; -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 } else { -#ifndef ENABLE_ONEDNN_V2 +#ifdef ENABLE_ONEDNN_V3 if (!std::is_same::value) TF_CHECK_OK(absl::FailedPreconditionError( "Output datatype is expected to be qint32.")); @@ -2234,10 +2234,10 @@ class MklQuantizedConvOp dnnl::algorithm::undef, {dst_scale}, dst_param_key.GetKey()}; -#endif // !ENABLE_ONEDNN_V2 +#endif // ENABLE_ONEDNN_V3 } -#ifndef ENABLE_ONEDNN_V2 +#ifdef ENABLE_ONEDNN_V3 FactoryKeyCreator src_param_key; src_param_key.AddAsKey(min_input); src_param_key.AddAsKey(max_input); @@ -2251,7 +2251,7 @@ class MklQuantizedConvOp src_param_key.GetKey()}; params.post_op_params[post_op_to_idx_["wei_scale"]] = { "wei_scale", dnnl::algorithm::undef, wei_scale, wei_param_key.GetKey()}; -#endif // !ENABLE_ONEDNN_V2 +#endif // ENABLE_ONEDNN_V3 if (this->get_fuse_add()) { // Calculate the scale (beta in oneDNN api term) for sum DataType summand_dt = this->input_type(this->get_input_add_idx()); @@ -2326,9 +2326,9 @@ class MklQuantizedConvOp dnnl::algorithm::undef, {1.0}, "", -#ifndef ENABLE_ONEDNN_V2 +#ifdef ENABLE_ONEDNN_V3 summand_dt -#endif // !ENABLE_ONEDNN_V2 +#endif // ENABLE_ONEDNN_V3 }; } } @@ -2376,7 +2376,7 @@ class MklQuantizedConvOp "Summand cannot be forwarded in the current fusion.")); return; } -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 MklConvOp< Device, Tinput, /*Tfilter*/ qint8, Tbias, Toutput, Ttemp_output, /*Tpadding*/ int32, @@ -2411,7 +2411,7 @@ class MklQuantizedConvOp std::max(std::abs(max_filter[i]), std::abs(min_filter[i]))); } dnnl::primitive_attr reorder_attr; -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 if (depth == 1) { reorder_attr.set_output_scales(0, scales); } else { @@ -2424,7 +2424,7 @@ class MklQuantizedConvOp reorder_attr.set_scales_mask(DNNL_ARG_SRC, 0); reorder_attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0); reorder_attr.set_scales_mask(DNNL_ARG_DST, 0); -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 auto summand_md = memory::desc(output_dims_mkl_order, MklDnnType(), memory::format_tag::nhwc); void* summand_buf = @@ -2456,7 +2456,7 @@ class MklQuantizedConvOp absl::InvalidArgumentError( "Summand cannot be forwarded in the current fusion.")); -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 } } @@ -2466,7 +2466,7 @@ class MklQuantizedConvOp if (!this->get_fuse_biasadd()) { return nullptr; } -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 if (std::is_same::value) { return static_cast( const_cast(bias_tensor.flat().data())); @@ -2502,7 +2502,7 @@ class MklQuantizedConvOp } if (!is_bias_const_ || IsBiasCacheEmpty(context) || !scales_are_valid) { dnnl::primitive_attr bias_attr; -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 if (depth == 1) { bias_attr.set_output_scales(0, scales_); } else { @@ -2515,7 +2515,7 @@ class MklQuantizedConvOp bias_attr.set_scales_mask(DNNL_ARG_SRC, 0); bias_attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0); bias_attr.set_scales_mask(DNNL_ARG_DST, 0); -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 auto bias_md = memory::desc({static_cast(bias_tensor.NumElements())}, MklDnnType(), memory::format_tag::x); @@ -2641,7 +2641,7 @@ class MklQuantizedConvOp } return GetCachedBias(context); -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 } bool is_bias_const_; diff --git a/tensorflow/core/kernels/mkl/mkl_conv_ops.h b/tensorflow/core/kernels/mkl/mkl_conv_ops.h index 6abb2862d35e6d..0384df4b309285 100644 --- a/tensorflow/core/kernels/mkl/mkl_conv_ops.h +++ b/tensorflow/core/kernels/mkl/mkl_conv_ops.h @@ -48,11 +48,11 @@ using dnnl::stream; namespace tensorflow { -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 // Op descriptor is no longer supported in oneDNN v3.x. Instead, primitive // descriptor will directly accept primitive parameters during creation. using ConvFwdDesc = dnnl::convolution_forward::desc; -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 using ConvFwdPd = dnnl::convolution_forward::primitive_desc; class MklDnnConvUtil { diff --git a/tensorflow/core/kernels/mkl/mkl_dequantize_op.cc b/tensorflow/core/kernels/mkl/mkl_dequantize_op.cc index 957612606c32d4..a41afd657824a4 100644 --- a/tensorflow/core/kernels/mkl/mkl_dequantize_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_dequantize_op.cc @@ -152,7 +152,7 @@ class MklDequantizeOp : public OpKernel { std::vector scales; scales.push_back(scale_factor); primitive_attr attr; -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 attr.set_output_scales(0, scales); #else attr.set_scales_mask(DNNL_ARG_SRC, 0); @@ -160,7 +160,7 @@ class MklDequantizeOp : public OpKernel { MklDnnType(), memory::format_tag::x}, cpu_engine, scales.data()); -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 std::vector net; // Create reorder primitive and then execute. @@ -169,7 +169,7 @@ class MklDequantizeOp : public OpKernel { dst.GetUsrMem()->get_desc(), attr); net.push_back(reorder(reorder_pd)); std::vector> reorder_net_args; -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 reorder_net_args.push_back({{DNNL_ARG_FROM, *src.GetUsrMem()}, { DNNL_ARG_TO, *dst.GetUsrMem() }}); @@ -178,7 +178,7 @@ class MklDequantizeOp : public OpKernel { {{DNNL_ARG_FROM, *src.GetUsrMem()}, {DNNL_ARG_TO, *dst.GetUsrMem()}, {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, scale_mem}}); -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 execute_primitives(net, reorder_stream, reorder_net_args); } catch (dnnl::error& e) { string error_msg = "Status: " + std::to_string(e.status) + diff --git a/tensorflow/core/kernels/mkl/mkl_eltwise_activation_base_op.h b/tensorflow/core/kernels/mkl/mkl_eltwise_activation_base_op.h index ad1f251b9f4022..9bfddc2f1d0991 100644 --- a/tensorflow/core/kernels/mkl/mkl_eltwise_activation_base_op.h +++ b/tensorflow/core/kernels/mkl/mkl_eltwise_activation_base_op.h @@ -43,11 +43,11 @@ using dnnl::stream; using EltwiseFwdActivationPd = dnnl::eltwise_forward::primitive_desc; namespace tensorflow { -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 #define GET_MEMORY_DESC(md) md.data #else #define GET_MEMORY_DESC(md) md -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 // TODO(tf-onednn): Consolidate this class with `MklEltWiseFwdParams` // in `mkl_relu_op.cc`. @@ -59,23 +59,23 @@ class MklEltwiseFwdActivationParams { public: memory::dims src_dims; memory::desc src_md; -#ifndef ENABLE_ONEDNN_V2 +#ifdef ENABLE_ONEDNN_V3 memory::desc dst_md; -#endif // !ENABLE_ONEDNN_V2 +#endif // ENABLE_ONEDNN_V3 algorithm alg_kind; float alpha; float beta; MklEltwiseFwdActivationParams(memory::dims src_dims, memory::desc src_md, -#ifndef ENABLE_ONEDNN_V2 +#ifdef ENABLE_ONEDNN_V3 memory::desc dst_md, -#endif // !ENABLE_ONEDNN_V2 +#endif // ENABLE_ONEDNN_V3 algorithm alg_kind, float alpha, float beta) : src_dims(src_dims), src_md(src_md), -#ifndef ENABLE_ONEDNN_V2 +#ifdef ENABLE_ONEDNN_V3 dst_md(dst_md), -#endif // !ENABLE_ONEDNN_V2 +#endif // ENABLE_ONEDNN_V3 alg_kind(alg_kind), alpha(alpha), beta(beta) { @@ -134,9 +134,9 @@ class MklEltwiseFwdActivationPrimitive : public MklPrimitive { std::shared_ptr dst_mem; // desc & primitive desc -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 std::shared_ptr fwd_desc; -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 std::shared_ptr fwd_pd; // memory desc @@ -156,9 +156,9 @@ class MklEltwiseFwdActivationPrimitive : public MklPrimitive { EltwiseFwdActivationContext() : src_mem(nullptr), dst_mem(nullptr), -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 fwd_desc(nullptr), -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 fwd_pd(nullptr), src_md(nullptr), dst_md(nullptr), @@ -174,7 +174,7 @@ class MklEltwiseFwdActivationPrimitive : public MklPrimitive { context_.src_mpd.reset(new memory::desc(*context_.src_md)); // Create an eltwise forward descriptor and primitive descriptor -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 context_.fwd_desc.reset(new eltwise_forward::desc( prop_kind::forward, fwdParams.alg_kind, *context_.src_md, fwdParams.alpha, fwdParams.beta)); @@ -185,7 +185,7 @@ class MklEltwiseFwdActivationPrimitive : public MklPrimitive { context_.fwd_pd.reset(new EltwiseFwdActivationPd( cpu_engine_, prop_kind::forward, fwdParams.alg_kind, *context_.src_md, *context_.dst_md, fwdParams.alpha, fwdParams.beta)); -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 auto fwd_pd = context_.fwd_pd.get(); // Create memory primitive based on dummy data @@ -301,15 +301,15 @@ class MklEltwiseFwdActivationOpBase : public OpKernel { // Create blocked memory descriptor src_md = MklDnnData::CreateBlockedMemDesc(src_dims, src_strides); -#ifndef ENABLE_ONEDNN_V2 +#ifdef ENABLE_ONEDNN_V3 memory::desc dst_md = src_md; -#endif // !ENABLE_ONEDNN_V2 +#endif // ENABLE_ONEDNN_V3 // Try to get an eltwise forward primitive from caching pool MklEltwiseFwdActivationParams fwdParams(src_dims, src_md, -#ifndef ENABLE_ONEDNN_V2 +#ifdef ENABLE_ONEDNN_V3 dst_md, -#endif // !ENABLE_ONEDNN_V2 +#endif // ENABLE_ONEDNN_V3 alg_kind, alpha_, beta_); MklEltwiseFwdActivationPrimitive* eltwise_fwd = MklEltwiseFwdActivationPrimitiveFactory::Get(fwdParams); diff --git a/tensorflow/core/kernels/mkl/mkl_fused_batch_norm_op.cc b/tensorflow/core/kernels/mkl/mkl_fused_batch_norm_op.cc index 1ae5080895e320..9d4736fc6a83a8 100644 --- a/tensorflow/core/kernels/mkl/mkl_fused_batch_norm_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_fused_batch_norm_op.cc @@ -41,7 +41,7 @@ using BatchNormBwdPd = dnnl::batch_normalization_backward::primitive_desc; namespace tensorflow { -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 #define FORWARD_INFERENCE prop_kind::forward_scoring #define GET_DIFF_SCALE_DATA_BUFFER diff_scale_shift_data #define GET_DIFF_SCALE_SHIFT_DATA_BUFFERS diff_scale_shift_data @@ -63,7 +63,7 @@ namespace tensorflow { #define SCALE_SHIFT_NET_ARGS \ {DNNL_ARG_SCALE, *context_.scale_mem}, { DNNL_ARG_SHIFT, *context_.shift_mem } #define SET_MKL_LAYOUT(md) SetMklLayout(md) -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 using CPUDevice = Eigen::ThreadPoolDevice; @@ -77,16 +77,16 @@ struct MklBatchNormFwdParams { TensorFormat data_format; FusedBNActivationMode activation_mode; memory::desc src_md; -#ifndef ENABLE_ONEDNN_V2 +#ifdef ENABLE_ONEDNN_V3 memory::desc dst_md; -#endif // !ENABLE_ONEDNN_V2 +#endif // ENABLE_ONEDNN_V3 MklBatchNormFwdParams(const memory::dims& src_dims, int depth, float eps, bool training, TensorFormat data_format, memory::desc src_md, -#ifndef ENABLE_ONEDNN_V2 +#ifdef ENABLE_ONEDNN_V3 memory::desc dst_md, -#endif // !ENABLE_ONEDNN_V2 +#endif // ENABLE_ONEDNN_V3 FusedBNActivationMode activation_mode) : src_dims(src_dims), depth(depth), @@ -94,14 +94,14 @@ struct MklBatchNormFwdParams { training(training), data_format(data_format), activation_mode(activation_mode), -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 src_md(src_md) { } #else src_md(src_md), dst_md(dst_md) { } -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 }; template @@ -123,18 +123,18 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive { // dst_data: output data buffer of dst // mean_data: output data buffer of means // variance_data: output data buffer of variances -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 void Execute(const T* src_data, const U* scale_shift_data, T* dst_data, #else void Execute(const T* src_data, const U* scale_data, const U* shift_data, T* dst_data, -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 U* mean_data, U* variance_data, std::shared_ptr fwd_stream, U* workspace_data) { #ifdef DNNL_AARCH64_USE_ACL mutex_lock lock(primitive_execution_mu_); #endif -#if !defined(ENABLE_ONEDNN_OPENMP) && defined(ENABLE_ONEDNN_V2) +#if !defined(ENABLE_ONEDNN_OPENMP) && !defined(ENABLE_ONEDNN_V3) // TODO(intel-tf): Create a common function and avoid the duplicate code context_.src_mem->set_data_handle( static_cast(const_cast(src_data)), *fwd_stream); @@ -161,7 +161,7 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive { context_.dst_mem->set_data_handle(static_cast(dst_data)); if (IS_SCALE_AND_SHIFT_FLAG_SET) { -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 context_.scale_shift_mem->set_data_handle( static_cast(const_cast(scale_shift_data))); #else @@ -169,7 +169,7 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive { static_cast(const_cast(scale_data))); context_.shift_mem->set_data_handle( static_cast(const_cast(shift_data))); -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 } if ((context_.pkind == prop_kind::forward_training) || @@ -180,7 +180,7 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive { if (workspace_data != nullptr) { context_.ws_mem->set_data_handle(workspace_data); } -#endif // !ENABLE_ONEDNN_OPENMP && ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_OPENMP && !ENABLE_ONEDNN_V3 // Execute batch-normalization forward primitives. execute_primitives(context_.fwd_primitives, fwd_stream, context_.net_args); @@ -189,12 +189,12 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive { context_.dst_mem->set_data_handle(DummyData); if (IS_SCALE_AND_SHIFT_FLAG_SET) { -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 context_.scale_shift_mem->set_data_handle(DummyData); #else context_.scale_mem->set_data_handle(DummyData); context_.shift_mem->set_data_handle(DummyData); -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 } if ((context_.pkind == prop_kind::forward_training) || @@ -225,12 +225,12 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive { // Inputs/outputs memory. std::shared_ptr src_mem; -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 std::shared_ptr scale_shift_mem; #else std::shared_ptr scale_mem; std::shared_ptr shift_mem; -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 std::shared_ptr dst_mem; std::shared_ptr mean_mem; std::shared_ptr variance_mem; @@ -249,12 +249,12 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive { : flags(0), pkind(prop_kind::forward_training), src_mem(nullptr), -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 scale_shift_mem(nullptr), #else scale_mem(nullptr), shift_mem(nullptr), -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 dst_mem(nullptr), mean_mem(nullptr), variance_mem(nullptr), @@ -275,7 +275,7 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive { // Memory descriptor auto src_md = fwdParams.src_md; // Create forward BatchNorm descriptor and primitive descriptor. -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 auto fwd_desc = batch_normalization_forward::desc( context_.pkind, src_md, fwdParams.eps, static_cast(context_.flags)); @@ -286,7 +286,7 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive { context_.fwd_pd.reset(new BatchNormFwdPd( cpu_engine_, context_.pkind, src_md, dst_md, fwdParams.eps, static_cast(context_.flags))); -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 // Create memory primitive based on dummy data context_.src_mem.reset( @@ -296,7 +296,7 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive { memory::dims m_dims = {1, fwdParams.depth}; if (IS_SCALE_AND_SHIFT_FLAG_SET) { -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 memory::dims s_dims = {2, fwdParams.depth}; context_.scale_shift_mem.reset( new memory({{s_dims}, MklDnnType(), memory::format_tag::nc}, @@ -309,7 +309,7 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive { context_.shift_mem.reset( new memory({{s_dims}, MklDnnType(), memory::format_tag::x}, cpu_engine_, DummyData)); -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 } if (fwdParams.training || (IS_SET(use_global_stats))) { @@ -482,18 +482,18 @@ struct MklBatchNormBwdParams { bool training; TensorFormat data_format; memory::desc src_md; -#ifndef ENABLE_ONEDNN_V2 +#ifdef ENABLE_ONEDNN_V3 memory::desc dst_md; memory::desc diff_src_md; -#endif // !ENABLE_ONEDNN_V2 +#endif // ENABLE_ONEDNN_V3 memory::desc diff_dst_md; MklBatchNormBwdParams(memory::dims src_dims, memory::dims diff_dst_dims, int depth, float eps, bool training, TensorFormat data_format, memory::desc src_md, -#ifndef ENABLE_ONEDNN_V2 +#ifdef ENABLE_ONEDNN_V3 memory::desc dst_md, memory::desc diff_src_md, -#endif // !ENABLE_ONEDNN_V2 +#endif // ENABLE_ONEDNN_V3 memory::desc diff_dst_md) : src_dims(src_dims), diff_dst_dims(diff_dst_dims), @@ -502,10 +502,10 @@ struct MklBatchNormBwdParams { training(training), data_format(data_format), src_md(src_md), -#ifndef ENABLE_ONEDNN_V2 +#ifdef ENABLE_ONEDNN_V3 dst_md(dst_md), diff_src_md(diff_src_md), -#endif // !ENABLE_ONEDNN_V2 +#endif // ENABLE_ONEDNN_V3 diff_dst_md(diff_dst_md) { } }; @@ -533,19 +533,19 @@ class MklFusedBatchNormBwdPrimitive : public MklPrimitive { // intermediate results is not implemented // on CPU as of now. void Execute(const T* src_data, const U* mean_data, const U* variance_data, -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 const T* diff_dst_data, const U* scale_shift_data, T* diff_src_data, U* diff_scale_shift_data, U* res_space_data, #else // oneDNN v3.x does not require 'shift_data' const T* diff_dst_data, const U* scale_data, T* diff_src_data, U* diff_scale_data, U* diff_shift_data, U* res_space_data, -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 std::shared_ptr bwd_stream) { #ifdef DNNL_AARCH64_USE_ACL mutex_lock lock(primitive_execution_mu_); #endif -#if !defined(ENABLE_ONEDNN_OPENMP) && defined(ENABLE_ONEDNN_V2) +#if !defined(ENABLE_ONEDNN_OPENMP) && !defined(ENABLE_ONEDNN_V3) // TODO(intel-tf): Create a common function and avoid the duplicate code context_.src_mem->set_data_handle( static_cast(const_cast(src_data)), *bwd_stream); @@ -576,7 +576,7 @@ class MklFusedBatchNormBwdPrimitive : public MklPrimitive { static_cast(const_cast(diff_dst_data))); if (IS_SCALE_AND_SHIFT_FLAG_SET) { -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 context_.scale_shift_mem->set_data_handle( static_cast(const_cast(scale_shift_data))); context_.diff_scale_shift_mem->set_data_handle( @@ -588,11 +588,11 @@ class MklFusedBatchNormBwdPrimitive : public MklPrimitive { static_cast(diff_scale_data)); context_.diff_shift_mem->set_data_handle( static_cast(diff_shift_data)); -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 } context_.diff_src_mem->set_data_handle(static_cast(diff_src_data)); -#endif // !ENABLE_ONEDNN_OPENMP && ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_OPENMP && !ENABLE_ONEDNN_V3 // Execute backward batch-normalization primitives. DCHECK_EQ(context_.bwd_primitives.size(), context_.net_args.size()); execute_primitives(context_.bwd_primitives, bwd_stream, context_.net_args); @@ -603,14 +603,14 @@ class MklFusedBatchNormBwdPrimitive : public MklPrimitive { context_.variance_mem->set_data_handle(DummyData); context_.diff_dst_mem->set_data_handle(DummyData); if (IS_SCALE_AND_SHIFT_FLAG_SET) { -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 context_.scale_shift_mem->set_data_handle(DummyData); context_.diff_scale_shift_mem->set_data_handle(DummyData); #else context_.scale_mem->set_data_handle(DummyData); context_.diff_scale_mem->set_data_handle(DummyData); context_.diff_shift_mem->set_data_handle(DummyData); -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 } context_.diff_src_mem->set_data_handle(DummyData); } @@ -631,14 +631,14 @@ class MklFusedBatchNormBwdPrimitive : public MklPrimitive { std::shared_ptr mean_mem; std::shared_ptr variance_mem; std::shared_ptr diff_dst_mem; -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 std::shared_ptr scale_shift_mem; std::shared_ptr diff_scale_shift_mem; #else std::shared_ptr scale_mem; std::shared_ptr diff_scale_mem; std::shared_ptr diff_shift_mem; -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 std::shared_ptr diff_src_mem; // Backward batch-normalization primitive descriptor. @@ -655,14 +655,14 @@ class MklFusedBatchNormBwdPrimitive : public MklPrimitive { mean_mem(nullptr), variance_mem(nullptr), diff_dst_mem(nullptr), -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 scale_shift_mem(nullptr), diff_scale_shift_mem(nullptr), #else scale_mem(nullptr), diff_scale_mem(nullptr), diff_shift_mem(nullptr), -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 diff_src_mem(nullptr) { } }; @@ -682,19 +682,19 @@ class MklFusedBatchNormBwdPrimitive : public MklPrimitive { memory::format_tag::nc); auto mean_desc = memory::desc({1, bwdParams.depth}, MklDnnType(), memory::format_tag::nc); -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 auto scale_shift_desc = memory::desc({2, bwdParams.depth}, MklDnnType(), memory::format_tag::nc); #else auto scale_shift_desc = memory::desc({bwdParams.depth}, MklDnnType(), memory::format_tag::x); -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 auto diff_scale_shift_desc = scale_shift_desc; // Forward batch-normalization descriptor and primitive descriptor. // Adding this back due to type difference with context.flags auto bn_flags = GetBatchNormFlags(bwdParams); -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 auto fwd_desc = batch_normalization_forward::desc( prop_kind::forward_training, src_md, bwdParams.eps, static_cast(bn_flags)); @@ -719,7 +719,7 @@ class MklFusedBatchNormBwdPrimitive : public MklPrimitive { cpu_engine_, prop_kind::backward, diff_src_md, diff_dst_md, src_md, bwdParams.eps, static_cast(bn_flags), fwd_pd)); -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 // Create memory primitives. context_.src_mem.reset(new memory(src_md, cpu_engine_, DummyData)); @@ -728,7 +728,7 @@ class MklFusedBatchNormBwdPrimitive : public MklPrimitive { context_.variance_mem.reset( new memory(variance_desc, cpu_engine_, DummyData)); context_.mean_mem.reset(new memory(mean_desc, cpu_engine_, DummyData)); -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 context_.scale_shift_mem.reset( new memory(scale_shift_desc, cpu_engine_, DummyData)); context_.diff_scale_shift_mem.reset( @@ -740,7 +740,7 @@ class MklFusedBatchNormBwdPrimitive : public MklPrimitive { new memory(diff_scale_shift_desc, cpu_engine_, DummyData)); context_.diff_shift_mem.reset( new memory(diff_scale_shift_desc, cpu_engine_, DummyData)); -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 context_.diff_src_mem.reset(new memory(src_md, cpu_engine_, DummyData)); context_.bn_bwd.reset(new batch_normalization_backward(*context_.bwd_pd)); @@ -750,7 +750,7 @@ class MklFusedBatchNormBwdPrimitive : public MklPrimitive { {DNNL_ARG_VARIANCE, *context_.variance_mem}, {DNNL_ARG_DIFF_DST, *context_.diff_dst_mem}, {DNNL_ARG_DIFF_SRC, *context_.diff_src_mem}, -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 {DNNL_ARG_SCALE_SHIFT, *context_.scale_shift_mem}, { DNNL_ARG_DIFF_SCALE_SHIFT, *context_.diff_scale_shift_mem }}); @@ -759,7 +759,7 @@ class MklFusedBatchNormBwdPrimitive : public MklPrimitive { {DNNL_ARG_SCALE, *context_.scale_mem}, {DNNL_ARG_DIFF_SCALE, *context_.diff_scale_mem}, {DNNL_ARG_DIFF_SHIFT, *context_.diff_shift_mem}}); -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 context_.bwd_primitives.push_back(*context_.bn_bwd); } @@ -977,12 +977,12 @@ class MklFusedBatchNormOp : public OpKernel { Tensor* reserved_space_tensor = nullptr; MklDnnData src(&cpu_engine_); -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 MklDnnData scale_shift(&cpu_engine_); #else MklDnnData scale(&cpu_engine_); MklDnnData shift(&cpu_engine_); -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 MklDnnData wksp(&cpu_engine_); memory::format_tag dnn_fmt; @@ -1009,17 +1009,17 @@ class MklFusedBatchNormOp : public OpKernel { auto src_md = dnn_shape_src.IsMklTensor() ? dnn_shape_src.GetMklLayout() : memory::desc(src_dims, MklDnnType(), dnn_fmt); -#ifndef ENABLE_ONEDNN_V2 +#ifdef ENABLE_ONEDNN_V3 auto dst_md = memory::desc(src_dims, MklDnnType(), dnn_fmt); -#endif // !ENABLE_ONEDNN_V2 +#endif // ENABLE_ONEDNN_V3 MklBatchNormFwdParams fwdParams(src_dims, depth_, epsilon_, is_training_, -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 tensor_format_, src_md, activation_mode_); #else tensor_format_, src_md, dst_md, activation_mode_); -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 // Create the oneDNN wrapper over Eigen threadpool and set max threads // in oneDNN. @@ -1062,7 +1062,7 @@ class MklFusedBatchNormOp : public OpKernel { else SetMeanVariance(est_mean_tensor, est_variance_tensor); -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 // oneDNN packs scale & shift as a combined array in float32 type // ...... scale_shift.AllocateBuffer(2 * depth_ * sizeof(U)); @@ -1083,7 +1083,7 @@ class MklFusedBatchNormOp : public OpKernel { const U* shift_tf = shift_tensor.flat().data(); std::memcpy(scale_data, scale_tf, depth_ * sizeof(U)); std::memcpy(shift_data, shift_tf, depth_ * sizeof(U)); -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 char* saved_mean_data_tf = reinterpret_cast(saved_mean_tensor->flat().data()); @@ -1124,12 +1124,12 @@ class MklFusedBatchNormOp : public OpKernel { AllocateOutputSetMklShape(context, kDstIndex, &dst_tensor, tf_shape_dst, dnn_shape_dst, native_format); -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 U* scale_shift_op_data = scale_shift_data; #else U* scale_op_data = scale_data; U* shift_op_data = shift_data; -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 U* mean_op_data = saved_mean_tensor->flat().data(); U* variance_op_data = saved_variance_tensor->flat().data(); T* dst_data = dst_tensor->flat().data(); @@ -1138,12 +1138,12 @@ class MklFusedBatchNormOp : public OpKernel { std::shared_ptr fwd_cpu_stream; fwd_cpu_stream.reset(CreateStream(&eigen_tp, bn_fwd->GetEngine())); -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 bn_fwd->Execute(src_data, scale_shift_op_data, dst_data, mean_op_data, #else bn_fwd->Execute(src_data, scale_op_data, shift_op_data, dst_data, mean_op_data, -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 variance_op_data, fwd_cpu_stream, ws_data); float adjust_factor = 1.0; if (is_training_) { @@ -1464,14 +1464,14 @@ class MklFusedBatchNormGradOp : public OpKernel { MklDnnData src(&cpu_engine_); MklDnnData diff_dst(&cpu_engine_); -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 MklDnnData scale_shift(&cpu_engine_); MklDnnData diff_scale_shift(&cpu_engine_); #else MklDnnData scale(&cpu_engine_); MklDnnData diff_scale(&cpu_engine_); MklDnnData diff_shift(&cpu_engine_); -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 memory::dims src_dims = dnn_shape_src.IsMklTensor() @@ -1492,11 +1492,11 @@ class MklFusedBatchNormGradOp : public OpKernel { dnn_shape_diff_dst.IsMklTensor() ? dnn_shape_diff_dst.GetMklLayout() : memory::desc(diff_dst_dims, MklDnnType(), dnn_fmt); -#ifndef ENABLE_ONEDNN_V2 +#ifdef ENABLE_ONEDNN_V3 memory::desc dst_md = memory::desc(src_dims, MklDnnType(), dnn_fmt); memory::desc diff_src_md = memory::desc(diff_dst_dims, MklDnnType(), dnn_fmt); -#endif // !ENABLE_ONEDNN_V2 +#endif // ENABLE_ONEDNN_V3 MklDnnData reorder_src(&cpu_engine_); MklDnnData reorder_diff_dst(&cpu_engine_); @@ -1525,7 +1525,7 @@ class MklFusedBatchNormGradOp : public OpKernel { } } -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 // scale_shift -- oneDNN packs scales/shifts as scale_shift in order // of scale, ..., scale, shift, ...., shift scale_shift.AllocateBuffer(2 * depth_ * sizeof(U)); @@ -1546,13 +1546,13 @@ class MklFusedBatchNormGradOp : public OpKernel { std::memcpy(scale_data_tf, scale_tf, depth_ * sizeof(U)); diff_scale.AllocateBuffer(depth_ * sizeof(U)); diff_shift.AllocateBuffer(depth_ * sizeof(U)); -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 MklBatchNormBwdParams bwdParams(src_dims, diff_dst_dims, depth_, epsilon_, is_training_, tensor_format_, src_md, -#ifndef ENABLE_ONEDNN_V2 +#ifdef ENABLE_ONEDNN_V3 dst_md, diff_src_md, -#endif // !ENABLE_ONEDNN_V2 +#endif // ENABLE_ONEDNN_V3 diff_dst_md); Eigen::ThreadPoolInterface* eigen_interface = EigenThreadPoolFromTfContext(context); @@ -1600,7 +1600,7 @@ class MklFusedBatchNormGradOp : public OpKernel { static_cast(const_cast(saved_mean_tensor.flat().data())); U* variance_data = static_cast( const_cast(saved_variance_tensor.flat().data())); -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 U* scale_shift_data = scale_shift_data_tf; U* diff_scale_shift_data = static_cast(diff_scale_shift.GetAllocatedBuffer()); @@ -1608,7 +1608,7 @@ class MklFusedBatchNormGradOp : public OpKernel { U* scale_data = scale_data_tf; U* diff_scale_data = static_cast(diff_scale.GetAllocatedBuffer()); U* diff_shift_data = static_cast(diff_shift.GetAllocatedBuffer()); -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 T* diff_src_data = static_cast(diff_src_tensor->flat().data()); U* res_space_data = diff --git a/tensorflow/core/kernels/mkl/mkl_fused_instance_norm_op.cc b/tensorflow/core/kernels/mkl/mkl_fused_instance_norm_op.cc index c7dd9c585e04b6..b33be9350aedc7 100644 --- a/tensorflow/core/kernels/mkl/mkl_fused_instance_norm_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_fused_instance_norm_op.cc @@ -104,11 +104,11 @@ class MklFusedInstanceNormOp : public OpKernel { } auto src_md = memory::desc(src_dims, MklDnnType(), tag); -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 #define NUM_DUPLICATE 2 #else #define NUM_DUPLICATE 1 -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 memory::dims scale_shift_dims = { static_cast(NUM_DUPLICATE * num_elements_scale)}; auto scale_shift_md = memory::desc(scale_shift_dims, MklDnnType(), @@ -116,7 +116,7 @@ class MklFusedInstanceNormOp : public OpKernel { int64_t tensor_shape = scale_shift_md.get_size() / sizeof(float); #undef NUM_DUPLICATE -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 Tensor scale_shift_tensor; OP_REQUIRES_OK( ctx, ctx->allocate_temp(DataTypeToEnum::v(), {tensor_shape}, @@ -150,12 +150,12 @@ class MklFusedInstanceNormOp : public OpKernel { num_elements_scale, scale_fp32_buf, shift_fp32_buf); auto scale_mem = memory(scale_shift_md, cpu_engine_, scale_fp32_buf); auto shift_mem = memory(scale_shift_md, cpu_engine_, shift_fp32_buf); -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 batch_normalization_forward::primitive_desc bnorm_pd; if (fuse_activation_) { dnnl::post_ops post_ops; dnnl::primitive_attr post_ops_attr; -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 post_ops.append_eltwise(1.0, dnnl::algorithm::eltwise_relu, leakyrelu_alpha_, 0.0); post_ops_attr.set_post_ops(post_ops); @@ -169,16 +169,16 @@ class MklFusedInstanceNormOp : public OpKernel { cpu_engine_, prop_kind::forward_inference, src_md, src_md, epsilon_, normalization_flags::use_scale | normalization_flags::use_shift, post_ops_attr); -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 } else { -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 bnorm_pd = batch_normalization_forward::primitive_desc(bnorm_desc, cpu_engine_); #else bnorm_pd = batch_normalization_forward::primitive_desc( cpu_engine_, prop_kind::forward_inference, src_md, src_md, epsilon_, normalization_flags::use_scale | normalization_flags::use_shift); -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 } auto bnorm_prim = batch_normalization_forward(bnorm_pd); @@ -198,12 +198,12 @@ class MklFusedInstanceNormOp : public OpKernel { std::unordered_map bnorm_args; bnorm_args.insert({DNNL_ARG_SRC, *src_mem_ptr}); bnorm_args.insert({DNNL_ARG_DST, *dst_mem_ptr}); -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 bnorm_args.insert({DNNL_ARG_SCALE_SHIFT, scale_shift_mem}); #else bnorm_args.insert({DNNL_ARG_SCALE, scale_mem}); bnorm_args.insert({DNNL_ARG_SHIFT, shift_mem}); -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 // Perform batchnorm computation for each batch in input for (int i = 0; i < batch_size; i++) { @@ -287,14 +287,14 @@ class MklFusedInstanceNormOp : public OpKernel { auto data_size = sizeof(float) * num_elements; void* scale_buf_dst = fp32_scale_or_combine_buf; void* shift_buf_dst = nullptr; -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 shift_buf_dst = static_cast(fp32_scale_or_combine_buf) + data_size; (void)fp32_shift_buf; #else OP_REQUIRES(ctx, (fp32_shift_buf != nullptr), absl::InvalidArgumentError("Invalid shift buffer")); shift_buf_dst = fp32_shift_buf; -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 if (std::is_same::value) { memcpy(scale_buf_dst, scale_buf_src, data_size); diff --git a/tensorflow/core/kernels/mkl/mkl_kernel_util.h b/tensorflow/core/kernels/mkl/mkl_kernel_util.h index 4dd6c17dd55f84..2e7ff537cfa000 100644 --- a/tensorflow/core/kernels/mkl/mkl_kernel_util.h +++ b/tensorflow/core/kernels/mkl/mkl_kernel_util.h @@ -49,7 +49,7 @@ class MklTestingUtil { } }; -#ifndef ENABLE_ONEDNN_V2 +#ifdef ENABLE_ONEDNN_V3 // Since oneDNN v3.x exposes only an opaque memory descriptor, it is no longer // possible to cache the entire filter memory descriptor as is. So we store // all relevant information about it in the following class. @@ -94,7 +94,7 @@ class FilterMemoryDesc { memory::dims inner_idxs_; memory::dims strides_; }; -#endif // !ENABLE_ONEDNN_V2 +#endif // ENABLE_ONEDNN_V3 } // namespace tensorflow #endif // INTEL_MKL diff --git a/tensorflow/core/kernels/mkl/mkl_layer_norm_op.cc b/tensorflow/core/kernels/mkl/mkl_layer_norm_op.cc index fd92b8309cfe9f..ae5ad08b3f4393 100644 --- a/tensorflow/core/kernels/mkl/mkl_layer_norm_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_layer_norm_op.cc @@ -79,7 +79,7 @@ class MklLayerNormOp : public OpKernel { static_cast(const_cast(src_tensor.flat().data())); auto src_mem = memory(src_md, cpu_engine, src_buf); -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 // oneDNN v2.x requires scale-shift as a combined array in float32 type. memory::dims scale_shift_dims = { 2, static_cast(num_elements_scale)}; @@ -124,7 +124,7 @@ class MklLayerNormOp : public OpKernel { void* shift_buf_dst = static_cast(shift_buf_tensor.flat().data()); auto shift_mem = memory(scale_shift_md, cpu_engine, shift_buf_dst); -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 void* scale_buf_src = static_cast(const_cast(scale_tensor.flat().data())); @@ -159,7 +159,7 @@ class MklLayerNormOp : public OpKernel { shift_reorder_prim.execute(*cpu_stream, shift_reorder_args); // Create layer_normalization primitive -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 auto lnorm_desc = layer_normalization_forward::desc( prop_kind::forward_inference, src_md, epsilon_, normalization_flags::use_scale_shift); @@ -170,7 +170,7 @@ class MklLayerNormOp : public OpKernel { auto lnorm_pd = layer_normalization_forward::primitive_desc( cpu_engine, prop_kind::forward_inference, src_md, dst_md, epsilon_, normalization_flags::use_scale | normalization_flags::use_shift); -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 auto lnorm_prim = layer_normalization_forward(lnorm_pd); // mean and variance memory @@ -189,12 +189,12 @@ class MklLayerNormOp : public OpKernel { lnorm_args.insert({DNNL_ARG_SRC, src_mem}); lnorm_args.insert({DNNL_ARG_MEAN, mean_mem}); lnorm_args.insert({DNNL_ARG_VARIANCE, variance_mem}); -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 lnorm_args.insert({DNNL_ARG_SCALE_SHIFT, scale_shift_mem}); #else lnorm_args.insert({DNNL_ARG_SCALE, scale_mem}); lnorm_args.insert({DNNL_ARG_SHIFT, shift_mem}); -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 lnorm_args.insert({DNNL_ARG_DST, dst_mem}); lnorm_prim.execute(*cpu_stream, lnorm_args); } catch (dnnl::error& e) { diff --git a/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h b/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h index 60852930bc5a7a..4b86eb50d5863e 100644 --- a/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h +++ b/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h @@ -40,7 +40,7 @@ using dnnl::stream; namespace tensorflow { -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 #define APPEND_ELTWISE(scale, alg, alpha, beta) \ append_eltwise(scale, alg, alpha, beta) #define APPEND_ELTWISE_RELU6(scale, alpha, beta) \ @@ -58,13 +58,13 @@ namespace tensorflow { (post_op_param.name == "dst_scale") #define SET_MKL_LAYOUT(md) SetMklLayout(md) #define TSCALED_BIAS float -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 -#if !defined(ENABLE_ONEDNN_OPENMP) && defined(ENABLE_ONEDNN_V2) +#if !defined(ENABLE_ONEDNN_OPENMP) && !defined(ENABLE_ONEDNN_V3) #define FWD_STREAM , *fwd_stream #else #define FWD_STREAM -#endif // !ENABLE_ONEDNN_OPENMP && ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_OPENMP && !ENABLE_ONEDNN_V3 static Eigen::internal::CacheSizes cache_sizes = Eigen::internal::CacheSizes(); @@ -208,9 +208,9 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive { std::shared_ptr dst_scale_mem; // Descriptor and primitive-descriptor for forward inner-product. -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 std::shared_ptr fwd_desc; -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 std::shared_ptr fwd_pd; // Memory descriptors. @@ -238,9 +238,9 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive { src_scale_mem(nullptr), wei_scale_mem(nullptr), dst_scale_mem(nullptr), -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 fwd_desc(nullptr), -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 fwd_pd(nullptr), src_md(nullptr), weight_md(nullptr), @@ -282,7 +282,7 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive { memory::format_tag::any)); } // Create an inner-product. -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 context_.fwd_desc.reset(new inner_product_forward::desc( matmul_fwd_params.const_weight ? prop_kind::forward_inference : prop_kind::forward_training, @@ -290,7 +290,7 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive { *context_.dst_md)); context_.fwd_pd.reset(new inner_product_forward::primitive_desc( *context_.fwd_desc, cpu_engine_)); -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 // Check if there is any fusion as post-ops auto const& post_op_params = matmul_fwd_params.post_op_params; @@ -348,7 +348,7 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive { float op_beta = post_op_param.param[2]; post_ops.APPEND_ELTWISE(op_scale, dnnl::algorithm::eltwise_logistic, op_alpha, op_beta); -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 } else if (post_op_param.name == "output_scale") { DCHECK_EQ(post_op_param.param.size(), 1); std::vector scales; @@ -376,7 +376,7 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive { memory::format_tag::x)); context_.dst_scale_mem.reset( new memory(*context_.dst_scale_md, cpu_engine_, DummyData)); -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 } else if (post_op_param.name == "sum") { DCHECK_EQ(post_op_param.param.size(), 1); float op_scale = post_op_param.param[0]; @@ -395,7 +395,7 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive { post_ops_attr.set_post_ops(post_ops); } -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 context_.fwd_pd.reset(new inner_product_forward::primitive_desc( *context_.fwd_desc, post_ops_attr, cpu_engine_)); #else @@ -405,7 +405,7 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive { : prop_kind::forward_training, *context_.src_md, *context_.weight_md, *context_.bias_md, *context_.dst_md, post_ops_attr)); -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 // Create memory primitive based on dummy data context_.src_mem.reset( @@ -428,7 +428,7 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive { {DNNL_ARG_BIAS, *context_.bias_mem}, {DNNL_ARG_SCRATCHPAD, *context_.sp_mem}, {DNNL_ARG_DST, *context_.dst_mem}}; -#ifndef ENABLE_ONEDNN_V2 +#ifdef ENABLE_ONEDNN_V3 if (is_scale_set["src"] && is_scale_set["wei"] && is_scale_set["dst"]) { net_args.insert( {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, *context_.src_scale_mem}); @@ -437,7 +437,7 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive { net_args.insert( {DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, *context_.dst_scale_mem}); } -#endif // !ENABLE_ONEDNN_V2 +#endif // ENABLE_ONEDNN_V3 context_.net_args.push_back(net_args); context_.fwd_primitives.push_back(*context_.matmul_fwd); return; @@ -521,13 +521,13 @@ class MklDnnMatMulFwdPrimitiveFactory : public MklPrimitiveFactory { DCHECK_EQ(post_op_param.param.size(), 1); key_creator.AddAsKey(post_op_param.name); key_creator.AddAsKey(post_op_param.param[0]); -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 } else if (post_op_param.name == "output_scale") { #else } else if (post_op_param.name == "src_scale" || post_op_param.name == "wei_scale" || post_op_param.name == "dst_scale") { -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 DCHECK_EQ(post_op_param.param.size(), 1); key_creator.AddAsKey(post_op_param.name); key_creator.AddAsKey(post_op_param.param[0]); @@ -612,12 +612,12 @@ class MklDnnMatMulOpBase : public OpKernel { return; } -#ifndef ENABLE_ONEDNN_V2 +#ifdef ENABLE_ONEDNN_V3 // For now, cache weights only for blocked format if (weight_md.get_format_kind() != memory::format_kind::blocked) { return; } -#endif // !ENABLE_ONEDNN_V2 +#endif // ENABLE_ONEDNN_V3 // reorder and cache the weight weight.SetUsrMem(weight_md, &weight_tensor); @@ -638,7 +638,7 @@ class MklDnnMatMulOpBase : public OpKernel { // cache the memory descriptor auto expected_md = matmul_fwd_pd->weights_desc(); -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 TensorShape weight_mkl_format; weight_mkl_format.AddDim(sizeof(expected_md) / sizeof(Tweight)); @@ -653,7 +653,7 @@ class MklDnnMatMulOpBase : public OpKernel { expected_md.get_data_type(), expected_md.get_dims(), expected_md.get_inner_blks(), expected_md.get_inner_idxs(), expected_md.get_strides()); -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 } Tweight* GetCachedWeight(OpKernelContext* context, @@ -661,7 +661,7 @@ class MklDnnMatMulOpBase : public OpKernel { TF_LOCKS_EXCLUDED(mu_) { tf_shared_lock lock(mu_); const Tensor& weight_t = weight_oi_; -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 const Tensor& weight_md_t = weight_oi_md_; // Check if the memory descriptor of the cached weight is same as @@ -693,7 +693,7 @@ class MklDnnMatMulOpBase : public OpKernel { const_cast(weight_t.flat().data())); } return nullptr; -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 } bool IsBiasCacheEmpty() TF_LOCKS_EXCLUDED(bias_cache_mutex_) { @@ -740,11 +740,11 @@ class MklDnnMatMulOpBase : public OpKernel { // Tensor to save reordered weight mutex mu_; Tensor weight_oi_ TF_GUARDED_BY(mu_); -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 Tensor weight_oi_md_ TF_GUARDED_BY(mu_); #else FilterMemoryDesc weight_oi_md_ TF_GUARDED_BY(mu_); -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 bool is_weight_const_; @@ -814,7 +814,7 @@ class MklMatMulPrimitive : public MklPrimitive { #ifdef DNNL_AARCH64_USE_ACL mutex_lock lock(primitive_execution_mu_); #endif -#if !defined(ENABLE_ONEDNN_OPENMP) && defined(ENABLE_ONEDNN_V2) +#if !defined(ENABLE_ONEDNN_OPENMP) && !defined(ENABLE_ONEDNN_V3) context_.a_mem->set_data_handle( static_cast(const_cast(a_data)), *stream); context_.b_mem->set_data_handle( @@ -837,7 +837,7 @@ class MklMatMulPrimitive : public MklPrimitive { context_.sp_mem->set_data_handle(sp_data); if (mul_data != nullptr) context_.mul_mem->set_data_handle(mul_data); if (add_data != nullptr) context_.add_mem->set_data_handle(add_data); -#endif // !ENABLE_ONEDNN_OPENMP && ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_OPENMP && !ENABLE_ONEDNN_V3 execute_primitives(context_.matmul_primitives, stream, context_.net_args); // After execution, set data handle back @@ -865,9 +865,9 @@ class MklMatMulPrimitive : public MklPrimitive { std::shared_ptr sp_mem; // Descriptor and primitive-descriptor for MatMul. -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 std::shared_ptr desc; -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 std::shared_ptr prim_desc; // Memory descriptors. @@ -888,9 +888,9 @@ class MklMatMulPrimitive : public MklPrimitive { mul_mem(nullptr), add_mem(nullptr), sp_mem(nullptr), -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 desc(nullptr), -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 prim_desc(nullptr), a_md(nullptr), b_md(nullptr), @@ -917,10 +917,10 @@ class MklMatMulPrimitive : public MklPrimitive { params.c_strides)); // Create matmul. -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 context_.desc.reset( new matmul::desc(*context_.a_md, *context_.b_md, *context_.c_md)); -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 // Check if there is any fusion as post-ops auto const& post_op_params = params.post_op_params; @@ -929,14 +929,14 @@ class MklMatMulPrimitive : public MklPrimitive { if (!post_op_params.empty()) { for (auto const& post_op_param : post_op_params) { if (post_op_param.name == "output_scale") { -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 // TODO(intel-tf): Verify if this code is needed. If not, it needs to // be removed. DCHECK_EQ(post_op_param.param.size(), 1); std::vector scales; scales.push_back(post_op_param.param[0]); post_ops_attr.set_output_scales(0, scales); -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 } else if (post_op_param.name == "mul") { context_.mul_md.reset(new memory::desc({post_op_param.dims}, post_op_param.data_type, @@ -954,14 +954,14 @@ class MklMatMulPrimitive : public MklPrimitive { post_ops_attr.set_post_ops(post_ops); } post_ops_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 context_.prim_desc.reset( new matmul::primitive_desc(*context_.desc, post_ops_attr, cpu_engine_)); #else context_.prim_desc.reset( new matmul::primitive_desc(cpu_engine_, *context_.a_md, *context_.b_md, *context_.c_md, post_ops_attr)); -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 // Create memory primitive based on dummy data. context_.a_mem.reset( diff --git a/tensorflow/core/kernels/mkl/mkl_maxpooling_op.cc b/tensorflow/core/kernels/mkl/mkl_maxpooling_op.cc index 304b499b62d3a0..40e8529d209424 100644 --- a/tensorflow/core/kernels/mkl/mkl_maxpooling_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_maxpooling_op.cc @@ -116,12 +116,12 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase { : TFShapeToMklDnnDimsInNCDHW(input_tensor.shape(), this->data_format_tf_); memory::dims filter_dims, strides, padding_left, padding_right; -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 this->PoolParamsToDims(&pool_params, &filter_dims, &strides, #else memory::dims dilations; this->PoolParamsToDims(&pool_params, &filter_dims, &strides, &dilations, -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 &padding_left, &padding_right, is_pool2d); // Get a pooling op from the cached pool @@ -134,11 +134,11 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase { else pooling_prop_kind = prop_kind::forward_training; MklPoolingParams fwdParams( -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 src_dims, output_dims_mkl_order, filter_dims, strides, #else src_dims, output_dims_mkl_order, filter_dims, strides, dilations, -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 padding_left, padding_right, dnnl::algorithm::pooling_max, pooling_prop_kind, static_cast(this->data_format_mkldnn_), input_md, @@ -291,12 +291,12 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase { orig_input_shape); memory::dims filter_dims, strides, padding_left, padding_right; -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 this->PoolParamsToDims(&pool_params, &filter_dims, &strides, #else memory::dims dilations; this->PoolParamsToDims(&pool_params, &filter_dims, &strides, &dilations, -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 &padding_left, &padding_right, is_pool2d); memory::dims orig_input_dims_mkl_order = @@ -333,12 +333,12 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase { MklPoolingParams bwdParams( orig_input_dims_mkl_order, output_dims_mkl_order, filter_dims, -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 strides, padding_left, padding_right, dnnl::algorithm::pooling_max, #else strides, dilations, padding_left, padding_right, dnnl::algorithm::pooling_max, -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 prop_kind::forward_training, static_cast(this->data_format_mkldnn_), src_md, this->native_format_); diff --git a/tensorflow/core/kernels/mkl/mkl_pooling_ops_common.cc b/tensorflow/core/kernels/mkl/mkl_pooling_ops_common.cc index d53b8f37eef2e2..2df1e7879fdba9 100644 --- a/tensorflow/core/kernels/mkl/mkl_pooling_ops_common.cc +++ b/tensorflow/core/kernels/mkl/mkl_pooling_ops_common.cc @@ -25,14 +25,14 @@ limitations under the License. #include "tensorflow/core/framework/kernel_shape_util.h" namespace tensorflow { -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 #define AVG_POOLING_DCHECK(params) \ params.alg_kind == dnnl::algorithm::pooling_avg || #define GET_MEMORY_DESC(md) md.data #else #define AVG_POOLING_DCHECK(params) #define GET_MEMORY_DESC(md) md -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 using dnnl::prop_kind; template @@ -58,7 +58,7 @@ void MklPoolingFwdPrimitive::Setup(const MklPoolingParams& fwdParams) { : memory::format_tag::any)); // Create a pooling descriptor. -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 context_.fwd_desc.reset(new pooling_forward::desc( fwdParams.prop_kind, fwdParams.alg_kind, *context_.src_md, *context_.dst_md, fwdParams.strides, fwdParams.filter_dims, @@ -70,7 +70,7 @@ void MklPoolingFwdPrimitive::Setup(const MklPoolingParams& fwdParams) { cpu_engine_, fwdParams.prop_kind, fwdParams.alg_kind, *context_.src_md, *context_.dst_md, fwdParams.strides, fwdParams.filter_dims, fwdParams.dilations, fwdParams.padding_left, fwdParams.padding_right)); -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 context_.dst_fmt = static_cast(memory::format_tag::any); // Create oneDNN internal memory object with dummy data. @@ -104,7 +104,7 @@ void MklPoolingFwdPrimitive::Execute(const T* src_data, T* dst_data, #ifdef DNNL_AARCH64_USE_ACL mutex_lock lock(primitive_execution_mu_); #endif -#if !defined(ENABLE_ONEDNN_OPENMP) && defined(ENABLE_ONEDNN_V2) +#if !defined(ENABLE_ONEDNN_OPENMP) && !defined(ENABLE_ONEDNN_V3) context_.src_mem->set_data_handle( static_cast(const_cast(src_data)), *fwd_stream); context_.dst_mem->set_data_handle(static_cast(dst_data), *fwd_stream); @@ -124,7 +124,7 @@ void MklPoolingFwdPrimitive::Execute(const T* src_data, T* dst_data, DCHECK(ws_data != nullptr); context_.ws_mem->set_data_handle(ws_data); } -#endif // !ENABLE_ONEDNN_OPENMP && ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_OPENMP && !ENABLE_ONEDNN_V3 execute_primitives(context_.fwd_primitives, fwd_stream, context_.net_args); // Set back data handle. @@ -162,7 +162,7 @@ void MklPoolingBwdPrimitive::Setup(const MklPoolingParams& bwdParams) { ? bwdParams.src_format : memory::format_tag::any)); -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 // Create a backward primitive. The implementation for backward must comply to // the workspace format it gets from forward pass, so we directly use src_md // and dst_md here. @@ -188,7 +188,7 @@ void MklPoolingBwdPrimitive::Setup(const MklPoolingParams& bwdParams) { cpu_engine_, bwdParams.alg_kind, *context_.src_md, *context_.dst_md, bwdParams.strides, bwdParams.filter_dims, bwdParams.dilations, bwdParams.padding_left, bwdParams.padding_right, *context_.fwd_pd)); -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 // Create oneDNN internal memory object with dummy data. context_.diff_src_mem.reset(new memory(context_.bwd_pd.get()->diff_src_desc(), @@ -219,7 +219,7 @@ void MklPoolingBwdPrimitive::Execute(const T* diff_dst_data, #ifdef DNNL_AARCH64_USE_ACL mutex_lock lock(primitive_execution_mu_); #endif -#if !defined(ENABLE_ONEDNN_OPENMP) && defined(ENABLE_ONEDNN_V2) +#if !defined(ENABLE_ONEDNN_OPENMP) && !defined(ENABLE_ONEDNN_V3) context_.diff_dst_mem->set_data_handle( static_cast(const_cast(diff_dst_data)), *bwd_stream); context_.diff_src_mem->set_data_handle(static_cast(diff_src_data), @@ -236,7 +236,7 @@ void MklPoolingBwdPrimitive::Execute(const T* diff_dst_data, DCHECK(ws_data != nullptr); context_.ws_mem->set_data_handle(const_cast(ws_data)); } -#endif // !ENABLE_ONEDNN_OPENMP && ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_OPENMP && !ENABLE_ONEDNN_V3 execute_primitives(context_.bwd_primitives, bwd_stream, context_.net_args); @@ -325,13 +325,13 @@ void MklPoolParameters::Init(OpKernelContext* context, col_stride = GetTensorDim(stride, data_format, 'W'); depth_stride = GetTensorDim(stride, data_format, 'C'); -#ifndef ENABLE_ONEDNN_V2 +#ifdef ENABLE_ONEDNN_V3 // TODO(intel-tf): we are setting dilations to 0 to mimic the behavior of // oneDNN v2.x integration code. We can extend this in the future to support // dilations != 0 row_dilation = 0; col_dilation = 0; -#endif // !ENABLE_ONEDNN_V2 +#endif // ENABLE_ONEDNN_V3 // We only support 2D pooling across width/height and depthwise // pooling, not a combination. @@ -354,12 +354,12 @@ void MklPoolParameters::Init(OpKernelContext* context, col_stride = GetTensorDim(stride, data_format, '2'); depth_stride = GetTensorDim(stride, data_format, 'C'); -#ifndef ENABLE_ONEDNN_V2 +#ifdef ENABLE_ONEDNN_V3 // TODO(intel-tf): TensorFlow's 3D-pooling API does not support dilations planes_dilation = 0; row_dilation = 0; col_dilation = 0; -#endif // !ENABLE_ONEDNN_V2 +#endif // ENABLE_ONEDNN_V3 // We only support 3D pooling across depth/width/height and depthwise // pooling, not a combination. diff --git a/tensorflow/core/kernels/mkl/mkl_pooling_ops_common.h b/tensorflow/core/kernels/mkl/mkl_pooling_ops_common.h index d3ad93f73c264b..012244a79691ad 100644 --- a/tensorflow/core/kernels/mkl/mkl_pooling_ops_common.h +++ b/tensorflow/core/kernels/mkl/mkl_pooling_ops_common.h @@ -33,13 +33,13 @@ limitations under the License. namespace tensorflow { -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 #define GET_DIMS data.dims #define SET_MKL_LAYOUT(md) SetMklLayout(&md) #else #define GET_DIMS get_dims() #define SET_MKL_LAYOUT(md) SetMklLayout(md) -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 using dnnl::pooling_backward; using dnnl::pooling_forward; @@ -54,9 +54,9 @@ struct MklPoolingParams { memory::dims dst_dims; memory::dims filter_dims; memory::dims strides; -#ifndef ENABLE_ONEDNN_V2 +#ifdef ENABLE_ONEDNN_V3 memory::dims dilations; -#endif // !ENABLE_ONEDNN_V2 +#endif // ENABLE_ONEDNN_V3 memory::dims padding_left; memory::dims padding_right; dnnl::algorithm alg_kind; @@ -67,9 +67,9 @@ struct MklPoolingParams { MklPoolingParams(memory::dims src_dims, memory::dims dst_dims, memory::dims filter_dims, memory::dims strides, -#ifndef ENABLE_ONEDNN_V2 +#ifdef ENABLE_ONEDNN_V3 memory::dims dilations, -#endif // !ENABLE_ONEDNN_V2 +#endif // ENABLE_ONEDNN_V3 memory::dims padding_left, memory::dims padding_right, dnnl::algorithm alg_kind, dnnl::prop_kind prop_kind, memory::format_tag src_format, memory::desc src_md, @@ -78,9 +78,9 @@ struct MklPoolingParams { dst_dims(dst_dims), filter_dims(filter_dims), strides(strides), -#ifndef ENABLE_ONEDNN_V2 +#ifdef ENABLE_ONEDNN_V3 dilations(dilations), -#endif // !ENABLE_ONEDNN_V2 +#endif // ENABLE_ONEDNN_V3 padding_left(padding_left), padding_right(padding_right), alg_kind(alg_kind), @@ -141,9 +141,9 @@ class MklPoolingFwdPrimitive : public MklPrimitive { std::shared_ptr dst_mem; // Pooling forward descriptor and primitive descriptor. -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 std::shared_ptr fwd_desc; -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 std::shared_ptr fwd_pd; // Memory descriptor. @@ -164,9 +164,9 @@ class MklPoolingFwdPrimitive : public MklPrimitive { ws_mem(nullptr), src_mem(nullptr), dst_mem(nullptr), -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 fwd_desc(nullptr), -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 fwd_pd(nullptr), src_md(nullptr), dst_md(nullptr), @@ -220,9 +220,9 @@ class MklPoolingFwdPrimitiveFactory : public MklPrimitiveFactory { key_creator.AddAsKey(fwdParams.dst_dims); key_creator.AddAsKey(fwdParams.filter_dims); key_creator.AddAsKey(fwdParams.strides); -#ifndef ENABLE_ONEDNN_V2 +#ifdef ENABLE_ONEDNN_V3 key_creator.AddAsKey(fwdParams.dilations); -#endif // !ENABLE_ONEDNN_V2 +#endif // ENABLE_ONEDNN_V3 key_creator.AddAsKey(fwdParams.padding_left); key_creator.AddAsKey(fwdParams.padding_right); key_creator.AddAsKey(static_cast(fwdParams.alg_kind)); @@ -297,10 +297,10 @@ class MklPoolingBwdPrimitive : public MklPrimitive { std::shared_ptr dst_md; // Forward and backward pooling descriptors and primitive descriptors. -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 std::shared_ptr fwd_desc; std::shared_ptr bwd_desc; -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 std::shared_ptr fwd_pd; std::shared_ptr bwd_pd; @@ -320,10 +320,10 @@ class MklPoolingBwdPrimitive : public MklPrimitive { diff_dst_mem(nullptr), src_md(nullptr), dst_md(nullptr), -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 fwd_desc(nullptr), bwd_desc(nullptr), -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 fwd_pd(nullptr), bwd_pd(nullptr), bwd(nullptr) { @@ -376,9 +376,9 @@ class MklPoolingBwdPrimitiveFactory : public MklPrimitiveFactory { key_creator.AddAsKey(bwdParams.dst_dims); key_creator.AddAsKey(bwdParams.filter_dims); key_creator.AddAsKey(bwdParams.strides); -#ifndef ENABLE_ONEDNN_V2 +#ifdef ENABLE_ONEDNN_V3 key_creator.AddAsKey(bwdParams.dilations); -#endif // !ENABLE_ONEDNN_V2 +#endif // ENABLE_ONEDNN_V3 key_creator.AddAsKey(bwdParams.padding_left); key_creator.AddAsKey(bwdParams.padding_right); key_creator.AddAsKey(static_cast(bwdParams.alg_kind)); @@ -416,11 +416,11 @@ struct MklPoolParameters { int col_stride; int depth_stride; -#ifndef ENABLE_ONEDNN_V2 +#ifdef ENABLE_ONEDNN_V3 int planes_dilation; // Pool3D int row_dilation; int col_dilation; -#endif // !ENABLE_ONEDNN_V2 +#endif // ENABLE_ONEDNN_V3 int64 out_planes; // Pool3D int64 out_height; @@ -450,11 +450,11 @@ struct MklPoolParameters { row_stride(0), col_stride(0), depth_stride(0), -#ifndef ENABLE_ONEDNN_V2 +#ifdef ENABLE_ONEDNN_V3 planes_dilation(0), row_dilation(0), col_dilation(0), -#endif // !ENABLE_ONEDNN_V2 +#endif // ENABLE_ONEDNN_V3 out_planes(0), out_height(0), out_width(0), @@ -572,9 +572,9 @@ class MklPoolingOpBase : public OpKernel { void PoolParamsToDims(const MklPoolParameters* pool_params, memory::dims* filter_dims, memory::dims* strides, -#ifndef ENABLE_ONEDNN_V2 +#ifdef ENABLE_ONEDNN_V3 memory::dims* dilations, -#endif // !ENABLE_ONEDNN_V2 +#endif // ENABLE_ONEDNN_V3 memory::dims* padding_left, memory::dims* padding_right, bool is_pool2d) { if (is_pool2d) { @@ -583,10 +583,10 @@ class MklPoolingOpBase : public OpKernel { memory::dims({pool_params->window_rows, pool_params->window_cols}); *strides = memory::dims({pool_params->row_stride, pool_params->col_stride}); -#ifndef ENABLE_ONEDNN_V2 +#ifdef ENABLE_ONEDNN_V3 *dilations = memory::dims({pool_params->row_dilation, pool_params->col_dilation}); -#endif // !ENABLE_ONEDNN_V2 +#endif // ENABLE_ONEDNN_V3 *padding_left = memory::dims({static_cast(pool_params->pad_top), static_cast(pool_params->pad_left)}); *padding_right = memory::dims({static_cast(pool_params->pad_bottom), @@ -599,11 +599,11 @@ class MklPoolingOpBase : public OpKernel { *strides = memory::dims({pool_params->planes_stride, pool_params->row_stride, pool_params->col_stride}); -#ifndef ENABLE_ONEDNN_V2 +#ifdef ENABLE_ONEDNN_V3 *dilations = memory::dims({pool_params->planes_dilation, pool_params->row_dilation, pool_params->col_dilation}); -#endif // !ENABLE_ONEDNN_V2 +#endif // ENABLE_ONEDNN_V3 *padding_left = memory::dims({static_cast(pool_params->pad_P1), static_cast(pool_params->pad_top), diff --git a/tensorflow/core/kernels/mkl/mkl_qmatmul_op.cc b/tensorflow/core/kernels/mkl/mkl_qmatmul_op.cc index 8e0e85f1b5c4a9..987bda960b4519 100644 --- a/tensorflow/core/kernels/mkl/mkl_qmatmul_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_qmatmul_op.cc @@ -96,7 +96,7 @@ limitations under the License. // // More information of this implementation can be found in // https://software.intel.com/en-us/articles/lower-numerical-precision-deep-learning-inference-and-training -#ifdef INTEL_MKL +#if defined(INTEL_MKL) #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/kernels/fill_functor.h" @@ -116,11 +116,11 @@ enum { namespace tensorflow { -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 #define TSCALED_BIAS Tbias #else #define TSCALED_BIAS float -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 template @@ -320,7 +320,7 @@ class MklDnnQuantizedMatMulOp UserScratchPad scratch_pad; scratch_pad.AllocateSPTensor(matmul_fwd, context); -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 Tbias* bias_data = this->GetBiasHandle( context, matmul_fwd_pd, bias_tensor, weight_tensor, cpu_stream); #else @@ -330,7 +330,7 @@ class MklDnnQuantizedMatMulOp Tensor temp_scaled_bias_tensor; this->GetBiasHandle(context, matmul_fwd_pd, bias_tensor, weight_tensor, cpu_stream, &temp_scaled_bias_tensor, &bias_data); -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 // Execute inner-product matmul_fwd->Execute(src_data, weight_data, bias_data, dst_data, matmul_fwd_dims, scratch_pad.Get(), cpu_stream); @@ -428,7 +428,7 @@ class MklDnnQuantizedMatMulOp absl::InvalidArgumentError(absl::StrCat( "`max_b` must be rank 0 but is rank ", max_weight_tensor.dims()))); -#ifndef ENABLE_ONEDNN_V2 +#ifdef ENABLE_ONEDNN_V3 const float min_input = min_input_tensor.scalar()(); const float max_input = max_input_tensor.scalar()(); const float min_weight = min_weight_tensor.scalar()(); @@ -441,7 +441,7 @@ class MklDnnQuantizedMatMulOp float wei_scale = std::max(std::abs(min_weight), std::abs(max_weight)) / 127.0; float dst_scale = 1.0; -#endif // !ENABLE_ONEDNN_V2 +#endif // ENABLE_ONEDNN_V3 // When the output type is quint8, the output data is requantized into // quint8. A post_op "output_scale" is added to do the conversion. if (std::is_same::value || @@ -464,7 +464,7 @@ class MklDnnQuantizedMatMulOp const float max_freezed_output = max_freezed_tensor.scalar()(); float scale_eightbit = std::max(std::abs(min_freezed_output), std::abs(max_freezed_output)); -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 float min_output_value; float max_output_value; ComputeOutputRangeForInt32(context, &min_output_value, &max_output_value); @@ -504,7 +504,7 @@ class MklDnnQuantizedMatMulOp params.post_op_params.push_back({"src_scale", {src_scale}}); params.post_op_params.push_back({"wei_scale", {wei_scale}}); params.post_op_params.push_back({"dst_scale", {dst_scale}}); -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 } // This function handles bias conversion and compensation for MIN_FIRST and @@ -512,7 +512,7 @@ class MklDnnQuantizedMatMulOp // B's32 = Q'a * Qw * Bf32 + Q'a * Qw * Min(Af32) * 1 * Wf32 // If input is quantized via SCALE, // Bs32 = Qa * Qw * Bf32. -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 Tbias* GetBiasHandle( OpKernelContext* context, std::shared_ptr& @@ -753,7 +753,7 @@ class MklDnnQuantizedMatMulOp } return false; } -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 private: memory* input_bias_ = nullptr; diff --git a/tensorflow/core/kernels/mkl/mkl_qmatmul_op_test.cc b/tensorflow/core/kernels/mkl/mkl_qmatmul_op_test.cc index 95baa8b0f5eb3b..22b56e19e3bb63 100644 --- a/tensorflow/core/kernels/mkl/mkl_qmatmul_op_test.cc +++ b/tensorflow/core/kernels/mkl/mkl_qmatmul_op_test.cc @@ -273,11 +273,11 @@ TEST_F(QuantizedMatMulTest, Small_withBiasAndReq) { // 178 * 1.00392 ~= 178.698 ~= 179 Tensor expected(allocator(), DT_QUINT8, TensorShape({2, 4})); -#ifndef ENABLE_ONEDNN_V2 +#ifdef ENABLE_ONEDNN_V3 test::FillValues(&expected, {84, 60, 116, 52, 183, 168, 233, 178}); #else test::FillValues(&expected, {84, 60, 116, 52, 184, 169, 234, 179}); -#endif // !ENABLE_ONEDNN_V2 +#endif // ENABLE_ONEDNN_V3 const Tensor& output = *GetOutput(0); test::ExpectTensorEqual(expected, output); @@ -470,11 +470,11 @@ TEST_F(QuantizedMatMulTest, Small_withBiasAndReluAndReq) { // 178 * 1.00392 ~= 178.698 ~= 179 Tensor expected(allocator(), DT_QUINT8, TensorShape({2, 4})); -#ifndef ENABLE_ONEDNN_V2 +#ifdef ENABLE_ONEDNN_V3 test::FillValues(&expected, {84, 60, 116, 52, 183, 168, 233, 178}); #else test::FillValues(&expected, {84, 60, 116, 52, 184, 169, 234, 179}); -#endif // !ENABLE_ONEDNN_V2 +#endif // ENABLE_ONEDNN_V3 const Tensor& output = *GetOutput(0); test::ExpectTensorEqual(expected, output); diff --git a/tensorflow/core/kernels/mkl/mkl_quantize_op.cc b/tensorflow/core/kernels/mkl/mkl_quantize_op.cc index 3e17aa2e162c56..36e7178d5b7249 100644 --- a/tensorflow/core/kernels/mkl/mkl_quantize_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_quantize_op.cc @@ -57,11 +57,11 @@ enum { namespace tensorflow { -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 #define SET_MKL_LAYOUT(md) SetMklLayout(&md) #else #define SET_MKL_LAYOUT(md) SetMklLayout(md) -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 typedef Eigen::ThreadPoolDevice CPUDevice; @@ -69,9 +69,9 @@ struct MklReorderWithScaleFwdParams { memory::dims src_dims; memory::desc src_md; memory::desc dst_md; -#ifndef ENABLE_ONEDNN_V2 +#ifdef ENABLE_ONEDNN_V3 memory::desc scale_md; -#endif // !ENABLE_ONEDNN_V2 +#endif // ENABLE_ONEDNN_V3 string dtypes = string(""); struct PostOpParam { string name; @@ -79,7 +79,7 @@ struct MklReorderWithScaleFwdParams { }; PostOpParam post_op_params; -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 MklReorderWithScaleFwdParams(memory::dims src_dims, memory::desc src_md, memory::desc dst_md) : src_dims(src_dims), src_md(src_md), dst_md(dst_md) {} @@ -90,7 +90,7 @@ struct MklReorderWithScaleFwdParams { src_md(src_md), dst_md(dst_md), scale_md(scale_md) {} -#endif // !ENABLE_ONEDNN_V2 +#endif // ENABLE_ONEDNN_V3 }; class MklReorderWithScalePrimitive : public MklPrimitive { @@ -107,30 +107,30 @@ class MklReorderWithScalePrimitive : public MklPrimitive { std::shared_ptr GetPrimitive() { return context_.reorder_prim; } void Execute(void* src_data, void* dst_data, -#ifndef ENABLE_ONEDNN_V2 +#ifdef ENABLE_ONEDNN_V3 void* scale_data, -#endif // !ENABLE_ONEDNN_V2 +#endif // ENABLE_ONEDNN_V3 std::shared_ptr reorder_stream) { #ifdef DNNL_AARCH64_USE_ACL mutex_lock lock(primitive_execution_mu_); #endif -#if !defined(ENABLE_ONEDNN_OPENMP) && defined(ENABLE_ONEDNN_V2) +#if !defined(ENABLE_ONEDNN_OPENMP) && !defined(ENABLE_ONEDNN_V3) context_.src_mem->set_data_handle(src_data, *reorder_stream); context_.dst_mem->set_data_handle(dst_data, *reorder_stream); #else context_.src_mem->set_data_handle(src_data); context_.dst_mem->set_data_handle(dst_data); -#endif // !ENABLE_ONEDNN_OPENMP && ENABLE_ONEDNN_V2 -#ifndef ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_OPENMP && !ENABLE_ONEDNN_V3 +#ifdef ENABLE_ONEDNN_V3 context_.scale_mem->set_data_handle(scale_data); -#endif // !ENABLE_ONEDNN_V2 +#endif // ENABLE_ONEDNN_V3 context_.reorder_prim->execute(*reorder_stream, context_.prim_args); // After execution, set data handle back. context_.src_mem->set_data_handle(DummyData); context_.dst_mem->set_data_handle(DummyData); -#ifndef ENABLE_ONEDNN_V2 +#ifdef ENABLE_ONEDNN_V3 context_.scale_mem->set_data_handle(DummyData); -#endif // !ENABLE_ONEDNN_V2 +#endif // ENABLE_ONEDNN_V3 } private: @@ -139,9 +139,9 @@ class MklReorderWithScalePrimitive : public MklPrimitive { // MKL-DNN memory std::shared_ptr src_mem; std::shared_ptr dst_mem; -#ifndef ENABLE_ONEDNN_V2 +#ifdef ENABLE_ONEDNN_V3 std::shared_ptr scale_mem; -#endif // !ENABLE_ONEDNN_V2 +#endif // ENABLE_ONEDNN_V3 // Reorder primitive descriptor and primitive std::shared_ptr reorder_pd; @@ -155,9 +155,9 @@ class MklReorderWithScalePrimitive : public MklPrimitive { ReorderContext() : src_mem(nullptr), dst_mem(nullptr), -#ifndef ENABLE_ONEDNN_V2 +#ifdef ENABLE_ONEDNN_V3 scale_mem(nullptr), -#endif // !ENABLE_ONEDNN_V2 +#endif // ENABLE_ONEDNN_V3 reorder_pd(nullptr), reorder_prim(nullptr) { } @@ -170,14 +170,14 @@ class MklReorderWithScalePrimitive : public MklPrimitive { new memory(fwdParams.src_md, cpu_engine_, DummyData)); context_.dst_mem.reset( new memory(fwdParams.dst_md, cpu_engine_, DummyData)); -#ifndef ENABLE_ONEDNN_V2 +#ifdef ENABLE_ONEDNN_V3 context_.scale_mem.reset( new memory(fwdParams.scale_md, cpu_engine_, DummyData)); -#endif // !ENABLE_ONEDNN_V2 +#endif // ENABLE_ONEDNN_V3 // Check if there is any fusion as post-ops dnnl::primitive_attr post_ops_attr; -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 auto const& post_op_params = fwdParams.post_op_params; DCHECK(post_op_params.name == "scale"); DCHECK_EQ(post_op_params.param.size(), 1); @@ -186,7 +186,7 @@ class MklReorderWithScalePrimitive : public MklPrimitive { post_ops_attr.set_output_scales(0, scales); #else post_ops_attr.set_scales_mask(DNNL_ARG_SRC, 0 /* mask */); -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 context_.reorder_pd.reset( new ReorderPd(cpu_engine_, context_.src_mem->get_desc(), cpu_engine_, @@ -196,10 +196,10 @@ class MklReorderWithScalePrimitive : public MklPrimitive { context_.reorder_prim.reset(new reorder(*context_.reorder_pd)); context_.prim_args.insert({DNNL_ARG_FROM, *context_.src_mem}); context_.prim_args.insert({DNNL_ARG_TO, *context_.dst_mem}); -#ifndef ENABLE_ONEDNN_V2 +#ifdef ENABLE_ONEDNN_V3 context_.prim_args.insert( {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, *context_.scale_mem}); -#endif // !ENABLE_ONEDNN_V2 +#endif // ENABLE_ONEDNN_V3 } #ifdef DNNL_AARCH64_USE_ACL @@ -437,9 +437,9 @@ class MklQuantizeV2Op : public OpKernel { // they are wrapper MklDnnData src(&cpu_engine); MklDnnData dst(&cpu_engine); -#ifndef ENABLE_ONEDNN_V2 +#ifdef ENABLE_ONEDNN_V3 MklDnnData scale(&cpu_engine); -#endif // !ENABLE_ONEDNN_V2 +#endif // ENABLE_ONEDNN_V3 auto src_md = src_mkl_shape.IsMklTensor() @@ -538,7 +538,7 @@ class MklQuantizeV2Op : public OpKernel { const int64 number_of_steps = static_cast(1) << number_of_bits; scale_factor = (number_of_steps - 1.0) / (max_range - min_range); } -#ifndef ENABLE_ONEDNN_V2 +#ifdef ENABLE_ONEDNN_V3 auto scale_md = memory::desc({1}, MklDnnType(), memory::format_tag::x); MklReorderWithScaleFwdParams fwdParams(src_dims, src_md, dst_md, scale_md); @@ -549,7 +549,7 @@ class MklQuantizeV2Op : public OpKernel { #else MklReorderWithScaleFwdParams fwdParams(src_dims, src_md, dst_md); fwdParams.dtypes.append(typeid(T).name()); -#endif // !ENABLE_ONEDNN_V2 +#endif // ENABLE_ONEDNN_V3 fwdParams.post_op_params.name = "scale"; fwdParams.post_op_params.param.push_back(scale_factor); @@ -566,9 +566,9 @@ class MklQuantizeV2Op : public OpKernel { cpu_stream.reset(CreateStream(&eigen_tp, reorder_prim->GetEngine())); reorder_prim->Execute(src.GetUsrMemDataHandle(), dst.GetUsrMemDataHandle(), -#ifndef ENABLE_ONEDNN_V2 +#ifdef ENABLE_ONEDNN_V3 scale.GetUsrMemDataHandle(), -#endif // !ENABLE_ONEDNN_V2 +#endif // ENABLE_ONEDNN_V3 cpu_stream); output_min_tensor->scalar()() = min_range; diff --git a/tensorflow/core/kernels/mkl/mkl_relu_op.cc b/tensorflow/core/kernels/mkl/mkl_relu_op.cc index 2d7064b51692e9..c239d958558a9c 100644 --- a/tensorflow/core/kernels/mkl/mkl_relu_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_relu_op.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ // See docs in ../ops/nn_ops.cc. -#if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V2) +#if defined(INTEL_MKL) && !defined(ENABLE_ONEDNN_V3) // TODO(intel-tf): This file is no longer used and needs to be removed. // This file will be an empty compilation unit when building with oneDNN v3.x // (default behavior). It can be compiled only when building with oneDNN v2.x. @@ -1240,4 +1240,4 @@ TF_CALL_bfloat16(REGISTER_LeakyRelu_MKL_SUPPORTED_KERNELS_TYPES); } // namespace tensorflow -#endif // INTEL_MKL && ENABLE_ONEDNN_V2 +#endif // INTEL_MKL && !ENABLE_ONEDNN_V3 diff --git a/tensorflow/core/kernels/mkl/mkl_relu_op_test.cc b/tensorflow/core/kernels/mkl/mkl_relu_op_test.cc index f39b96f3606f51..740de287de37d6 100644 --- a/tensorflow/core/kernels/mkl/mkl_relu_op_test.cc +++ b/tensorflow/core/kernels/mkl/mkl_relu_op_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V2) && defined(ENABLE_MKL) +#if defined(INTEL_MKL) && !defined(ENABLE_ONEDNN_V3) && defined(ENABLE_MKL) // TODO(intel-tf): This file is no longer used and needs to be removed. // This file will be an empty compilation unit when building with oneDNN v3.x // (default behavior). It can be compiled only when building with oneDNN v2.x. @@ -139,4 +139,4 @@ TEST_ALL_SIZES(LeakyReluGrad) } // namespace tensorflow -#endif // INTEL_MKL && ENABLE_ONEDNN_V2 && ENABLE_MKL +#endif // INTEL_MKL && !ENABLE_ONEDNN_V3 && ENABLE_MKL diff --git a/tensorflow/core/kernels/mkl/mkl_requantize_per_channel_op.cc b/tensorflow/core/kernels/mkl/mkl_requantize_per_channel_op.cc index 6080df45c1ae5a..ecf518baf4ed98 100644 --- a/tensorflow/core/kernels/mkl/mkl_requantize_per_channel_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_requantize_per_channel_op.cc @@ -106,7 +106,7 @@ class MklRequantizePerChannelOp : public OpKernel { } dnnl::primitive_attr reorder_attr; -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 reorder_attr.set_output_scales(2, scales); #else reorder_attr.set_scales_mask(DNNL_ARG_SRC, 2); @@ -114,7 +114,7 @@ class MklRequantizePerChannelOp : public OpKernel { MklDnnType(), memory::format_tag::x}, cpu_engine_, scales.data()); -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 // Create the oneDNN wrapper over Eigen threadpool and set max threads // in oneDNN. @@ -157,9 +157,9 @@ class MklRequantizePerChannelOp : public OpKernel { reorder_stream.reset(CreateStream(&eigen_tp, cpu_engine_)); std::unordered_map reorder_args = { {DNNL_ARG_FROM, *input_mem_prim}, {DNNL_ARG_TO, *output_mem_prim}}; -#ifndef ENABLE_ONEDNN_V2 +#ifdef ENABLE_ONEDNN_V3 reorder_args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, scale_mem}); -#endif // !ENABLE_ONEDNN_V2 +#endif // ENABLE_ONEDNN_V3 std::unique_ptr reorder_prim( new dnnl::reorder(reorder_pd)); reorder_prim->execute(*reorder_stream, reorder_args); diff --git a/tensorflow/core/kernels/mkl/mkl_softmax_op.cc b/tensorflow/core/kernels/mkl/mkl_softmax_op.cc index 7a7e14c1d90524..50caffa5e2d1e7 100644 --- a/tensorflow/core/kernels/mkl/mkl_softmax_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_softmax_op.cc @@ -67,7 +67,7 @@ class MklSoftmaxPrimitive : public MklPrimitive { #ifdef DNNL_AARCH64_USE_ACL mutex_lock lock(primitive_execution_mu_); #endif -#if !defined(ENABLE_ONEDNN_OPENMP) && defined(ENABLE_ONEDNN_V2) +#if !defined(ENABLE_ONEDNN_OPENMP) && !defined(ENABLE_ONEDNN_V3) context_.src_mem->set_data_handle( static_cast(const_cast(src_data)), *fwd_cpu_stream); context_.dst_mem->set_data_handle(static_cast(dst_data), @@ -76,7 +76,7 @@ class MklSoftmaxPrimitive : public MklPrimitive { context_.src_mem->set_data_handle( static_cast(const_cast(src_data))); context_.dst_mem->set_data_handle(static_cast(dst_data)); -#endif // !ENABLE_ONEDNN_OPENMP && ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_OPENMP && !ENABLE_ONEDNN_V3 DCHECK_EQ(context_.fwd_primitives.size(), context_.fwd_net_args.size()); execute_primitives(context_.fwd_primitives, fwd_cpu_stream, @@ -98,9 +98,9 @@ class MklSoftmaxPrimitive : public MklPrimitive { std::shared_ptr dst_mem; // Primitive descriptor. -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 std::shared_ptr fwd_desc; -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 // Memory descriptor. std::shared_ptr src_md; @@ -115,9 +115,9 @@ class MklSoftmaxPrimitive : public MklPrimitive { SoftmaxFwdContext() : src_mem(nullptr), dst_mem(nullptr), -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 fwd_desc(nullptr), -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 src_md(nullptr), fwd_pd(nullptr), softmax_fwd(nullptr) { @@ -132,7 +132,7 @@ class MklSoftmaxPrimitive : public MklPrimitive { new memory::desc({fwdParams.src_dims}, MklDnnType(), src_format)); // Create softmax descriptor and primitive descriptor. -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 context_.fwd_desc.reset(new dnnl::softmax_forward::desc( prop_kind::forward_scoring, *context_.src_md, fwdParams.axis)); context_.fwd_pd.reset(new dnnl::softmax_forward::primitive_desc( @@ -142,7 +142,7 @@ class MklSoftmaxPrimitive : public MklPrimitive { cpu_engine_, prop_kind::forward_inference, dnnl::algorithm::softmax_accurate, *context_.src_md, *context_.src_md /* dst_md */, fwdParams.axis)); -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 // Create memory primitive based on dummy data. context_.src_mem.reset( diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h index e426221ff844d5..caa0b11305d199 100644 --- a/tensorflow/core/util/mkl_util.h +++ b/tensorflow/core/util/mkl_util.h @@ -159,7 +159,7 @@ inline void execute_primitives( } } -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 #define ARE_MEMORY_DESCS_EQUAL(md1, md2) dnnl_memory_desc_equal(&md1, &md2) #define CREATE_MEMORY_DESC_USING_STRIDES dnnl_memory_desc_init_by_strides #define GET_DATA_TYPE data_type @@ -193,7 +193,7 @@ inline void execute_primitives( #define GET_STRIDES_DIMS(dims, dims_outer_blocks) dims #define INIT_DIMS_FROM_DESC(in_dims, md) in_dims = md.get_dims() #define MEMORY_DESC memory::desc -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 // In oneDNN v1.x, the format (ex. NCHW) used to initialize a memory descriptor // (md) structure will no longer be recorded in its `format` field. Instead, it @@ -477,14 +477,14 @@ class MklDnnShape { inline void SetElemType(memory::data_type dt) { data_.T_ = dt; } inline const memory::data_type GetElemType() { return data_.T_; } -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 inline void SetMklLayout(memory::desc* md) { CHECK_NOTNULL(md); data_.mkl_md_ = md->data; } #else inline void SetMklLayout(const memory::desc& md) { data_.mkl_md_ = md; } -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 inline const memory::desc GetMklLayout() const { return memory::desc(data_.mkl_md_); @@ -1342,7 +1342,7 @@ inline void CreateAndExecuteReorder(const ReorderPd& reorder_desc, std::vector net; net.push_back(dnnl::reorder(reorder_desc)); std::vector net_args; -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 net_args.push_back({{DNNL_ARG_FROM, src_mem}, {DNNL_ARG_TO, dst_mem}}); #else if (scale_mem != nullptr) { @@ -1352,7 +1352,7 @@ inline void CreateAndExecuteReorder(const ReorderPd& reorder_desc, } else { net_args.push_back({{DNNL_ARG_FROM, src_mem}, {DNNL_ARG_TO, dst_mem}}); } -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 ExecutePrimitive(net, &net_args, engine, ctx); } @@ -1514,11 +1514,11 @@ class MklDnnData { std::shared_ptr t_stream = nullptr) { CHECK_NOTNULL(user_memory_); CHECK_NOTNULL(data_buffer); -#if !defined(ENABLE_ONEDNN_OPENMP) && defined(ENABLE_ONEDNN_V2) +#if !defined(ENABLE_ONEDNN_OPENMP) && !defined(ENABLE_ONEDNN_V3) user_memory_->set_data_handle(data_buffer, *t_stream); #else user_memory_->set_data_handle(data_buffer); -#endif // !ENABLE_ONEDNN_OPENMP && ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_OPENMP && !ENABLE_ONEDNN_V3 } /// Set function for data buffer of user memory primitive. @@ -2195,7 +2195,7 @@ class MklReorderPrimitiveFactory : public MklPrimitiveFactory { auto to_inner_blks = to_desc.GET_INNER_BLKS; auto to_inner_idxs = to_desc.GET_INNER_IDXS; auto to_strides = to_desc.GET_STRIDES; -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 memory::dims from_inner_blks_1(from_inner_blks, &from_inner_blks[from_inner_nblks]); memory::dims from_inner_idxs_1(from_inner_idxs, @@ -2206,7 +2206,7 @@ class MklReorderPrimitiveFactory : public MklPrimitiveFactory { &from_strides[from_desc.ndims]); memory::dims to_strides_outer_blocks(to_strides, &to_strides[to_desc.ndims]); -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 key_creator.AddAsKey(prefix); #ifdef DNNL_AARCH64_USE_ACL diff --git a/tensorflow/core/util/mkl_util_test.cc b/tensorflow/core/util/mkl_util_test.cc index 30799c72dcb538..2621d9ad27e207 100644 --- a/tensorflow/core/util/mkl_util_test.cc +++ b/tensorflow/core/util/mkl_util_test.cc @@ -53,7 +53,7 @@ TEST(MklUtilTest, MklDnnTfShape) { EXPECT_NE(b_tf_shape_nchw, b_mkldnn_tf_shape); } -#ifdef ENABLE_ONEDNN_V2 +#ifndef ENABLE_ONEDNN_V3 // TODO(intel-tf): This code is not tested for oneDNN v3.x and needs to be // removed TEST(MklUtilTest, MklDnnBlockedFormatTest) { @@ -83,7 +83,7 @@ TEST(MklUtilTest, MklDnnBlockedFormatTest) { EXPECT_EQ(b_md2.data.dims[0], 3); EXPECT_EQ(b_md2.data.dims[1], 4); } -#endif // ENABLE_ONEDNN_V2 +#endif // !ENABLE_ONEDNN_V3 TEST(MklUtilTest, LRUCacheTest) { // The cached objects are of type int* diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index 875f35e4ce9e29..eaae8b80ed0714 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -54,6 +54,10 @@ load( "if_mkldnn_aarch64_acl_openmp", "if_mkldnn_openmp", ) +load( + "//tensorflow/tsl/mkl:build_defs.bzl", + "onednn_v3_define", +) load( "//third_party/compute_library:build_defs.bzl", "if_enable_acl", @@ -440,6 +444,7 @@ def tf_copts( # optimizations for Intel builds using oneDNN if configured if_enable_mkl(["-DENABLE_MKL"]) + if_mkldnn_openmp(["-DENABLE_ONEDNN_OPENMP"]) + + onednn_v3_define() + if_mkldnn_aarch64_acl(["-DDNNL_AARCH64_USE_ACL=1", "-DENABLE_ONEDNN_V2=1"]) + if_mkldnn_aarch64_acl_openmp(["-DENABLE_ONEDNN_OPENMP", "-DENABLE_ONEDNN_V2=1"]) + if_zendnn(["-DAMD_ZENDNN"]) + diff --git a/tensorflow/tsl/mkl/build_defs.bzl b/tensorflow/tsl/mkl/build_defs.bzl index 28d15f2b4fcfca..13400d341ebd2e 100644 --- a/tensorflow/tsl/mkl/build_defs.bzl +++ b/tensorflow/tsl/mkl/build_defs.bzl @@ -107,6 +107,21 @@ def mkl_deps(): "//conditions:default": [], }) +def onednn_v3_define(): + """Returns a define to indicate if oneDNN v3.x is enabled or not. It is + defined only on Linux and Windows x86 builds. + Returns none for all other cases (including ARM builds). + + Returns: + a select evaluating to a define or none as appropriate. + """ + return select({ + "@org_tensorflow//tensorflow/tsl/mkl:build_with_mkl_aarch64": [], + "@org_tensorflow//tensorflow/tsl:linux_x86_64": ["-DENABLE_ONEDNN_V3"], + "@org_tensorflow//tensorflow/tsl:windows": ["-DENABLE_ONEDNN_V3"], + "//conditions:default": [], + }) + def _enable_local_mkl(repository_ctx): return _TF_MKL_ROOT in repository_ctx.os.environ From 59a068276f0e992ad3368091d0f1e904a6b69b6d Mon Sep 17 00:00:00 2001 From: Kuangyuan Chen Date: Mon, 24 Jul 2023 16:57:27 -0700 Subject: [PATCH 082/410] Internal TFRT change PiperOrigin-RevId: 550713100 --- tensorflow/compiler/mlir/tfrt/ir/mlrt/BUILD | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow/compiler/mlir/tfrt/ir/mlrt/BUILD b/tensorflow/compiler/mlir/tfrt/ir/mlrt/BUILD index dfc673b86d4019..2001bfc7186ca9 100644 --- a/tensorflow/compiler/mlir/tfrt/ir/mlrt/BUILD +++ b/tensorflow/compiler/mlir/tfrt/ir/mlrt/BUILD @@ -13,7 +13,7 @@ td_library( ], includes = ["."], visibility = [ - # copybara:uncomment "//learning/brain/experimental/tfrt:__subpackages__", + # copybara:uncomment "//learning/brain/tfrt/mlir:__subpackages__", "//learning/infra/mira/distributed:__subpackages__", "//tensorflow/compiler/mlir/tfrt:__subpackages__", "//tensorflow/core/tfrt/mlrt:__subpackages__", @@ -52,7 +52,7 @@ cc_library( "mlrt_ops.h", ], visibility = [ - # copybara:uncomment "//learning/brain/experimental/tfrt:__subpackages__", + # copybara:uncomment "//learning/brain/tfrt/mlir:__subpackages__", "//learning/infra/mira/distributed:__subpackages__", "//tensorflow/compiler/mlir/tfrt:__subpackages__", ], @@ -72,7 +72,7 @@ td_library( ], includes = ["."], visibility = [ - # copybara:uncomment "//learning/brain/experimental/tfrt:__subpackages__", + # copybara:uncomment "//learning/brain/tfrt/mlir:__subpackages__", # copybara:uncomment "//learning/infra/mira/distributed:__subpackages__", ], deps = [ From a10d81efa3295983995db2e28de441f6bd6dce31 Mon Sep 17 00:00:00 2001 From: Jieying Luo Date: Mon, 24 Jul 2023 17:04:31 -0700 Subject: [PATCH 083/410] [PJRT C API] Add logging for the PJRT C API version of framework and plugin. PiperOrigin-RevId: 550714686 --- tensorflow/compiler/xla/pjrt/c/pjrt_c_api_helpers.cc | 11 ++++++++--- tensorflow/compiler/xla/pjrt/pjrt_api.cc | 10 ++++++++++ 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_helpers.cc b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_helpers.cc index 6e40ebf8539e7c..501531e0747acc 100644 --- a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_helpers.cc +++ b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_helpers.cc @@ -539,9 +539,14 @@ PJRT_SerializedExecutableDeleter MakeSerializedExecutableDeleter( static std::string StructSizeErrorMsg(absl::string_view struct_name, size_t expected_size, size_t actual_size) { - return absl::StrCat("Unexpected ", struct_name, " size: expected ", - expected_size, ", got ", actual_size, - ". Check installed software versions."); + std::string error_msg = absl::StrCat( + "Unexpected ", struct_name, " size: expected ", expected_size, ", got ", + actual_size, ". Check installed software versions."); +#if defined(PJRT_API_MAJOR) + absl::StrAppend(&error_msg, " The framework PJRT API version is ", + PJRT_API_MAJOR, ".", PJRT_API_MINOR, "."); +#endif // PJRT_API_MAJOR + return error_msg; } xla::Status CheckMatchingStructSizes(absl::string_view struct_name, diff --git a/tensorflow/compiler/xla/pjrt/pjrt_api.cc b/tensorflow/compiler/xla/pjrt/pjrt_api.cc index 5861f074a63fc1..ebb31db4c4f334 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_api.cc +++ b/tensorflow/compiler/xla/pjrt/pjrt_api.cc @@ -62,6 +62,16 @@ xla::Status SetPjrtApi(absl::string_view device_type, const PJRT_Api* api) { xla::Status InitPjrtPlugin(PjrtApiInitFn init_fn, absl::string_view device_type) { const PJRT_Api* pjrt_api = init_fn(); + // TODO(jieying): 592 is the size of PJRT_Api right after PJRT_Api_Version is + // added. Remove this check after PJRT C API is stable and we assume all + // plugins uses PJRT C API with PJRT_Api_Version. + if (pjrt_api->struct_size >= 592) { + LOG(INFO) << "PJRT plugin for " << device_type << " has PJRT API version " + << pjrt_api->pjrt_api_version.major_version << "." + << pjrt_api->pjrt_api_version.minor_version + << ". The framework PJRT API version is " << PJRT_API_MAJOR << "." + << PJRT_API_MINOR << "."; + } TF_RETURN_IF_ERROR(pjrt::CheckMatchingStructSizes( "PJRT_Api", PJRT_Api_STRUCT_SIZE, pjrt_api->struct_size)); return SetPjrtApi(device_type, pjrt_api); From 7220ef46163748e493e8ad975295e01d0c6a1ad3 Mon Sep 17 00:00:00 2001 From: Bhavani Subramanian Date: Mon, 24 Jul 2023 17:25:14 -0700 Subject: [PATCH 084/410] Removed changes which are no longer needed --- tensorflow/tensorflow.bzl | 4 ++-- tensorflow/tsl/tsl.bzl | 9 +++++++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index eaae8b80ed0714..bb1d1394785c1d 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -445,8 +445,8 @@ def tf_copts( if_enable_mkl(["-DENABLE_MKL"]) + if_mkldnn_openmp(["-DENABLE_ONEDNN_OPENMP"]) + onednn_v3_define() + - if_mkldnn_aarch64_acl(["-DDNNL_AARCH64_USE_ACL=1", "-DENABLE_ONEDNN_V2=1"]) + - if_mkldnn_aarch64_acl_openmp(["-DENABLE_ONEDNN_OPENMP", "-DENABLE_ONEDNN_V2=1"]) + + if_mkldnn_aarch64_acl(["-DDNNL_AARCH64_USE_ACL=1"]) + + if_mkldnn_aarch64_acl_openmp(["-DENABLE_ONEDNN_OPENMP"]) + if_zendnn(["-DAMD_ZENDNN"]) + if_enable_acl(["-DXLA_CPU_USE_ACL=1", "-fexceptions"]) + if_android_arm(["-mfpu=neon", "-fomit-frame-pointer"]) + diff --git a/tensorflow/tsl/tsl.bzl b/tensorflow/tsl/tsl.bzl index ce729c50664e5b..df1492066739f0 100644 --- a/tensorflow/tsl/tsl.bzl +++ b/tensorflow/tsl/tsl.bzl @@ -30,6 +30,10 @@ load( "if_mkldnn_aarch64_acl_openmp", "if_mkldnn_openmp", ) +load( + "//tensorflow/tsl/mkl:build_defs.bzl", + "onednn_v3_define", +) load( "//third_party/compute_library:build_defs.bzl", "if_enable_acl", @@ -253,8 +257,9 @@ def tsl_copts( # optimizations for Intel builds using oneDNN if configured if_enable_mkl(["-DENABLE_MKL"]) + if_mkldnn_openmp(["-DENABLE_ONEDNN_OPENMP"]) + - if_mkldnn_aarch64_acl(["-DDNNL_AARCH64_USE_ACL=1", "-DENABLE_ONEDNN_V2=1"]) + - if_mkldnn_aarch64_acl_openmp(["-DENABLE_ONEDNN_OPENMP", "-DENABLE_ONEDNN_V2=1"]) + + onednn_v3_define() + + if_mkldnn_aarch64_acl(["-DDNNL_AARCH64_USE_ACL=1"]) + + if_mkldnn_aarch64_acl_openmp(["-DENABLE_ONEDNN_OPENMP"]) + if_enable_acl(["-DXLA_CPU_USE_ACL=1", "-fexceptions"]) + if_android_arm(["-mfpu=neon", "-fomit-frame-pointer"]) + if_linux_x86_64(["-msse3"]) + From ae206bbd788fbb129c905aefd45e68e3dd5fd6fa Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 24 Jul 2023 18:28:34 -0700 Subject: [PATCH 085/410] Add side effect modeling of _XlaRun op. There are currently no side effect traits associated with tf._XlaRun and as a result its side effect is treated as unknown. This implements MemoryEffectsOpInterface to relax the conservatism. PiperOrigin-RevId: 550732191 --- .../mlir/tensorflow/ir/tf_generated_ops.td | 2 +- .../compiler/mlir/tensorflow/ir/tf_op_base.td | 2 +- .../compiler/mlir/tensorflow/ir/tf_ops_n_z.cc | 27 ++++++++++++++++ .../mlir/tensorflow/ir/tf_side_effects.h | 4 +++ .../tests/side-effect-analysis-test.mlir | 32 +++++++++++++++++++ 5 files changed, 65 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index 45ce31983de9d0..ab078e2ad31018 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -22775,7 +22775,7 @@ execution the transfer corresponds to.}]>:$dynamic_key, TF_DerivedResultTypeListAttr Toutputs = TF_DerivedResultTypeListAttr<0>; } -def TF__XlaRunOp : TF_Op<"_XlaRun", []> { +def TF__XlaRunOp : TF_Op<"_XlaRun", [DeclareOpInterfaceMethods]> { let summary = "XLA Run Op. For use by the XLA JIT only."; let description = [{ diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td index 7466b3bc37e487..246b51c172bdb6 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td @@ -190,7 +190,7 @@ def TF_XlaHostComputeResource : TF_ResourceBase<"XlaHostCompute">; def TF_WriteTrainingPredictionsResource : TF_ResourceBase<"WriteTrainingPredictions">; def TF_CollectiveReduceOrderingResource : TF_ResourceBase<"CollectiveReduceOrdering">; def TF_NcclAllReduceOrderingResource : TF_ResourceBase<"NcclAllReduceOrdering">; - +def TF__XlaRunResource : TF_ResourceBase<"_XlaRun">; // Fake resource, see `TF_MustExecute` below. def TF_MustExecuteResource : TF_ResourceBase<"MustExecute">; diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc index 1822e797c8089c..b8f2387660be46 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc @@ -2350,6 +2350,33 @@ void TPUExecuteOp::getEffects( } } +//===----------------------------------------------------------------------===// +// _XlaRunOp +//===----------------------------------------------------------------------===// + +void _XlaRunOp::getEffects( + SmallVectorImpl> + &effects) { + effects.reserve(2 * getArgs().size() + 1); + effects.emplace_back(MemoryEffects::Write::get(), + ResourceEffects::_XlaRun::get()); + + for (Value value : getArgs()) { + if (value.getType() + .cast() + .getElementType() + .isa()) { + // Conservatively mark resource handles as read and write, as without + // analyzing _XlaCompile, there is not sufficient information to determine + // effects on resources. + effects.emplace_back(MemoryEffects::Read::get(), value, + ResourceEffects::Variable::get()); + effects.emplace_back(MemoryEffects::Write::get(), value, + ResourceEffects::Variable::get()); + } + } +} + //===----------------------------------------------------------------------===// // WriteTrainingPredictions //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h index 2fe672f147716e..317bfb3a36a6d9 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h @@ -113,6 +113,10 @@ struct WriteTrainingPredictions StringRef getName() final { return "WriteTrainingPredictions"; } }; +struct _XlaRun : public ::mlir::SideEffects::Resource::Base<_XlaRun> { + StringRef getName() final { return "_XlaRun"; } +}; + // Returns true iff resource type with given ID is only self-dependent, i.e., // there are no dependencies to other resource types (including unknown resource // type). diff --git a/tensorflow/compiler/mlir/tensorflow/tests/side-effect-analysis-test.mlir b/tensorflow/compiler/mlir/tensorflow/tests/side-effect-analysis-test.mlir index 575b510cff8e31..6cd1cf67c0a93d 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/side-effect-analysis-test.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/side-effect-analysis-test.mlir @@ -2845,3 +2845,35 @@ func.func @call_pure_function(%arg0: tensor) -> tensor, + %arg1: tensor) { + tf_executor.graph { + // expected-remark@above {{ID: 5}} + %island = tf_executor.island { + // expected-remark@above {{ID: 3}} + // expected-remark@above {{Successors: {4}}} + "tf._XlaRun"(%arg0, %arg0) : (tensor, tensor) -> () + // expected-remark@above {{ID: 0}} + // expected-remark@above {{Successors: {1}}} + "tf._XlaRun"(%arg1, %arg1) : (tensor, tensor) -> () + // expected-remark@above {{ID: 1}} + // expected-remark@above {{Predecessors: {0}}} + // expected-remark@above {{Successors: {2}}} + tf_executor.yield + // expected-remark@above {{ID: 2}} + // expected-remark@above {{Predecessors: {1}}} + } + tf_executor.fetch %island : !tf_executor.control + // expected-remark@above {{ID: 4}} + // expected-remark@above {{Predecessors: {3}}} + } + func.return + // expected-remark@above {{ID: 6}} + // expected-remark@above {{Sinks: {5}}} +} From 8192e47d186ec36620eb63f7d7781c07a3d67cfb Mon Sep 17 00:00:00 2001 From: Junwhan Ahn Date: Mon, 24 Jul 2023 19:49:59 -0700 Subject: [PATCH 086/410] Change `mhlo.is_same_data_across_replicas` from unit attr to bool attr Using bool attrs aligns better with StableHLO. Since [VHLO does not define unit attrs](https://github.com/openxla/stablehlo/blob/main/stablehlo/dialect/VhloAttrs.td), serializing StableHLO modules containing unit attrs fails. This becomes a problem when we want to serialize MHLO modules containing `mhlo.is_same_data_across_replicas` by converting them into StableHLO then VHLO. JAX emits `mhlo.is_same_data_across_replicas` as a bool attr only after a new jaxlib version since this requires the jaxlib to understand the new attr type. PiperOrigin-RevId: 550745955 --- .../tensorflow/tests/annotate-parameter-replication.mlir | 8 ++++---- .../tests/compile_mlir_util/constant-folding.mlir | 2 +- .../tf_to_hlo_pipeline/sccp-post-shape-inference.mlir | 2 +- .../compiler/mlir/tensorflow/tests/tpu_rewrite.mlir | 4 ++-- .../mlir/tensorflow/tests/tpu_space_to_depth_pass.mlir | 8 ++++---- .../transforms/annotate_parameter_replication.cc | 3 ++- .../mlir/tensorflow/transforms/tpu_rewrite_pass.cc | 4 ++-- tensorflow/compiler/xla/python/xla_client.py | 2 +- .../compiler/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc | 5 +++-- .../xla/translate/mhlo_to_hlo/tests/export_replicas.mlir | 2 +- 10 files changed, 21 insertions(+), 19 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/annotate-parameter-replication.mlir b/tensorflow/compiler/mlir/tensorflow/tests/annotate-parameter-replication.mlir index a5d45664b9221f..b9cec3bae4b76f 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/annotate-parameter-replication.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/annotate-parameter-replication.mlir @@ -19,7 +19,7 @@ module attributes {tf.versions = {producer = 888 : i32}} { // CHECK-LABEL: func @_func // CHECK-SAME: %[[ARG0:.*]]: tensor, - // CHECK-SAME: %[[ARG1:.*]]: tensor {mhlo.is_same_data_across_replicas} + // CHECK-SAME: %[[ARG1:.*]]: tensor {mhlo.is_same_data_across_replicas = true} // CHECK-SAME: %[[ARG2:.*]]: tensor) func.func @_func(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { %0 = "tf._D"(%arg0, %arg1) : (tensor, tensor) -> tensor @@ -54,9 +54,9 @@ module attributes {tf.versions = {producer = 888 : i32}} { } // CHECK-LABEL: func @_func - // CHECK-SAME: %[[ARG0:.*]]: tensor {mhlo.is_same_data_across_replicas}, + // CHECK-SAME: %[[ARG0:.*]]: tensor {mhlo.is_same_data_across_replicas = true}, // CHECK-SAME: %[[ARG1:.*]]: tensor, - // CHECK-SAME: %[[ARG2:.*]]: tensor>> {mhlo.is_same_data_across_replicas} + // CHECK-SAME: %[[ARG2:.*]]: tensor>> {mhlo.is_same_data_across_replicas = true} func.func @_func(%arg0: tensor, %arg1: tensor, %arg2: tensor>>) -> tensor { %0 = "tf._D"(%arg0, %arg1) : (tensor, tensor) -> tensor func.return %0 : tensor @@ -78,7 +78,7 @@ module attributes {tf.versions = {producer = 888 : i32}} { } // CHECK-LABEL: func @_func - // CHECK-NOT: mhlo.is_same_data_across_replicas + // CHECK-NOT: mhlo.is_same_data_across_replicas = true func.func @_func(%arg0: tensor, %arg1: tensor) -> tensor { %0 = "tf._D"(%arg0, %arg1) : (tensor, tensor) -> tensor func.return %0 : tensor diff --git a/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/constant-folding.mlir b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/constant-folding.mlir index 67deb8a0a1dcd7..0250bf43c52418 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/constant-folding.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/constant-folding.mlir @@ -1,7 +1,7 @@ // RUN: tf-mlir-translate -mlir-tf-to-hlo-text %s -tf-input-shapes=10,19:19,10 -tf-xla-emit-use-tuple-args -tf-xla-emit-return-tuple | FileCheck %s module attributes {tf.versions = {producer = 179 : i32}} { - func.func @main(%arg0: tensor<10x19xf32>, %arg1: tensor<19x10xf32> {mhlo.is_same_data_across_replicas}) -> tensor<10x19xf32> { + func.func @main(%arg0: tensor<10x19xf32>, %arg1: tensor<19x10xf32> {mhlo.is_same_data_across_replicas = true}) -> tensor<10x19xf32> { %0 = "tf.Shape"(%arg0) : (tensor<10x19xf32>) -> tensor<2xi64> %1 = "tf.Reshape"(%arg1, %0) : (tensor<19x10xf32>, tensor<2xi64>) -> tensor<10x19xf32> func.return %1 : tensor<10x19xf32> diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_to_hlo_pipeline/sccp-post-shape-inference.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_to_hlo_pipeline/sccp-post-shape-inference.mlir index de0c3d80706696..ed393081108e1d 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_to_hlo_pipeline/sccp-post-shape-inference.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_to_hlo_pipeline/sccp-post-shape-inference.mlir @@ -5,7 +5,7 @@ module attributes {tf.versions = {producer = 179 : i32}} { // CHECK-LABEL: func @main - func.func @main(%arg0: tensor<10x19xf32>, %arg1: tensor<19x10xf32> {mhlo.is_same_data_across_replicas}) -> tensor { + func.func @main(%arg0: tensor<10x19xf32>, %arg1: tensor<19x10xf32> {mhlo.is_same_data_across_replicas = true}) -> tensor { %0 = "tf.Shape"(%arg0) : (tensor<10x19xf32>) -> tensor<2xi64> %1 = "tf.Reshape"(%arg1, %0) : (tensor<19x10xf32>, tensor<2xi64>) -> tensor diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir index e5ff1c4bc390fb..2075fb93267f5f 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir @@ -666,10 +666,10 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor %0 = "tf_device.cluster_func"(%arg0) {_xla_compile_device_type = "TPU", _replication_info = "cluster", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "", topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor<8xi32>) -> tensor<8xi32> // CHECK: metadata // CHECK-SAME: is_same_data_across_replicas: true - // CHECK-SAME: mhlo.is_same_data_across_replicas + // CHECK-SAME: mhlo.is_same_data_across_replicas = true func.return %0: tensor<8xi32> } - func.func @tpu0_func(%arg0: tensor<8xi32> {mhlo.is_same_data_across_replicas}) -> tensor<8xi32> { + func.func @tpu0_func(%arg0: tensor<8xi32> {mhlo.is_same_data_across_replicas = true}) -> tensor<8xi32> { func.return %arg0 : tensor<8xi32> } } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_space_to_depth_pass.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_space_to_depth_pass.mlir index 3f7ba1524beb23..2125f2877fb8cf 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_space_to_depth_pass.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_space_to_depth_pass.mlir @@ -113,8 +113,8 @@ module attributes {tf.devices = {"/job:localhost/replica:0/task:0/device:COMPOSI func.return } // CHECK-LABEL: func private @_func - // CHECK-SAME: [[FUNCINPUT00:.*]]: tensor<2x112x112x12xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg1: tensor<2x1xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg2: tensor<7x7x3x64xf32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg3: tensor<64x1001xf32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg4: tensor<1001xf32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg5: tensor {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg6: tensor {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg7: tensor {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg8: tensor {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) -> (tensor {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) { - func.func private @_func(%arg0: tensor<2x224x224x3xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg1: tensor<2x1xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg2: tensor<7x7x3x64xf32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg3: tensor<64x1001xf32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg4: tensor<1001xf32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg5: tensor {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg6: tensor {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg7: tensor {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg8: tensor {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) -> (tensor {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) { + // CHECK-SAME: [[FUNCINPUT00:.*]]: tensor<2x112x112x12xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg1: tensor<2x1xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg2: tensor<7x7x3x64xf32> {mhlo.is_same_data_across_replicas = true, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg3: tensor<64x1001xf32> {mhlo.is_same_data_across_replicas = true, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg4: tensor<1001xf32> {mhlo.is_same_data_across_replicas = true, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg5: tensor {mhlo.is_same_data_across_replicas = true, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg6: tensor {mhlo.is_same_data_across_replicas = true, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg7: tensor {mhlo.is_same_data_across_replicas = true, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg8: tensor {mhlo.is_same_data_across_replicas = true, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) -> (tensor {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) { + func.func private @_func(%arg0: tensor<2x224x224x3xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg1: tensor<2x1xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg2: tensor<7x7x3x64xf32> {mhlo.is_same_data_across_replicas = true, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg3: tensor<64x1001xf32> {mhlo.is_same_data_across_replicas = true, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg4: tensor<1001xf32> {mhlo.is_same_data_across_replicas = true, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg5: tensor {mhlo.is_same_data_across_replicas = true, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg6: tensor {mhlo.is_same_data_across_replicas = true, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg7: tensor {mhlo.is_same_data_across_replicas = true, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg8: tensor {mhlo.is_same_data_across_replicas = true, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) -> (tensor {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) { %0 = "tf.Const"() {value = dense<2.000000e+00> : tensor} : () -> tensor %1 = "tf.Const"() {value = dense<1.000000e+00> : tensor} : () -> tensor %2 = "tf.Const"() {value = dense<-1> : tensor} : () -> tensor @@ -187,8 +187,8 @@ module attributes {tf.devices = {"/job:localhost/replica:0/task:0/device:COMPOSI func.return } // CHECK-LABEL: func private @_func - // CHECK-SAME: [[FUNCINPUT00:.*]]: tensor<2x112x112x12xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg1: tensor<2x1xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg2: tensor<7x7x3x64xf32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg3: tensor<64x1001xf32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg4: tensor<1001xf32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg5: tensor {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg6: tensor {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg7: tensor {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg8: tensor {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) -> (tensor {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) { - func.func private @_func(%arg0: tensor<2x224x224x3xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg1: tensor<2x1xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg2: tensor<7x7x3x64xf32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg3: tensor<64x1001xf32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg4: tensor<1001xf32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg5: tensor {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg6: tensor {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg7: tensor {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg8: tensor {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) -> (tensor {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) { + // CHECK-SAME: [[FUNCINPUT00:.*]]: tensor<2x112x112x12xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg1: tensor<2x1xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg2: tensor<7x7x3x64xf32> {mhlo.is_same_data_across_replicas = true, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg3: tensor<64x1001xf32> {mhlo.is_same_data_across_replicas = true, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg4: tensor<1001xf32> {mhlo.is_same_data_across_replicas = true, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg5: tensor {mhlo.is_same_data_across_replicas = true, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg6: tensor {mhlo.is_same_data_across_replicas = true, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg7: tensor {mhlo.is_same_data_across_replicas = true, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg8: tensor {mhlo.is_same_data_across_replicas = true, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) -> (tensor {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) { + func.func private @_func(%arg0: tensor<2x224x224x3xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg1: tensor<2x1xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg2: tensor<7x7x3x64xf32> {mhlo.is_same_data_across_replicas = true, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg3: tensor<64x1001xf32> {mhlo.is_same_data_across_replicas = true, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg4: tensor<1001xf32> {mhlo.is_same_data_across_replicas = true, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg5: tensor {mhlo.is_same_data_across_replicas = true, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg6: tensor {mhlo.is_same_data_across_replicas = true, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg7: tensor {mhlo.is_same_data_across_replicas = true, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg8: tensor {mhlo.is_same_data_across_replicas = true, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) -> (tensor {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) { %0 = "tf.Const"() {value = dense<2.000000e+00> : tensor} : () -> tensor %1 = "tf.Const"() {value = dense<1.000000e+00> : tensor} : () -> tensor %2 = "tf.Const"() {value = dense<-1> : tensor} : () -> tensor diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/annotate_parameter_replication.cc b/tensorflow/compiler/mlir/tensorflow/transforms/annotate_parameter_replication.cc index 6fd14f88ad9a97..996686eb525d03 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/annotate_parameter_replication.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/annotate_parameter_replication.cc @@ -88,7 +88,8 @@ void AnnotateParameterReplicationPass::runOnOperation() { // Not a replication-invariant operand. continue; } - func.setArgAttr(entry.index(), kReplicationAttr, builder.getUnitAttr()); + func.setArgAttr(entry.index(), kReplicationAttr, + builder.getBoolAttr(true)); } }); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc index 97a96797d021fb..c6e0185320268e 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc @@ -264,9 +264,9 @@ LogicalResult SetMetadataProtoArgs( // Populate set_is_same_data_across_replicas // Note: this information is duplicated and can be removed from the proto // and here once MLIR bridge phase 2 doesn't fallback to the old bridge. - mlir::UnitAttr attr = op.getFuncOp().getArgAttrOfType( + auto attr = op.getFuncOp().getArgAttrOfType( index, replication_attr_name); - arg->set_is_same_data_across_replicas(attr != nullptr); + arg->set_is_same_data_across_replicas(attr != nullptr && attr.getValue()); // Currently only support first dimension to be bounded dynamic. arg->mutable_is_bounded_dynamic_dim()->Add( diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index af8950932876b0..90e1414154d5c5 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -44,7 +44,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. -_version = 171 +_version = 172 # Version number for MLIR:Python components. mlir_api_version = 54 diff --git a/tensorflow/compiler/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc b/tensorflow/compiler/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc index 56f42aa4ce801e..8db572653be221 100644 --- a/tensorflow/compiler/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc +++ b/tensorflow/compiler/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc @@ -3160,8 +3160,9 @@ LogicalResult ConvertToHloModule::RunOnFunction(mlir::func::FuncOp f) { bool any_arg_replicated = false; entry_args_same_across_replicas.reserve(f.getNumArguments()); for (int64_t i = 0; i < f.getNumArguments(); ++i) { - auto attr = f.getArgAttrOfType(i, kReplicationAttr); - entry_args_same_across_replicas.push_back(attr != nullptr); + auto attr = f.getArgAttrOfType(i, kReplicationAttr); + entry_args_same_across_replicas.push_back(attr != nullptr && + attr.getValue()); any_arg_replicated |= entry_args_same_across_replicas.back(); // Pass the alias info to the builder so that it will build the alias info // into the resulting HloModule. diff --git a/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_replicas.mlir b/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_replicas.mlir index 8650ecc333cb44..5b11ed4a4c323f 100644 --- a/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_replicas.mlir +++ b/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_replicas.mlir @@ -3,7 +3,7 @@ // Tests that the exported HLO module keeps parameter replication annotation. // CHECK: HloModule -func.func @main(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32> {mhlo.is_same_data_across_replicas}) -> tensor<16x16xf32> { +func.func @main(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32> {mhlo.is_same_data_across_replicas = true}) -> tensor<16x16xf32> { %0 = "mhlo.add"(%arg0, %arg1) : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xf32> func.return %0 : tensor<16x16xf32> } From 0be6c1eed172414c3a951f8ec46d2742ddab7abd Mon Sep 17 00:00:00 2001 From: Dateng Lin Date: Mon, 24 Jul 2023 20:31:09 -0700 Subject: [PATCH 087/410] Avoided redundant map lookup. PiperOrigin-RevId: 550754162 --- .../dtensor/mlir/expansions/fft_spmd_expander.cc | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tensorflow/dtensor/mlir/expansions/fft_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/fft_spmd_expander.cc index 69d07dfb06a4e8..d8de0e07b3be25 100644 --- a/tensorflow/dtensor/mlir/expansions/fft_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/fft_spmd_expander.cc @@ -427,10 +427,10 @@ StatusOr FFTSPMDExpander::ExpandOp(mlir::Operation* op) { StatusOr> FFTSPMDExpander::ComputeLayoutForward( mlir::Operation* op, const llvm::DenseMap& input_layouts) { - if (input_layouts.find(0) == input_layouts.end()) - return llvm::DenseMap(); + auto iter = input_layouts.find(0); + if (iter == input_layouts.end()) return llvm::DenseMap(); - const Layout& input_layout = input_layouts.lookup(0); + const Layout& input_layout = iter->second; std::vector sharding_specs = input_layout.sharding_spec_strs(); if (sharding_specs.empty()) return absl::FailedPreconditionError( @@ -468,10 +468,10 @@ StatusOr> FFTSPMDExpander::ComputeLayoutForward( StatusOr> FFTSPMDExpander::ComputeLayoutBackward( mlir::Operation* op, const llvm::DenseMap& output_layouts) { - if (output_layouts.find(0) == output_layouts.end()) - return llvm::DenseMap(); + auto iter = output_layouts.find(0); + if (iter == output_layouts.end()) return llvm::DenseMap(); - const Layout& output_layout = output_layouts.lookup(0); + const Layout& output_layout = iter->second; std::vector sharding_specs = output_layout.sharding_spec_strs(); if (sharding_specs.empty()) return absl::FailedPreconditionError( From 1e957846e0b5de00373894ef1a8e111d22dbff3c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 24 Jul 2023 20:41:47 -0700 Subject: [PATCH 088/410] Integrate LLVM at llvm/llvm-project@c6f66de21af0 Updates LLVM usage to match [c6f66de21af0](https://github.com/llvm/llvm-project/commit/c6f66de21af0) PiperOrigin-RevId: 550756438 --- third_party/llvm/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index b05b90ca851727..906573eae3531e 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "fc3b7874b6c95f04a249e2c9da3c5221f50c85b2" - LLVM_SHA256 = "4956bc88608ac68c9536ace1669a3936e9b40896e8fe1e2c1d8658986818bf6a" + LLVM_COMMIT = "c6f66de21af060ead6e5402858351e9e869dc15f" + LLVM_SHA256 = "91d75027562704b7e9941bbeff6b174ecfe5ea26be5dcd149b5131a2109520ea" tf_http_archive( name = name, From f15c449f0a79b5ec85066a6162388ff01fe1fb59 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 24 Jul 2023 20:48:00 -0700 Subject: [PATCH 089/410] Removing tfrt gpu backend. PiperOrigin-RevId: 550757559 --- .../compiler/mlir/tfrt/lhlo-tfrt-opt.cc | 3 +- .../runtime/gpu/conversion_function.cc | 242 ------------------ .../runtime/gpu/conversion_function.h | 34 --- .../runtime/gpu/static_registration.cc | 32 --- tensorflow/core/runtime_fallback/util/BUILD | 35 --- 5 files changed, 1 insertion(+), 345 deletions(-) delete mode 100644 tensorflow/core/runtime_fallback/runtime/gpu/conversion_function.cc delete mode 100644 tensorflow/core/runtime_fallback/runtime/gpu/conversion_function.h delete mode 100644 tensorflow/core/runtime_fallback/runtime/gpu/static_registration.cc diff --git a/tensorflow/compiler/mlir/tfrt/lhlo-tfrt-opt.cc b/tensorflow/compiler/mlir/tfrt/lhlo-tfrt-opt.cc index 7af989f9cd3218..5c1d2cbbcebaf9 100644 --- a/tensorflow/compiler/mlir/tfrt/lhlo-tfrt-opt.cc +++ b/tensorflow/compiler/mlir/tfrt/lhlo-tfrt-opt.cc @@ -20,7 +20,6 @@ limitations under the License. #include "mlir/InitAllPasses.h" // from @llvm-project #include "mlir/Tools/mlir-opt/MlirOptMain.h" // from @llvm-project #include "tensorflow/compiler/mlir/init_mlir.h" -#include "tfrt/gpu/kernels/gpu_ops.h" // from @tf_runtime #include "tfrt/gpu/passes/passes.h" // from @tf_runtime #include "tfrt/init_tfrt_dialects.h" // from @tf_runtime @@ -30,7 +29,7 @@ int main(int argc, char **argv) { mlir::DialectRegistry registry; mlir::registerAllDialects(registry); registry.insert(); + mlir::mhlo::MhloDialect>(); tfrt::RegisterTFRTDialects(registry); mlir::registerAllPasses(); diff --git a/tensorflow/core/runtime_fallback/runtime/gpu/conversion_function.cc b/tensorflow/core/runtime_fallback/runtime/gpu/conversion_function.cc deleted file mode 100644 index f5f08d0f8771ce..00000000000000 --- a/tensorflow/core/runtime_fallback/runtime/gpu/conversion_function.cc +++ /dev/null @@ -1,242 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This file implements conversion function between TFRuntimeFallback and Gpu -// tensors. - -#include "tensorflow/core/runtime_fallback/runtime/gpu/conversion_function.h" - -#include "absl/status/status.h" -#include "absl/strings/match.h" -#include "tensorflow/core/runtime_fallback/runtime/kernel_utils.h" -#include "tensorflow/core/runtime_fallback/runtime/runtime_fallback_tensor.h" -#include "tensorflow/core/runtime_fallback/util/attr_util.h" -#include "tensorflow/core/runtime_fallback/util/gpu/gpu_utils.h" -#include "tensorflow/core/runtime_fallback/util/type_util.h" -#include "tfrt/gpu/device/conversion_function.h" // from @tf_runtime -#include "tfrt/gpu/device/device.h" // from @tf_runtime -#include "tfrt/gpu/device/device_util.h" // from @tf_runtime -#include "tfrt/gpu/gpu_types.h" // from @tf_runtime -#include "tfrt/gpu/tensor/dense_gpu_tensor.h" // from @tf_runtime -#include "tfrt/host_context/async_value_ref.h" // from @tf_runtime -#include "tfrt/host_context/diagnostic.h" // from @tf_runtime -#include "tfrt/host_context/execution_context.h" // from @tf_runtime -#include "tfrt/host_context/host_buffer.h" // from @tf_runtime -#include "tfrt/host_context/host_context.h" // from @tf_runtime -#include "tfrt/support/error_util.h" // from @tf_runtime -#include "tfrt/tensor/conversion_registry.h" // from @tf_runtime -#include "tfrt/tensor/conversion_utils.h" // from @tf_runtime -#include "tfrt/tensor/tensor.h" // from @tf_runtime - -namespace tensorflow { -namespace tfd { - -static tfrt::Expected -CopyRefGpuTensorToRuntimeFallbackTensor( - const tfrt::gpu::DenseGpuTensor& gpu_tensor, Device* device, - Device* op_device, EagerContext* eager_ctx) { - // Do not copy the gpu buffer content, CopyRef on the buffer instead. - tfrt::AsyncValueRef gpu_buffer = - gpu_tensor.CopyBufferRef(); - tfrt::Expected tensor = MoveGpuBufferToTFTensor( - std::move(gpu_buffer), gpu_tensor.dtype(), gpu_tensor.shape()); - if (!tensor) return tensor.takeError(); - - OwnedTensorHandle tensor_handle{tensorflow::TensorHandle::CreateLocalHandle( - std::move(tensor.get()), device, op_device, eager_ctx)}; - return RuntimeFallbackTensor(gpu_tensor.shape(), gpu_tensor.dtype(), - std::move(tensor_handle)); -} - -// Convert the RuntimeFallbackTensor to a GpuTensor (currently DenseGpuTensor -// only). If the source tensor is on CPU, copy the data to GPU. If the source -// tensor is already on GPU, just do type conversion. -// TODO(b/167254525): For TFRuntimeFallback tensor, create separate tensor -// types for different devices. -static tfrt::AsyncValueRef -ConvertRuntimeFallbackTensorToDenseGpuTensor( - const RuntimeFallbackTensor& tensor, const tfrt::Device& src, - const tfrt::gpu::GpuDevice& dst, const tfrt::ExecutionContext& exec_ctx) { - auto* host_ctx = exec_ctx.host(); - - auto tf_tensor_handle = tensor.GetTensorHandle(); - - tensorflow::Status status; - const char* device_name = tf_tensor_handle->DeviceName(&status); - - auto tensor_device_ref = - host_ctx->GetDeviceManager()->GetDeviceRef(device_name); - - if (!tensor_device_ref) { - tensor_device_ref = - host_ctx->GetDeviceManager()->GetDeviceRef( - ConvertTfDeviceNameToTfrtDefault(device_name)); - } - - if (!tensor_device_ref) - return tfrt::EmitErrorAsync( - exec_ctx, - tfrt::StrCat("Failed to find a device with name: ", device_name)); - - if (!status.ok()) { - return EmitErrorAsync( - exec_ctx, tfrt::StrCat("error getting device name from TensorHandle: ", - status.message())); - } - - // Check if the underlying tensorflow::TensorHandle is already on GPU. - // If so, just convert the RuntimeFallbackTensor to GpuTensor. - if (tensor_device_ref.get() == &dst) { - tensorflow::TensorShape shape; - tensorflow::Status status = tf_tensor_handle->Shape(&shape); - if (!status.ok()) { - return EmitErrorAsync( - exec_ctx, tfrt::StrCat("error getting shape from TF tensor handle: ", - status.message())); - } - - auto tf_shape = shape.dim_sizes(); - DataType dtype = tf_tensor_handle->DataType(); - // Note that GPU tensor might not be available yet. But since TF - // and TFRT share the same stream, this is ok. - const tensorflow::Tensor* tf_tensor = nullptr; - status = tf_tensor_handle->Tensor(&tf_tensor); - if (!status.ok()) { - return EmitErrorAsync(exec_ctx, - tfrt::StrCat("error calling TensorHandle::Tensor: ", - status.message())); - } - - auto platform = tensorflow::tfd::GetTfrtGpuPlatform(tf_tensor_handle); - - void* data = tf_tensor->data(); - size_t size = tf_tensor->TotalBytes(); - - // Need to add a reference here since we are transferring the ownership - // of the Tensorflow::TensorHandle and the underlying GPU buffer to - // tfrt::DenseGpuTensor. Otherwise, the TensorHandle will be released - // when he RuntimeFallbackTensor goes out of scope after the tensor - // conversion. The GPU buffer will be deleted as well. - tf_tensor_handle->Ref(); - OwnedTensorHandle owned_tf_tensor_handle = - OwnedTensorHandle{TensorHandleFromInterface(tf_tensor_handle)}; - - // The OwnedTensorHandle holds a reference on underlying Tensorflow buffer - // and is held alive by GpuOneShotAllocator. - auto allocator = tfrt::MakeAvailableAsyncValueRef< - tfrt::gpu::GpuOneShotAllocator>( - tfrt::gpu::wrapper::Pointer(data, platform), - std::move(owned_tf_tensor_handle)); - llvm::Expected gpu_buffer = - tfrt::gpu::GpuBuffer::Allocate(std::move(allocator), size); - if (!gpu_buffer) { - return tfrt::MakeErrorAsyncValueRef(tfrt::StrCat(gpu_buffer.takeError())); - } - - // create DenseGpuTensor. - tfrt::gpu::DenseGpuTensor gpu_tensor{ - tfrt::TensorShape( - std::vector(tf_shape.begin(), tf_shape.end())), - GetTfrtDtype(dtype), - tfrt::MakeAvailableAsyncValueRef( - std::move(*gpu_buffer))}; - - return tfrt::MakeAvailableAsyncValueRef( - std::move(gpu_tensor)); - } else { - // TODO(chuanhao): clean up the branch after cl/325503773. Currently this - // branch is needed since we don't know what type of tensor that - // RuntimeFallbackTensor holds. - // tensorflow::TensorHandle is on host CPU. - assert(tensor_device_ref.get() == &host_ctx->GetHostDevice()); - - // Convert the TFRuntimeFallbackTensor to DenseHostTensor. - auto host_tensor_ref = tfrt::ConvertTensor( - exec_ctx, tensor, src, src, tfrt::DenseHostTensor::kTensorType); - - if (!host_tensor_ref.get().IsTensorType(tfrt::DenseHostTensor::kTensorType)) - return EmitErrorAsync(exec_ctx, - "TFRuntimeFallbackTensor not converted to " - "DenseHostTensor."); - llvm::Expected current_context = - dst.SetCurrentContext(); - if (!current_context) { - return tfrt::MakeErrorAsyncValueRef( - tfrt::StrCat(current_context.takeError())); - } - - auto expected_gpu_tensor = - tfrt::gpu::ConvertDenseHostTensorToDenseGpuTensor( - std::move(current_context.get()), dst.stream(), dst.allocator(), - llvm::cast(host_tensor_ref.get()), host_ctx); - if (!expected_gpu_tensor) { - return EmitErrorAsync( - exec_ctx, - absl::InternalError(toString(expected_gpu_tensor.takeError()))); - } - return tfrt::MakeAvailableAsyncValueRef( - std::move(expected_gpu_tensor.get())); - } -} - -static tfrt::AsyncValueRef -ConvertDenseGpuTensorToRuntimeFallbackTensor( - const tfrt::gpu::DenseGpuTensor& tensor, const tfrt::gpu::GpuDevice& src, - const tfrt::gpu::GpuDevice& dst, const tfrt::ExecutionContext& exec_ctx) { - tfrt::ResourceContext* resource_context = exec_ctx.resource_context(); - tensorflow::tfd::EagerContextResource* eager_context_resource = - resource_context - ->GetOrCreateResource( - tensorflow::tfd::kEagerContextResourceName); - - tfrt::Expected eager_ctx_expected = - eager_context_resource->GetTFEagerContext(); - if (!eager_ctx_expected) - return EmitErrorAsync(exec_ctx, eager_ctx_expected.takeError()); - - EagerContext* eager_ctx = eager_ctx_expected.get(); - - assert(&src == &dst); - Device* device; - Status status = eager_ctx->local_device_mgr()->LookupDevice( - ToAbslStringView(dst.name()), &device); - if (!status.ok()) - return EmitErrorAsync(exec_ctx, - absl::InternalError(tfrt::StrCat( - "error looking up gpu device from EagerContext: ", - status.message()))); - - auto fallback_tensor = CopyRefGpuTensorToRuntimeFallbackTensor( - tensor, device, device, eager_ctx); - if (fallback_tensor) { - return tfrt::MakeAvailableAsyncValueRef( - std::move(*fallback_tensor)); - } else { - return EmitErrorAsync( - exec_ctx, absl::InternalError(toString(fallback_tensor.takeError()))); - } -} - -void RegisterTFRuntimeFallbackTensorToGpuConversionFn( - tfrt::TensorConversionFnRegistry* registry) { - registry->AddTensorConversionFn( - TFRT_CONVERSION(ConvertRuntimeFallbackTensorToDenseGpuTensor)); - - registry->AddTensorConversionFn( - TFRT_CONVERSION(ConvertDenseGpuTensorToRuntimeFallbackTensor)); -} - -} // namespace tfd -} // namespace tensorflow diff --git a/tensorflow/core/runtime_fallback/runtime/gpu/conversion_function.h b/tensorflow/core/runtime_fallback/runtime/gpu/conversion_function.h deleted file mode 100644 index 105902fee11836..00000000000000 --- a/tensorflow/core/runtime_fallback/runtime/gpu/conversion_function.h +++ /dev/null @@ -1,34 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This file declares TFRuntimeFallback tensor conversion functions for copying -// between gpu and host. - -#ifndef TENSORFLOW_CORE_RUNTIME_FALLBACK_RUNTIME_GPU_CONVERSION_FUNCTION_H_ -#define TENSORFLOW_CORE_RUNTIME_FALLBACK_RUNTIME_GPU_CONVERSION_FUNCTION_H_ - -#include "tfrt/tensor/conversion_registry.h" // from @tf_runtime - -namespace tensorflow { -namespace tfd { - -// Register conversion functions for TFRuntimeFallbackTensors. -void RegisterTFRuntimeFallbackTensorToGpuConversionFn( - tfrt::TensorConversionFnRegistry* registry); - -} // namespace tfd -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_RUNTIME_FALLBACK_RUNTIME_GPU_CONVERSION_FUNCTION_H_ diff --git a/tensorflow/core/runtime_fallback/runtime/gpu/static_registration.cc b/tensorflow/core/runtime_fallback/runtime/gpu/static_registration.cc deleted file mode 100644 index eebdf225f60a17..00000000000000 --- a/tensorflow/core/runtime_fallback/runtime/gpu/static_registration.cc +++ /dev/null @@ -1,32 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This file uses a static constructor to automatically register conversion -// functions for TFRuntimeFallback tensor. - -#include "tensorflow/core/runtime_fallback/runtime/gpu/conversion_function.h" -#include "tfrt/tensor/conversion_registry.h" // from @tf_runtime - -namespace tensorflow { -namespace tfd { - -static bool runtime_fallback_to_gpu_conversion_fn_registration = []() { - tfrt::AddStaticTensorConversionFn( - RegisterTFRuntimeFallbackTensorToGpuConversionFn); - return true; -}(); - -} // namespace tfd -} // namespace tensorflow diff --git a/tensorflow/core/runtime_fallback/util/BUILD b/tensorflow/core/runtime_fallback/util/BUILD index 99cd09e1e91388..53e8584c7ae754 100644 --- a/tensorflow/core/runtime_fallback/util/BUILD +++ b/tensorflow/core/runtime_fallback/util/BUILD @@ -115,41 +115,6 @@ tf_cuda_library( }), ) -tf_cuda_library( - name = "gpu_util", - srcs = [ - "gpu/gpu_utils.cc", - ], - hdrs = [ - "gpu/gpu_utils.h", - ], - compatible_with = [], - # Only build this library with --config=cuda. - tags = [ - "manual", - "requires_cuda", - ], - deps = [ - ":tensor_util", - ":type_util", - "@tf_runtime//:support", - "@tf_runtime//backends/gpu:gpu_config", - "@tf_runtime//backends/gpu:gpu_device", - ] + select({ - "//tensorflow:android": [ - "//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs - ], - "//conditions:default": [ - "//tensorflow/c:tf_tensor", - "//tensorflow/c:tf_tensor_internal", - "//tensorflow/compiler/xla/stream_executor:platform", - "//tensorflow/compiler/xla/stream_executor/cuda:cuda_driver", - "//tensorflow/core/common_runtime/eager:tensor_handle", - "//tensorflow/core/common_runtime/gpu:gpu_runtime", - ], - }), -) - tf_cc_test( name = "type_util_test", srcs = ["type_util_test.cc"], From 996865f038df52acefa4dcba97047fe3181e1280 Mon Sep 17 00:00:00 2001 From: Deqiang Chen Date: Mon, 24 Jul 2023 20:56:51 -0700 Subject: [PATCH 090/410] Create an async_while kernel that may dispatch an iteration before the previous iteration is complete. PiperOrigin-RevId: 550759127 --- tensorflow/core/tfrt/mlrt/kernel/kernel.cc | 294 +++++++++++++++++- .../core/tfrt/mlrt/kernel/kernel_test.cc | 281 +++++++++++++++++ 2 files changed, 574 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/tfrt/mlrt/kernel/kernel.cc b/tensorflow/core/tfrt/mlrt/kernel/kernel.cc index 0c2e0bf12a67c8..b47be8d5aa6520 100644 --- a/tensorflow/core/tfrt/mlrt/kernel/kernel.cc +++ b/tensorflow/core/tfrt/mlrt/kernel/kernel.cc @@ -50,6 +50,297 @@ namespace tensorflow { namespace tf_mlrt { namespace { +// AsyncWhileOp dispatch the body function repeatedly until the body function +// returns a predicate value of false. Each invocation of the body function +// corresponds to an iteration in a while loop. The body function is expected to +// have the following input signature (predicate_promise, +// mutable_tensor0_future, mutable_tensor0_promise, mutable_tensor1_future, +// mutable_tensor1_promise, ...., immutable_tensors). AsyncWhileOp dispatch the +// next iteraion as soon as the previous iteration has set the +// predicate_promise. Hence, in the case that the body function set +// predicate_promise earlier than setting any other promises, multiple +// iterations can run parallelly via this op. +class AsyncWhileOp : mlrt::KernelFrame { + public: + using KernelFrame::KernelFrame; + + static constexpr char kName[] = "tf_mlrt.async_while"; + + mlrt::bc::Function body_function() const { + uint32_t func_idx = attributes().GetAs(0); + return execution_context() + .loaded_executable() + .executable() + .functions()[func_idx]; + } + + // Arguments that remains unchanged between iterations are called + // immutable(invariants). Immutables are all at the bottom of the argument + // list. Immutable_size reflects the number of immutables. + uint32_t immutable_size() const { return attributes().GetAs(1); } + + void Invoke(); + + private: + // This utility function is used when an iteration has set its + // predicate_promise. If predicate is true, it dispatches the next iteration. + // If predicate is false, it set ups the AsyncOp's return futures via + // final_promises. + static void OnPredicateReady( + tensorflow::tfrt_stub::FallbackTensor predicate, + std::vector async_handles, + std::vector mutable_tensor_futures, + std::vector immutable_tensors, + std::vector final_promises, mlrt::bc::Function body_fn, + mlrt::ExecutionContext& execution_context, uint32_t counter); + + // A utility function to populate the results in final_promises. + static void PopulateFinalPromise( + std::vector& final_promises, + const std::vector& mutable_tensor_futures, + const std::vector& + immutable_tensors); +}; + +void AsyncWhileOp::OnPredicateReady( + tensorflow::tfrt_stub::FallbackTensor predicate, + std::vector async_handles, + std::vector mutable_tensor_futures, + std::vector immutable_tensors, + std::vector final_promises, mlrt::bc::Function body_fn, + mlrt::ExecutionContext& execution_context, uint32_t counter) { + // final_promises[0] contains the final predicate and serves something similar + // as async_handle that the caller can wait and know the program is complete. + DCHECK_EQ(final_promises.size(), + mutable_tensor_futures.size() + immutable_tensors.size() + 1); + + // [predicate_promise; arg0_future, arg0_promise, arg1_future, arg1_promise, + // ..., immutable_args] + const uint32_t body_argument_size = + 1 + 2 * mutable_tensor_futures.size() + immutable_tensors.size(); + DCHECK_EQ(body_fn.input_regs().size(), body_argument_size); + + tsl::profiler::TraceMe trace_me([&]() { + return tsl::profiler::TraceMeEncode( + "tf_mlrt.AsyncWhileOp.OnPredicateReady", + {{"counter", counter}, {"name", body_fn.name().Get()}}); + }); + + bool predicate_value = predicate.tensor().scalar()(); + if (!predicate_value) { + // No more iterations. + if (async_handles.empty()) { + // Initial predicate is false + PopulateFinalPromise(final_promises, mutable_tensor_futures, + immutable_tensors); + } else { + // Iterations ends. Wait for all futures to be ready. + mlrt::Future await_all = mlrt::AwaitAll(absl::MakeSpan(async_handles)); + std::move(await_all).Then( + [final_promises = std::move(final_promises), + variant_tensor_futures = std::move(mutable_tensor_futures), + async_handles = std::move(async_handles), + immutable_tensors](absl::Status status) mutable { + if (status.ok()) { + PopulateFinalPromise(final_promises, variant_tensor_futures, + immutable_tensors); + return; + } else { + for (auto& final_promise : final_promises) { + std::move(final_promise).SetError(status); + } + } + }); + } + return; + } + // proceed to schedule the next iteration n+1. + // Creates arguments for dispatching the next iteration. + std::vector body_args; + body_args.resize(body_argument_size); + + // Set predicate_promise + auto arg_iter = body_args.begin(); + auto predicate_promise = + mlrt::Promise::Allocate(); + auto predicate_future = predicate_promise.GetFuture(); + arg_iter->Set(std::move(predicate_promise)); + ++arg_iter; + + // Current iteration n receives mutable tensor values in future from + // iteration n-1 and creates promises to return those mutable tensors after + // updating them from the current iteration. + std::vector next_futures; + next_futures.reserve(mutable_tensor_futures.size()); + + for (auto& mutable_tensor : mutable_tensor_futures) { + // Future from the previous iteration as input to the current iteration. + arg_iter->Set(std::move(mutable_tensor)); + ++arg_iter; + + // Promise to return values from the current iteration. + auto next_promise = + mlrt::Promise::Allocate(); + next_futures.push_back(next_promise.GetFuture()); + arg_iter->Set(std::move(next_promise)); + ++arg_iter; + } + + // Tensors that remains unchanged across iterations are copied over due to + // asynchronous execution between iterations. + for (auto& immutable_tensor : immutable_tensors) { + arg_iter->Set(immutable_tensor); + arg_iter++; + } + + // Launch this iteration. + auto [promise, handle] = mlrt::AsyncHandle::Allocate(execution_context); + auto& thread_execution_context = handle.execution_context(); + thread_execution_context.set_exit_handler( + [&execution_context = thread_execution_context, + promise = std::move(promise)]() mutable { + std::move(promise).Finish(execution_context.status()); + }); + + thread_execution_context.CallByMove(body_fn, absl::MakeSpan(body_args), + absl::Span()); + + thread_execution_context.work_queue()->AddTask( + [&execution_context = thread_execution_context]() { + mlrt::Execute(execution_context); + }); + + // save handles + async_handles.push_back(std::move(handle)); + + std::move(predicate_future) + .Then([futures = std::move(next_futures), + immutable_tensors = std::move(immutable_tensors), + final_promises = std::move(final_promises), + body_args = std::move(body_args), + async_handles = std::move(async_handles), body_fn, counter, + &execution_context = thread_execution_context]( + absl::StatusOr + predicate_result) mutable { + if (!predicate_result.ok()) { + auto status = predicate_result.status(); + mlrt::Future await_all = + mlrt::AwaitAll(absl::MakeSpan(async_handles)); + std::move(await_all).Then([final_promises = std::move(final_promises), + async_handles = std::move(async_handles), + status]() mutable { + for (auto& final_promise : final_promises) { + std::move(final_promise).SetError(status); + } + }); + execution_context.Fail(status); + return; + } + + // Keep body_args alive for thread execution. + OnPredicateReady(*predicate_result, std::move(async_handles), + std::move(futures), immutable_tensors, + std::move(final_promises), body_fn, execution_context, + ++counter); + }); +} + +void AsyncWhileOp::PopulateFinalPromise( + std::vector& final_promises, + const std::vector& mutable_tensor_futures, + const std::vector& + immutable_tensors) { + // The final predicate needs to be a tensor, not bool so that await_all + // can be used. + tensorflow::Tensor final_predicate_tensor(false); + + auto final_promise_iter = final_promises.begin(); + std::move(*final_promise_iter) + .Set( + tensorflow::tfrt_stub::FallbackTensor( + std::move(final_predicate_tensor))); + final_promise_iter++; + for (auto& mutable_tensor_future : mutable_tensor_futures) { + DCHECK(mutable_tensor_future.IsReady()); + std::move(*final_promise_iter) + .Set( + std::move(mutable_tensor_future + .Get())); + final_promise_iter++; + } + for (auto& immutable_tensor : immutable_tensors) { + std::move(*final_promise_iter) + .Set(immutable_tensor); + final_promise_iter++; + } +} + +void AsyncWhileOp::Invoke() { + mlrt::bc::Function body_fn = body_function(); + + // Argument: [final_predicate, %variant0, %variant1, ..., %invariant0,...] + // + // Results: [final_predicate, %variant0, %variant1, ..., %invariant0,...] + // + DCHECK_EQ(arguments().size(), results().size()); + + // [predicate_promise; arg0_future, arg0_promise, arg1_future, arg1_promise, + // ..., invariant_args] + // minus 1 b/c predicate is not a tensor + const uint32_t immutable_tensor_size = immutable_size(); + const uint32_t mutable_tensor_size = + arguments().size() - immutable_tensor_size - 1; + + const uint32_t body_argument_size = + 1 + (2 * mutable_tensor_size) + immutable_tensor_size; + DCHECK_EQ(body_fn.input_regs().size(), body_argument_size); + DCHECK_EQ(body_fn.output_regs().size(), 0); + + tsl::profiler::TraceMe trace_me([&]() { + return tsl::profiler::TraceMeEncode("tf_mlrt.async_while", + {{"name", body_fn.name().Get()}}); + }); + + // Save the future of final results. The last iteration will set the promises. + std::vector final_promises; + final_promises.reserve(arguments().size()); + for (int i = 0; i < arguments().size(); ++i) { + final_promises.push_back( + mlrt::Promise::Allocate()); + results()[i] = final_promises.back().GetFuture(); + } + + // Populate input arguments into a list of dummy futures to bootstrap the + // first iteration. + std::vector mutable_tensor_futures; + mutable_tensor_futures.reserve(mutable_tensor_size); + + // Plus 1 because the very first argument is a boolean predicate . + auto arg_iter = arguments().begin() + 1; + for (int i = 0; i < mutable_tensor_size; ++i) { + auto tensor_promise = + mlrt::Promise::Allocate(); + mutable_tensor_futures.push_back(tensor_promise.GetFuture()); + std::move(tensor_promise) + .Set( + arg_iter->Get()); + arg_iter++; + } + + std::vector immutable_tensors; + immutable_tensors.reserve(immutable_tensor_size); + for (int i = 0; i < immutable_tensor_size; ++i) { + immutable_tensors.push_back( + arg_iter->Get()); + arg_iter++; + } + OnPredicateReady(arguments()[0].Get(), + /*async_handles=*/{}, std::move(mutable_tensor_futures), + immutable_tensors, std::move(final_promises), body_fn, + execution_context(), + /*counter=*/0); +} + struct MapFnOp : mlrt::KernelFrame { using KernelFrame::KernelFrame; @@ -142,7 +433,7 @@ void MapFnOp::Invoke() { body_arg_last_uses.begin() + 2 * num_tensor_list_or_flow_in() + 2, true); - // Copy the invairant arguments (after max_iteration + + // Copy the invariant arguments (after max_iteration + // tensor_list_or_flow_ins) auto arg_iter = body_args.begin() + 2 * num_tensor_list_or_flow_in() + 2; for (int j = num_tensor_list_or_flow_in() + 1; j < arguments().size(); @@ -684,6 +975,7 @@ void RegisterTfMlrtKernels(mlrt::KernelRegistry& registry) { registry.Register(); registry.Register("tfrt_fallback_sync.executeop"); registry.Register(); + registry.Register(); registry.Register(); registry.Register(); registry.Register("tf_mlrt.set_resource", &SetResource); diff --git a/tensorflow/core/tfrt/mlrt/kernel/kernel_test.cc b/tensorflow/core/tfrt/mlrt/kernel/kernel_test.cc index 3ec567743fdf21..6e090a7a49f257 100644 --- a/tensorflow/core/tfrt/mlrt/kernel/kernel_test.cc +++ b/tensorflow/core/tfrt/mlrt/kernel/kernel_test.cc @@ -2438,6 +2438,287 @@ TEST(KernelTest, PromiseReturn) { output.Get().tensor(), expected); } +// A function body for AsyncWhile. +void TestAsyncWhileFnBody(mlrt::KernelFrame frame) { + ASSERT_EQ(frame.arguments().size(), 4); + + auto predicate_promise = std::move(frame.arguments()[0].Get()); + auto prev_loop_count_future = frame.arguments()[1].Get(); + auto next_loop_count_promise = + std::move(frame.arguments()[2].Get()); + + int32_t max_iteration = frame.arguments()[3] + .Get() + .tensor() + .scalar()(); + + for (; !prev_loop_count_future.IsReady();) { + // wait for future to be ready + } + int32_t prev_loop_count = + prev_loop_count_future.Get() + .tensor() + .scalar()(); + tensorflow::Tensor next_loop_count(DT_INT32, {}); + next_loop_count.scalar()() = prev_loop_count + 1; + + tensorflow::Tensor predicate(DT_BOOL, {}); + predicate.scalar()() = prev_loop_count + 1 < max_iteration; + std::move(predicate_promise) + .Set(std::move(predicate)); + + std::move(next_loop_count_promise) + .Set(std::move(next_loop_count)); +} + +mlrt::bc::Buffer CreateAsyncWhileExecutable() { + mlrt::bc::Buffer buffer; + mlrt::bc::Allocator allocator(&buffer); + + auto executable_ctor = mlrt::bc::New(&allocator); + mlrt::testing::SymbolTable kernels; + std::vector kernel_names = {"tf_mlrt.async_while", + "tf_mlrt.await_all", + "test_async_while_body", "return"}; + executable_ctor.construct_kernel_names(kernel_names.size()) + .Assign(kernel_names); + kernels.Def(kernel_names); + mlrt::testing::AttributeTable attributes( + executable_ctor.construct_attributes(1)); + + attributes.Add("body_idx", 1); + attributes.Add("invariant_size", 1); + auto functions_ctor = executable_ctor.construct_functions(2); + + { + auto function_ctor = functions_ctor.ConstructAt(0); + function_ctor.construct_name("main"); + mlrt::testing::SymbolTable regs; + function_ctor.construct_input_regs(3).Assign( + regs.Def({"initial_predicate", "loop_count", "max_iterations"})); + + auto kernels_ctor = function_ctor.construct_kernels(3); + { + auto kernel_ctor = kernels_ctor.ConstructAt(0); + kernel_ctor.set_code(kernels.Use("tf_mlrt.async_while")); + kernel_ctor.construct_attributes(2).Assign( + {attributes.GetHandle("body_idx"), + attributes.GetHandle("invariant_size")}); + kernel_ctor.construct_arguments(3).Assign( + regs.Use({"initial_predicate", "loop_count", "max_iterations"})); + kernel_ctor.construct_results(3).Assign( + regs.Def({"last_predicate_future", "final_loop_count_future", + "final_max_iterations_future"})); + } + { + auto kernel_ctor = kernels_ctor.ConstructAt(1); + kernel_ctor.set_code(kernels.Use("tf_mlrt.await_all")); + kernel_ctor.construct_arguments(3).Assign( + regs.Use({"last_predicate_future", "final_loop_count_future", + "final_max_iterations_future"})); + kernel_ctor.construct_last_uses(3).Assign({true, true, true}); + kernel_ctor.construct_results(3).Assign(regs.Def( + {"last_predicate", "final_loop_count", "final_max_iterations"})); + } + { + auto kernel_ctor = kernels_ctor.ConstructAt(2); + kernel_ctor.set_code(kernels.Use("return")); + kernel_ctor.construct_arguments(1).Assign({regs.Use("final_loop_count")}); + } + function_ctor.set_num_regs(regs.size()); + function_ctor.construct_output_regs(1).Assign( + {regs.Use("final_loop_count")}); + } + { + auto function_ctor = functions_ctor.ConstructAt(1); + function_ctor.construct_name("body_function"); + + mlrt::testing::SymbolTable regs; + + function_ctor.construct_input_regs(4).Assign( + regs.Def({"predicate_promise", "prev_loop_count_future", + "loop_count_promise", "max_iterations"})); + auto kernels_ctor = function_ctor.construct_kernels(2); + { + auto kernel_ctor = kernels_ctor.ConstructAt(0); + kernel_ctor.set_code(kernels.Use("test_async_while_body")); + kernel_ctor.construct_arguments(4).Assign( + regs.Use({"predicate_promise", "prev_loop_count_future", + "loop_count_promise", "max_iterations"})); + } + { + auto kernel_ctor = kernels_ctor.ConstructAt(1); + kernel_ctor.set_code(kernels.Use("return")); + } + function_ctor.set_num_regs(regs.size()); + } + return buffer; +} + +struct AsyncWhileOpTestParams { + bool initial_predicate; + int final_result; +}; +class AsyncWhileOpTestFixture + : public ::testing::TestWithParam {}; +TEST_P(AsyncWhileOpTestFixture, AsyncWhileOp) { + auto params = GetParam(); + auto buffer = CreateAsyncWhileExecutable(); + + mlrt::bc::Executable executable(buffer.data()); + + mlrt::KernelRegistry registry; + RegisterTfMlrtKernels(registry); + registry.Register("test_async_while_body", TestAsyncWhileFnBody); + + mlrt::LoadedExecutable loaded_executable(executable, registry); + + auto work_queue = tfrt::CreateMultiThreadedWorkQueue( + /*num_threads=*/4, /*num_blocking_threads=*/4); + mlrt::ExecutionContext execution_context(&loaded_executable); + execution_context.set_work_queue(work_queue.get()); + + tensorflow::SessionOptions session_options; + tensorflow::FunctionDefLibrary fdef_lib; + TF_ASSERT_OK_AND_ASSIGN(auto fallback_state, tfrt_stub::FallbackState::Create( + session_options, fdef_lib)); + + std::function)> runner = + [](const std::function& f) { f(); }; + tfrt_stub::OpKernelRunnerTable runner_table; + tfd::FallbackResourceArray resource_array; + tfd::KernelFallbackCompatRequestState fallback_request_state( + &runner, &fallback_state->device_manager(), /*step_id=*/0, &runner_table, + &resource_array, /*user_intra_op_threadpool=*/nullptr, + /*model_metadata=*/std::nullopt, + &fallback_state->process_function_library_runtime()); + + tfrt::ResourceContext resource_context; + + auto tf_context = + std::make_unique(&fallback_request_state, &resource_context); + execution_context.AddUserContext(std::move(tf_context)); + + std::vector args; + args.resize(3); + + // initial predicate is true + tensorflow::Tensor initial_predicate_tensor{DT_BOOL, {}}; + initial_predicate_tensor.scalar()() = params.initial_predicate; + args.at(0).Set( + tfrt_stub::FallbackTensor(std::move(initial_predicate_tensor))); + + tensorflow::Tensor loop_count_tensor{DT_INT32, {}}; + loop_count_tensor.scalar()() = 0; + args.at(1).Set(tfrt_stub::FallbackTensor(std::move(loop_count_tensor))); + + tensorflow::Tensor max_iteration_tensor{DT_INT32, {}}; + max_iteration_tensor.scalar()() = 2; + args.at(2).Set(tfrt_stub::FallbackTensor(std::move(max_iteration_tensor))); + + mlrt::Value result; + + absl::Notification notification; + execution_context.set_exit_handler( + [¬ification]() { notification.Notify(); }); + + std::vector last_uses = {true, true, true}; + execution_context.Call(executable.functions()[0], last_uses, + absl::MakeSpan(args), absl::MakeSpan(&result, 1)); + mlrt::Execute(execution_context); + + notification.WaitForNotification(); + + ASSERT_OK(execution_context.status()); + + tensorflow::Tensor expected(tensorflow::DT_INT32, {}); + expected.scalar()() = params.final_result; + + auto& to_be = result.Get(); + tensorflow::test::ExpectEqual(to_be.tensor(), expected); +} + +INSTANTIATE_TEST_SUITE_P( + AsyncWhileOpTestSuite, AsyncWhileOpTestFixture, + ::testing::ValuesIn({{true, 2}, {false, 0}})); + +// A AsyncWhile body function that triggers failure. +void TestAsyncWhileFnBodyError(mlrt::KernelFrame frame) { + ASSERT_EQ(frame.arguments().size(), 4); + + frame.execution_context().Fail(absl::InternalError("Test error")); +} +TEST(KernelTest, AsyncWhileOpError) { + auto buffer = CreateAsyncWhileExecutable(); + + mlrt::bc::Executable executable(buffer.data()); + + mlrt::KernelRegistry registry; + RegisterTfMlrtKernels(registry); + registry.Register("test_async_while_body", TestAsyncWhileFnBodyError); + + mlrt::LoadedExecutable loaded_executable(executable, registry); + + auto work_queue = tfrt::CreateMultiThreadedWorkQueue( + /*num_threads=*/4, /*num_blocking_threads=*/4); + mlrt::ExecutionContext execution_context(&loaded_executable); + execution_context.set_work_queue(work_queue.get()); + + tensorflow::SessionOptions session_options; + tensorflow::FunctionDefLibrary fdef_lib; + TF_ASSERT_OK_AND_ASSIGN(auto fallback_state, tfrt_stub::FallbackState::Create( + session_options, fdef_lib)); + + std::function)> runner = + [](const std::function& f) { f(); }; + tfrt_stub::OpKernelRunnerTable runner_table; + tfd::FallbackResourceArray resource_array; + tfd::KernelFallbackCompatRequestState fallback_request_state( + &runner, &fallback_state->device_manager(), /*step_id=*/0, &runner_table, + &resource_array, /*user_intra_op_threadpool=*/nullptr, + /*model_metadata=*/std::nullopt, + &fallback_state->process_function_library_runtime()); + + tfrt::ResourceContext resource_context; + + auto tf_context = + std::make_unique(&fallback_request_state, &resource_context); + execution_context.AddUserContext(std::move(tf_context)); + + std::vector args; + args.resize(3); + + // initial predicate is true + tensorflow::Tensor initial_predicate_tensor{DT_BOOL, {}}; + initial_predicate_tensor.scalar()() = true; + args.at(0).Set( + tfrt_stub::FallbackTensor(std::move(initial_predicate_tensor))); + + tensorflow::Tensor loop_count_tensor{DT_INT32, {}}; + loop_count_tensor.scalar()() = 0; + args.at(1).Set(tfrt_stub::FallbackTensor(std::move(loop_count_tensor))); + + tensorflow::Tensor max_iteration_tensor{DT_INT32, {}}; + max_iteration_tensor.scalar()() = 2; + args.at(2).Set(tfrt_stub::FallbackTensor(std::move(max_iteration_tensor))); + + mlrt::Value result; + + absl::Notification notification; + execution_context.set_exit_handler( + [¬ification]() { notification.Notify(); }); + + std::vector last_uses = {true, true, true}; + execution_context.Call(executable.functions()[0], last_uses, + absl::MakeSpan(args), absl::MakeSpan(&result, 1)); + mlrt::Execute(execution_context); + + notification.WaitForNotification(); + EXPECT_THAT( + execution_context.status(), + ::tsl::testing::StatusIs(absl::StatusCode::kInternal, "Test error")); +} + } // namespace } // namespace tf_mlrt } // namespace tensorflow From 367001b15350f04bdc8e65cb02a5df6d467f3ac4 Mon Sep 17 00:00:00 2001 From: Yu Feng Date: Mon, 24 Jul 2023 21:10:24 -0700 Subject: [PATCH 091/410] Update __repr__ of layout and mesh. Making the repr lossless to ease debugging of parted layout support. Things will be in fluid for a while. We can worry about pretty formatting later. PiperOrigin-RevId: 550761328 --- tensorflow/dtensor/python/layout.py | 9 ++------- tensorflow/dtensor/python/tests/layout_test.py | 9 ++++----- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/tensorflow/dtensor/python/layout.py b/tensorflow/dtensor/python/layout.py index 23536343bd3c06..e55039757f692d 100644 --- a/tensorflow/dtensor/python/layout.py +++ b/tensorflow/dtensor/python/layout.py @@ -242,12 +242,7 @@ def __hash__(self): return hash(self.as_proto().SerializeToString(deterministic=True)) def __repr__(self) -> str: - dims = [tuple(self[dim_name]) for dim_name in self.dim_names] - return ( - f'' - ) + return f'Mesh.from_string({self.to_string()})' # TODO(panzf): change to pybind11 pickle implementation in the last step def __reduce__(self): @@ -447,7 +442,7 @@ def _new_object(cls, *args, **kwargs): return self def __repr__(self) -> str: - return f'Layout(sharding_specs={self.sharding_specs}, mesh={self.mesh})' + return f'Layout.from_string({self.to_string()})' def __hash__(self): return hash(self.as_proto().SerializeToString(deterministic=True)) diff --git a/tensorflow/dtensor/python/tests/layout_test.py b/tensorflow/dtensor/python/tests/layout_test.py index b8e367a4901dd5..5fdb2d1c008cba 100644 --- a/tensorflow/dtensor/python/tests/layout_test.py +++ b/tensorflow/dtensor/python/tests/layout_test.py @@ -92,9 +92,9 @@ def test_mesh_repr(self): mesh = layout.Mesh([_MESH_DIM_BATCH, _MESH_DIM_X], device_ids, np.ravel(device_ids).tolist(), test_util.create_device_list((4, 2), 'CPU')) - self.assertIn( - ' Date: Mon, 24 Jul 2023 21:21:58 -0700 Subject: [PATCH 092/410] Validate bounding shape and dynamic shape match when dimensions are known. PiperOrigin-RevId: 550763601 --- tensorflow/compiler/tf2xla/BUILD | 3 ++- tensorflow/compiler/tf2xla/shape_util.cc | 16 ++++++++++++---- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 21c605944f8dfd..22d9877bed9713 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -756,13 +756,14 @@ cc_library( deps = [ "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], ) diff --git a/tensorflow/compiler/tf2xla/shape_util.cc b/tensorflow/compiler/tf2xla/shape_util.cc index 8a266c2f1eed96..4818c556a192b6 100644 --- a/tensorflow/compiler/tf2xla/shape_util.cc +++ b/tensorflow/compiler/tf2xla/shape_util.cc @@ -17,6 +17,8 @@ limitations under the License. #include +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -124,9 +126,9 @@ Status TensorShapeToBoundedXLAShape(DataType dtype, } if (tensor_shape.dims() != bound.dims()) { - return errors::InvalidArgument( + return absl::InvalidArgumentError(absl::StrCat( "`tensor_shape` and `bound` have different ranks. tensor_shape=", - tensor_shape.dims(), "vs bound=", bound.dims()); + tensor_shape.dims(), "vs bound=", bound.dims())); } int rank = tensor_shape.dims(); @@ -134,8 +136,14 @@ Status TensorShapeToBoundedXLAShape(DataType dtype, std::vector layout(rank); for (int d = 0; d < rank; ++d) { if (bound.dim_size(d) < 0) { - return errors::InvalidArgument("Bound dimension ", d, - " has unknown size."); + return absl::InvalidArgumentError( + absl::StrCat("Bound dimension ", d, " has unknown size.")); + } + if (tensor_shape.dim_size(d) > 0 && + bound.dim_size(d) != tensor_shape.dim_size(d)) { + return absl::InvalidArgumentError(absl::StrCat( + "Bounding shape does not match dynamic shape for known dimension ", d, + tensor_shape.dim_size(d), " vs ", bound.dim_size(d))); } dimensions[d] = bound.dim_size(d); } From 2a5cbd2a8521625a7b93dca9ecfa052965e50c7c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 24 Jul 2023 21:41:18 -0700 Subject: [PATCH 093/410] Update TFRT dependency to use revision http://github.com/tensorflow/runtime/commit/185b3fe2a676620ece266b6a3e4df6d5fb7264ae. PiperOrigin-RevId: 550766697 --- third_party/tf_runtime/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/tf_runtime/workspace.bzl b/third_party/tf_runtime/workspace.bzl index 9f8fcca21eb01d..3b2e8a79e554b7 100644 --- a/third_party/tf_runtime/workspace.bzl +++ b/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "76c4d009a37be44b2df2c79f1f74f40369b56dd0" - TFRT_SHA256 = "e5564ff77f20c7bc60cb8677f04a5dfdb7c221f98fd1c09095295e43e658b776" + TFRT_COMMIT = "185b3fe2a676620ece266b6a3e4df6d5fb7264ae" + TFRT_SHA256 = "6c9f82696dac444d2d5c8a409c329170145bf6304de910ec0d1f608fec3f2d92" tf_http_archive( name = "tf_runtime", From 93714549079144851619fb13690070eb444bb4bd Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Mon, 24 Jul 2023 23:22:49 -0700 Subject: [PATCH 094/410] Extract reduction fusion code from ir_emitter_unnested. PiperOrigin-RevId: 550783724 --- tensorflow/compiler/xla/service/gpu/BUILD | 4 - .../compiler/xla/service/gpu/fusions/BUILD | 46 - .../xla/service/gpu/fusions/fusion_emitter.cc | 4 +- .../xla/service/gpu/fusions/fusion_emitter.h | 9 - .../xla/service/gpu/fusions/fusions.cc | 4 - .../xla/service/gpu/fusions/reduction.cc | 958 ------------------ .../xla/service/gpu/fusions/reduction.h | 120 --- .../xla/service/gpu/fusions/thunk_util.cc | 120 --- .../xla/service/gpu/fusions/thunk_util.h | 37 - .../xla/service/gpu/fusions/tiling_util.cc | 33 - .../xla/service/gpu/fusions/tiling_util.h | 5 - .../compiler/xla/service/gpu/ir_emitter.cc | 4 - .../xla/service/gpu/ir_emitter_unnested.cc | 871 +++++++++++++++- .../xla/service/gpu/ir_emitter_unnested.h | 158 +++ 14 files changed, 1025 insertions(+), 1348 deletions(-) delete mode 100644 tensorflow/compiler/xla/service/gpu/fusions/reduction.cc delete mode 100644 tensorflow/compiler/xla/service/gpu/fusions/reduction.h delete mode 100644 tensorflow/compiler/xla/service/gpu/fusions/thunk_util.cc delete mode 100644 tensorflow/compiler/xla/service/gpu/fusions/thunk_util.h diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 4bf5b9d4159a14..13ade65c0b8a46 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -309,7 +309,6 @@ cc_library( ":fft_thunk", ":gemm_thunk", ":gpu_asm_opts_util", - ":gpu_constants", ":gpu_conv_runner", ":gpu_device_info", ":gpu_executable", @@ -342,10 +341,8 @@ cc_library( "//tensorflow/compiler/xla/service:custom_call_target_registry", "//tensorflow/compiler/xla/service:name_uniquer", "//tensorflow/compiler/xla/service/gpu/fusions", - "//tensorflow/compiler/xla/service/gpu/fusions:thunk_util", "//tensorflow/compiler/xla/service/gpu/fusions:tiling_util", "//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util", - "//tensorflow/compiler/xla/service/llvm_ir:dynamic_update_slice_util", "//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter", "//tensorflow/compiler/xla/service/llvm_ir:ir_array", "//tensorflow/compiler/xla/service/llvm_ir:kernel_support_library", @@ -409,7 +406,6 @@ cc_library( ":backend_configs_cc", ":hlo_fusion_analysis", ":hlo_to_ir_bindings", - ":ir_emission_utils", ":ir_emitter_context", ":kernel_reuse_cache", ":target_util", diff --git a/tensorflow/compiler/xla/service/gpu/fusions/BUILD b/tensorflow/compiler/xla/service/gpu/fusions/BUILD index 644544246200ce..0673e218c07396 100644 --- a/tensorflow/compiler/xla/service/gpu/fusions/BUILD +++ b/tensorflow/compiler/xla/service/gpu/fusions/BUILD @@ -61,7 +61,6 @@ cc_library( ":fusion_emitter", ":in_place_dynamic_update_slice", ":loop", - ":reduction", "//tensorflow/compiler/xla/hlo/ir:hlo", "//tensorflow/compiler/xla/mlir_hlo:lhlo", "//tensorflow/compiler/xla/service:elemental_ir_emitter", @@ -102,48 +101,3 @@ cc_library( "@llvm-project//llvm:ir_headers", ], ) - -cc_library( - name = "reduction", - srcs = ["reduction.cc"], - hdrs = ["reduction.h"], - deps = [ - ":fusion_emitter", - ":thunk_util", - ":tiling_util", - "//tensorflow/compiler/xla/hlo/ir:hlo", - "//tensorflow/compiler/xla/service/gpu:gpu_executable", - "//tensorflow/compiler/xla/service/gpu:hlo_fusion_analysis", - "//tensorflow/compiler/xla/service/gpu:ir_emission_utils", - "//tensorflow/compiler/xla/service/gpu:ir_emitter", - "//tensorflow/compiler/xla/service/gpu:ir_emitter_context", - "//tensorflow/compiler/xla/service/gpu:kernel_reuse_cache", - "//tensorflow/compiler/xla/service/gpu:parallel_loop_emitter", - "//tensorflow/compiler/xla/service/gpu:target_util", - "//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter", - "//tensorflow/compiler/xla/service/llvm_ir:ir_array", - "//tensorflow/compiler/xla/service/llvm_ir:kernel_support_library", - "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", - "//tensorflow/compiler/xla/translate/mhlo_to_hlo:location_exporter", - "@llvm-project//llvm:ir_headers", - ], -) - -cc_library( - name = "thunk_util", - srcs = ["thunk_util.cc"], - hdrs = ["thunk_util.h"], - visibility = ["//tensorflow/compiler/xla/service/gpu:__subpackages__"], - deps = [ - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla/service/gpu:gpu_executable", - "//tensorflow/compiler/xla/service/gpu:ir_emission_utils", - "//tensorflow/compiler/xla/service/gpu:ir_emitter_context", - "//tensorflow/compiler/xla/service/gpu:thunk", - "//tensorflow/compiler/xla/translate/hlo_to_mhlo:hlo_utils", - "@com_google_absl//absl/types:span", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:MemRefDialect", - ], -) diff --git a/tensorflow/compiler/xla/service/gpu/fusions/fusion_emitter.cc b/tensorflow/compiler/xla/service/gpu/fusions/fusion_emitter.cc index ab839a522c5df0..b4c60d1c2bc334 100644 --- a/tensorflow/compiler/xla/service/gpu/fusions/fusion_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/fusions/fusion_emitter.cc @@ -74,8 +74,6 @@ void AnnotateKernelLaunchDimensions(const LaunchDimensions& launch_dims, } } -} // namespace - std::tuple, std::vector> BuildKernelPrototype(IrEmitterContext& ir_emitter_context, @@ -173,6 +171,8 @@ BuildKernelPrototype(IrEmitterContext& ir_emitter_context, return {kernel, std::move(inputs), std::move(outputs)}; } +} // namespace + StatusOr KernelFusionEmitterBase::Emit( KernelReuseCache& kernel_cache, llvm::IRBuilder<>* builder) const { std::string suggested_kernel_name = GetIrNameFromLoc(fusion_op_->getLoc()); diff --git a/tensorflow/compiler/xla/service/gpu/fusions/fusion_emitter.h b/tensorflow/compiler/xla/service/gpu/fusions/fusion_emitter.h index b31247bd987da6..0fdc738a29eee5 100644 --- a/tensorflow/compiler/xla/service/gpu/fusions/fusion_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/fusions/fusion_emitter.h @@ -80,15 +80,6 @@ class KernelFusionEmitterBase : public FusionInterface { const HloFusionInstruction& fusion_; }; -std::tuple, - std::vector> -BuildKernelPrototype(IrEmitterContext& ir_emitter_context, - const std::string& suggested_name, - absl::Span arguments, - size_t num_inputs, - const LaunchDimensions& launch_dimensions, - llvm::IRBuilder<>* builder); - } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/fusions/fusions.cc b/tensorflow/compiler/xla/service/gpu/fusions/fusions.cc index 34723c1ec64eef..a6c800556c5c7a 100644 --- a/tensorflow/compiler/xla/service/gpu/fusions/fusions.cc +++ b/tensorflow/compiler/xla/service/gpu/fusions/fusions.cc @@ -21,7 +21,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/fusions/copy.h" #include "tensorflow/compiler/xla/service/gpu/fusions/in_place_dynamic_update_slice.h" #include "tensorflow/compiler/xla/service/gpu/fusions/loop.h" -#include "tensorflow/compiler/xla/service/gpu/fusions/reduction.h" #include "tensorflow/compiler/xla/service/gpu/hlo_fusion_analysis.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" @@ -50,9 +49,6 @@ std::optional> GetFusionEmitter( ElementalIrEmitter& elemental_emitter, mlir::lmhlo::FusionOp fusion_op, const HloFusionInstruction& fusion) { switch (analysis.GetEmitterFusionKind()) { - case HloFusionAnalysis::EmitterFusionKind::kReduction: - return std::make_unique( - ir_emitter_context, elemental_emitter, fusion_op, fusion, analysis); case HloFusionAnalysis::EmitterFusionKind::kLoop: { bool is_single = IsSingleInstructionFusion(fusion_op); if (!is_single && CanEmitFusedDynamicUpdateSliceInPlaceForGpu( diff --git a/tensorflow/compiler/xla/service/gpu/fusions/reduction.cc b/tensorflow/compiler/xla/service/gpu/fusions/reduction.cc deleted file mode 100644 index f139bba8f1a302..00000000000000 --- a/tensorflow/compiler/xla/service/gpu/fusions/reduction.cc +++ /dev/null @@ -1,958 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu/fusions/reduction.h" - -#include -#include - -#include "llvm/IR/IRBuilder.h" -#include "tensorflow/compiler/xla/hlo/ir/hlo_casting_utils.h" -#include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" -#include "tensorflow/compiler/xla/hlo/ir/hlo_instructions.h" -#include "tensorflow/compiler/xla/service/gpu/fusions/fusion_emitter.h" -#include "tensorflow/compiler/xla/service/gpu/fusions/thunk_util.h" -#include "tensorflow/compiler/xla/service/gpu/fusions/tiling_util.h" -#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" -#include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" -#include "tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h" -#include "tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h" -#include "tensorflow/compiler/xla/service/gpu/kernel_reuse_cache.h" -#include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h" -#include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h" -#include "tensorflow/compiler/xla/service/gpu/target_util.h" -#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" -#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" -#include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" -#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" -#include "tensorflow/compiler/xla/translate/mhlo_to_hlo/location_exporter.h" - -namespace xla { -namespace gpu { -namespace { - -using TypedPointer = std::pair; - -// Fusion root -> array of indexes, one per reduction output. -using ReductionOutputMap = - ConstHloInstructionMap>; -using ExtraOutputGensMap = ConstHloInstructionMap; - -void MaybeEmitFenceForAMDGPU(llvm::IRBuilder<>* builder, - IrEmitterContext& ir_emitter_context) { - auto* module = builder->GetInsertBlock()->getModule(); - if (IsAMDGPU(module) && - ir_emitter_context.rocm_compute_capability().gcn_arch_name().substr( - 0, 6) == "gfx90a") { - builder->CreateFence( - llvm::AtomicOrdering::SequentiallyConsistent, - builder->getContext().getOrInsertSyncScopeID("workgroup")); - } -} - -void EmitSyncThreads(llvm::IRBuilder<>* builder, - IrEmitterContext& ir_emitter_context) { - MaybeEmitFenceForAMDGPU(builder, ir_emitter_context); - EmitCallToTargetIntrinsic(TargetIntrinsicID::kBarrierId, {}, {}, builder); -} - -// For a row reduction, returns the number of rows we can process in parallel -// per warp. -int RowReductionGetRowsPerWarp(int reduced_dimension_size) { - if (WarpSize() % reduced_dimension_size != 0 || - reduced_dimension_size >= WarpSize()) { - return 1; - } - return WarpSize() / reduced_dimension_size; -} - -llvm::GlobalVariable* AllocateShared( - llvm::IRBuilder<>* builder, const TilingScheme& tiling_scheme, - llvm::Type* element_type, - absl::Span dimensions_major_to_minor, - absl::string_view buffer_name) { - CHECK(!dimensions_major_to_minor.empty()); - llvm::Type* ty = element_type; - for (auto dim : llvm::reverse(dimensions_major_to_minor)) { - ty = llvm::ArrayType::get(ty, dim); - } - ty = llvm::ArrayType::get(ty, tiling_scheme.GetThreadIdScalingFactor()); - return llvm_ir::AllocateSharedMemoryTile( - builder->GetInsertBlock()->getModule(), ty, buffer_name); -} - -Status EmitExtraOutputsForReduce(llvm::IRBuilder<>* builder, - const Shape& reduction_operand_shape, - const ReductionOutputMap& result_ir_arrays, - const llvm_ir::IrArray::Index& index, - const ReductionCodegenInfo& reduction_info, - const ExtraOutputGensMap& extra_output_gens) { - if (extra_output_gens.empty()) { - return OkStatus(); - } - - // Compute all extra output values before writing them. This avoids - // overwriting aliased input/output buffers before all reads occurred. - std::vector> - extra_output_ir_values; - extra_output_ir_values.reserve(extra_output_gens.size()); - - auto get_index = [&](const HloInstruction* instr) { - const Shape& s = instr->shape(); - return ShapeUtil::EqualIgnoringElementType(reduction_operand_shape, s) - ? index - : index.SourceIndexOfBitcast(reduction_operand_shape, s, - builder); - }; - - for (const auto& [instr, generator] : extra_output_gens) { - TF_ASSIGN_OR_RETURN(llvm::Value* const extra_output_ir_value, - generator(get_index(instr))); - extra_output_ir_values.emplace_back(instr, extra_output_ir_value); - } - - for (const auto& [instr, generator] : extra_output_ir_values) { - absl::Span result_ir = result_ir_arrays.at(instr); - CHECK_EQ(result_ir.size(), 1); - result_ir[0].EmitWriteArrayElement( - get_index(instr), generator, builder, /*use_linear_index=*/ - reduction_info.GetNumPartialResults() == 1); - } - return OkStatus(); -} - -ReductionCodegenState GenerateReductionCodegenState( - llvm::IRBuilder<>* builder, mlir::lmhlo::FusionOp fusion, - const ReductionCodegenInfo& reduction_info, - absl::Span reduce_instr_index_group, - FusedIrEmitter& fused_emitter) { - ReductionCodegenState reduction_codegen_state(reduction_info); - VLOG(10) << "Emit prologue for reduction: " << llvm_ir::DumpToString(fusion); - - for (const HloReduceInstruction* reduce_hlo : reduce_instr_index_group) { - int num_partial_results = reduction_codegen_state.GetNumPartialResults(); - int num_outputs = reduce_hlo->shape().IsTuple() - ? reduce_hlo->shape().tuple_shapes_size() - : 1; - for (int op_result_idx = 0; op_result_idx < num_outputs; op_result_idx++) { - Shape result_shape = reduce_hlo->shape().IsTuple() - ? reduce_hlo->shape().tuple_shapes(op_result_idx) - : reduce_hlo->shape(); - - llvm::Type* element_type = llvm_ir::PrimitiveTypeToIrType( - result_shape.element_type(), builder->GetInsertBlock()->getModule()); - llvm::AllocaInst* reduction_input_address = - llvm_ir::EmitAllocaAtFunctionEntry( - element_type, "reduction_input_address", builder); - - llvm::AllocaInst* partial_result_address = - llvm_ir::EmitAllocaAtFunctionEntryWithCount( - element_type, - /*element_count=*/builder->getInt32(num_partial_results), - "partial_reduction_result", builder); - - const HloInstruction* init_value = - reduce_hlo->init_values()[op_result_idx]; - - // Initialize the partial result with the initial value of the reduction. - llvm::Value* init_ir_value = (*fused_emitter.GetGenerator( - *init_value))(llvm_ir::IrArray::Index(builder->getInt32Ty())) - .value(); - - for (int i = 0; i < num_partial_results; ++i) { - builder->CreateStore( - init_ir_value, builder->CreateInBoundsGEP( - partial_result_address->getAllocatedType(), - partial_result_address, {builder->getInt32(i)})); - } - - const TilingScheme& tiling_scheme = - reduction_codegen_state.GetTilingScheme(); - int64_t num_threads_x = - tiling_scheme.GetNumThreadsFor(TilingScheme::DimX); - llvm::GlobalVariable* shared_cache = [&]() -> llvm::GlobalVariable* { - if (reduction_codegen_state.IsRowReduction()) { - // Multi-row reductions do not use shared memory. - if (RowReductionGetRowsPerWarp(tiling_scheme.GetDimsInElems()[2]) > - 1) { - return nullptr; - } - // Allocate __shared__ - // cache[num_partial_results][num_warps][scaling_factor]. - CHECK_EQ(tiling_scheme.GetNumThreadsPerBlock() % WarpSize(), 0); - int num_warps = tiling_scheme.GetNumThreadsPerBlock() / WarpSize(); - return AllocateShared(builder, tiling_scheme, element_type, - {num_partial_results, num_warps}, - "shared_cache"); - } else { - // Allocate __shared__ - // cache[num_threads][num_threads + 1], where - // num_threads == num_threads_x == num_threads_y. The "+1" is used to - // avoid bank conflicts. - // - // (Although each thread produces num_partial_results results, we - // don't need that much cache: Only one result is live at a time.) - CHECK_EQ(num_threads_x, - tiling_scheme.GetNumThreadsFor(TilingScheme::DimY)); - return AllocateShared(builder, tiling_scheme, element_type, - {num_threads_x, num_threads_x + 1}, - "shared_cache"); - } - }(); - - llvm_ir::ElementGenerator input_gen = - *fused_emitter.GetGenerator(*reduce_hlo->inputs()[op_result_idx]); - reduction_codegen_state.SetCalculationStateFor( - {shared_cache, init_ir_value, partial_result_address, - reduction_input_address, input_gen}, - reduce_hlo, op_result_idx); - } - } - - return reduction_codegen_state; -} - -// Generate a single element of the tile (update the accumulator state) for a -// given reducer of index `i`. -void GenerateElementForReducer( - llvm::IRBuilder<>* builder, IrEmitterContext& ir_emitter_context, - const HloReduceInstruction* reduction, llvm::Value* partial_result_index, - const ReductionCodegenState& codegen_state, - const llvm_ir::IrArray::Index& index_without_linear, - const llvm_ir::IrArray::Index& input_index, int num_partial_results, - const ReductionOutputMap& result_ir_arrays) { - HloComputation* reducer = reduction->to_apply(); - CHECK_EQ(reducer->num_parameters() % 2, 0); - - absl::InlinedVector reduction_accumulators; - absl::InlinedVector reduction_input_value; - for (int red_idx = 0; red_idx < reducer->num_parameters() / 2; red_idx++) { - const ReductionCodegenState::ReductionCalculationState& state = - codegen_state.GetCalculationStateFor(reduction, red_idx); - - llvm::AllocaInst* input_address = state.input_address; - llvm::AllocaInst* partial_reduction_result_address = - state.partial_result_address; - llvm::Value* const input_ir_value = *state.input_gen( - num_partial_results > 1 ? index_without_linear : input_index); - builder->CreateStore(input_ir_value, input_address); - llvm::Value* partial_result_address = builder->CreateInBoundsGEP( - partial_reduction_result_address->getAllocatedType(), - partial_reduction_result_address, {partial_result_index}); - reduction_accumulators.push_back(partial_result_address); - reduction_input_value.push_back(input_address); - } - - absl::InlinedVector reduction_params; - for (llvm::Value* acc : reduction_accumulators) { - reduction_params.push_back(acc); - } - for (llvm::Value* value : reduction_input_value) { - reduction_params.push_back(value); - } - - // Emit a call to the variadic reducer. Since it may be returning a - // tuple, we can't return it directly as a value. Instead, before - // the call, we create N (N = # arguments in the tuple) allocas, one - // for each returned argument, then when we make the call we pass N - // pointers as last parameters, the called computation writes into - // those pointers, and we have returned values on the stack (as well - // as pointers to them). - StatusOr> returned_scalars = - CallNestedComputationWithScalarAddrs(builder, ir_emitter_context, - *reducer, reduction_params); - TF_CHECK_OK(returned_scalars.status()); - - for (int i = 0; i < returned_scalars->size(); i++) { - builder->CreateStore(returned_scalars->at(i), reduction_accumulators[i]); - } -} - -// Emits shuffle-down reduction for the `partial_result_address` using the -// reduction computation `reducer`, writes output into -// `partial_result_address`. -// -// Multiple partial_result_address inputs happen when doing variadic -// reduction: each one should get the output value. -void EmitFullWarpShuffleDownLoopForReduce( - llvm::IRBuilder<>* builder, IrEmitterContext& ir_emitter_context, - const HloComputation* reducer, - absl::Span partial_result_addresses, - int threads_per_block, int num_results_per_warp = 1) { - // This only works when the block size is a multiple of 32 threads. - - // We check this here as a mistake in the number of threads per - // block is very hard to detect. - CHECK_EQ(threads_per_block % 32, 0); - CHECK_EQ(WarpSize() % num_results_per_warp, 0); - - for (int distance = 16 / num_results_per_warp; distance >= 1; distance /= 2) { - absl::InlinedVector reduction_params; - - for (auto acc : partial_result_addresses) { - reduction_params.push_back(acc.first); - } - - for (auto [partial_result_address, element_type] : - partial_result_addresses) { - int bit_width = llvm_ir::GetSizeInBits(element_type); - llvm::Value* result_from_other_lane = llvm_ir::EmitAllocaAtFunctionEntry( - element_type, "result_from_other_lane", builder); - - reduction_params.push_back(result_from_other_lane); - - // Bitcast cannot be applied to aggregate types (even packed ones), so - // we bitcast addresses of load/store to intN* of the same bit-width. - llvm::Type* shuffled_value_type = element_type->isStructTy() - ? builder->getIntNTy(bit_width) - : element_type; - auto convert_pointer_for_shuffle = [&](llvm::Value* ptr) { - return builder->CreatePointerBitCastOrAddrSpaceCast( - ptr, shuffled_value_type->getPointerTo()); - }; - - llvm::Value* partial_result = builder->CreateLoad( - shuffled_value_type, - convert_pointer_for_shuffle(partial_result_address), - "partial_reduction_result"); - builder->CreateStore( - EmitFullWarpShuffleDown(partial_result, builder->getInt32(distance), - builder), - convert_pointer_for_shuffle(result_from_other_lane)); - } - - StatusOr> returned_scalars = - CallNestedComputationWithScalarAddrs(builder, ir_emitter_context, - *reducer, reduction_params); - TF_CHECK_OK(returned_scalars.status()); - - for (int i = 0; i < returned_scalars->size(); i++) { - builder->CreateStore(/*Val=*/returned_scalars->at(i), - /*Ptr=*/partial_result_addresses[i].first); - } - } -} - -// Gets the output offset as calculated from thread_id.x (to be applied to the -// offset calculated from block_id and thread_id.y). -llvm::Value* GetStartOffsetX(const TilingScheme& tiling_scheme, - llvm::Value* thread_id_x, llvm::Type* index_ty, - llvm::IRBuilder<>* b) { - int64_t multiplier = - tiling_scheme.GetIndexingOrder() == TilingScheme::StridedIndexingX - ? tiling_scheme.GetVectorSize() - : tiling_scheme.GetTileSizeFor(TilingScheme::DimX); - return b->CreateMul(thread_id_x, - llvm::ConstantInt::get(index_ty, multiplier)); -} - -llvm::Value* GetOutputAddressForReduction( - llvm::IRBuilder<>* builder, int partial_result_idx, llvm::Type* index_ty, - const ReductionCodegenState& reduction_codegen_state, - const TilingKernelInfo& tiling_kernel_info, - const ReductionOutputMap& output_arrays, - const HloReduceInstruction* reduction, int output_idx) { - auto constant = [&](uint64_t c) -> llvm::Constant* { - return llvm::ConstantInt::get(index_ty, c); - }; - - const TilingScheme& tiling_scheme = reduction_codegen_state.GetTilingScheme(); - const TilingThreadIdInfo& thread_id_info = tiling_kernel_info.thread_id_info; - - llvm_ir::IrArray::Index start_offset = [&] { - llvm::Value* x_loc = thread_id_info.thread_id_x; - llvm::Value* y_loc = thread_id_info.thread_id_y; - if (!reduction_codegen_state.IsRowReduction()) { - std::swap(x_loc, y_loc); - } - llvm::Value* start_offset_x = - GetStartOffsetX(tiling_scheme, x_loc, index_ty, builder); - return tiling_kernel_info.tile_origin - .AddOffsetToDim(y_loc, TilingScheme::DimY, builder) - .AddOffsetToDim(start_offset_x, TilingScheme::DimX, builder); - }(); - - const llvm_ir::IrArray& output_array = - output_arrays.at(reduction)[output_idx]; - const Shape& operand_shape = reduction->inputs()[output_idx]->shape(); - Shape reduction_kept_element_shape = - ShapeUtil::DeleteDimensions(reduction->dimensions(), operand_shape); - - // Given the IrArray index of a reduction input, returns the linear address of - // the reduction output as if the reduction were going to keep the input shape - // with the dimensions being reduced moved. - llvm::Value* untransposed_output_linear_address = [&] { - const llvm_ir::IrArray::Index index = start_offset.AddOffsetToDim( - constant(partial_result_idx), TilingScheme::DimX, builder); - if (reduction_codegen_state.IsRowReduction()) { - // For row-reduction, y-coordinate determines which row we write into. - return index[TilingScheme::DimY]; - } - // For column reduction, we get the transposed address. - absl::Span dims_in_elem = tiling_scheme.GetDimsInElems(); - llvm::Value* x_dim_size = - index.GetConstantWithIndexType(dims_in_elem[TilingScheme::DimX]); - llvm::Value* x_block_offset = - builder->CreateMul(index[TilingScheme::DimZ], x_dim_size); - return builder->CreateAdd(x_block_offset, index[TilingScheme::DimX]); - }(); - - // A reduction is allowed to transpose its output. For example, suppose - // we are reducing the second dimension of f32[10,20,30]{3,2,1}. We are - // allowed to produce as output either f32[10,30]{1,0} (no transpose) or - // f32[10,30]{0,1} (transposing the two output dims). - // - // At this point in the function we have a "partial sum" of input elements - // (stored in partial_result_addresses), and we need to accumulate it into - // the correct output element. - llvm_ir::IrArray::Index element_index( - /*linear=*/untransposed_output_linear_address, - reduction_kept_element_shape, builder); - llvm_ir::IrArray::Index output_index(element_index.multidim(), - output_array.GetShape(), - element_index.GetType()); - - return output_array.EmitArrayElementAddress(output_index, builder, - "output_element_address"); -} - -// Wraps up the code generation for a tile block of a reduction kernel: -// write the calculated output into the output tensor. -void WriteReductionOutput(llvm::IRBuilder<>* builder, - IrEmitterContext& ir_emitter_context, - llvm::Type* index_ty, - const ReductionCodegenState& reduction_codegen_state, - const TilingKernelInfo& tiling_kernel_info, - const ReductionOutputMap& output_arrays, - const HloReduceInstruction* reduction, - int partial_result_idx, - const absl::Span values) { - const HloComputation* reducer = reduction->to_apply(); - for (const auto& [oidx, typed_ptr] : llvm::enumerate(values)) { - auto [output_ptr, type] = typed_ptr; - llvm::Value* output_address = GetOutputAddressForReduction( - builder, partial_result_idx, index_ty, reduction_codegen_state, - tiling_kernel_info, output_arrays, reduction, oidx); - if (reduction_codegen_state.IsRaceFree()) { - builder->CreateStore(builder->CreateLoad(type, output_ptr, "output"), - output_address); - } else { - CHECK_EQ(values.size(), 1); - TF_CHECK_OK(EmitAtomicOperationForNestedComputation( - builder, ir_emitter_context, *reducer, output_address, output_ptr, - type)); - } - } -} - -// `current_output`: the value the tile has calculated. -// `output_address`: address where the output value has to be written. -void EmitReductionOutputForRowReduction( - llvm::IRBuilder<>* builder, IrEmitterContext& ir_emitter_context, - const TilingKernelInfo& tiling_kernel_info, - const ReductionCodegenState& reduction_codegen_state, llvm::Type* index_ty, - const ReductionOutputMap& output_arrays, - const HloReduceInstruction* reduction, int partial_result_idx) { - const HloComputation* reducer = reduction->to_apply(); - const auto& thread_id_info = tiling_kernel_info.thread_id_info; - auto constant = [&](uint64_t c) -> llvm::Constant* { - return llvm::ConstantInt::get(index_ty, c); - }; - auto is_zero = [&](llvm::Value* value) { - return builder->CreateICmpEQ(value, constant(0)); - }; - - int num_outputs = reducer->num_parameters() / 2; - const TilingScheme& tiling_scheme = reduction_codegen_state.GetTilingScheme(); - absl::InlinedVector current_outputs; - for (int output_idx = 0; output_idx < num_outputs; output_idx++) { - const ReductionCodegenState::ReductionCalculationState& state = - reduction_codegen_state.GetCalculationStateFor(reduction, output_idx); - current_outputs.push_back( - {builder->CreateInBoundsGEP( - state.partial_result_address->getAllocatedType(), - state.partial_result_address, {constant(partial_result_idx)}, - "current_output"), - state.partial_result_address->getAllocatedType()}); - } - - int reduced_dimension_size = tiling_scheme.GetDimsInElems()[2]; - int num_rows_per_warp = RowReductionGetRowsPerWarp(reduced_dimension_size); - EmitFullWarpShuffleDownLoopForReduce( - builder, ir_emitter_context, reducer, absl::MakeSpan(current_outputs), - tiling_scheme.GetNumThreadsPerBlockPhysical(), num_rows_per_warp); - - KernelSupportLibrary ksl(builder); - llvm::Value* warp_id = - builder->CreateUDiv(thread_id_info.thread_id_x, constant(WarpSize())); - - auto emit_write_output = [&](llvm::Value* write_condition, - const absl::Span values) { - ksl.If("reduction_write_output", write_condition, [&] { - WriteReductionOutput(builder, ir_emitter_context, index_ty, - reduction_codegen_state, tiling_kernel_info, - output_arrays, reduction, partial_result_idx, - values); - }); - }; - - if (num_rows_per_warp > 1) { - llvm::Value* is_writing_thread = is_zero(builder->CreateAnd( - thread_id_info.thread_id_x, constant(reduced_dimension_size - 1))); - emit_write_output(is_writing_thread, current_outputs); - return; - } - - ksl.If("intra_warp_reduce_write", is_zero(thread_id_info.lane_id), [&] { - for (int oidx = 0; oidx < num_outputs; oidx++) { - const ReductionCodegenState::ReductionCalculationState& state = - reduction_codegen_state.GetCalculationStateFor(reduction, oidx); - llvm::Value* shmem_output_addr = thread_id_info.GEPIntoSharedMemory( - builder, state.shared_cache, {constant(partial_result_idx), warp_id}); - builder->CreateStore(builder->CreateLoad(current_outputs[oidx].second, - current_outputs[oidx].first), - shmem_output_addr); - } - }); - - // TODO(cheshire): Don't we want to sync it once for everything in the - // output? Not once per each? - EmitSyncThreads(builder, ir_emitter_context); - ksl.If("inter_warp_reduce", is_zero(warp_id), [&] { - absl::InlinedVector selected_values; - for (int oidx = 0; oidx < num_outputs; oidx++) { - const ReductionCodegenState::ReductionCalculationState& state = - reduction_codegen_state.GetCalculationStateFor(reduction, oidx); - llvm::Value* block_accum_addr = thread_id_info.GEPIntoSharedMemory( - builder, state.shared_cache, - {constant(partial_result_idx), thread_id_info.lane_id}); - - llvm::Type* element_type = - state.partial_result_address->getAllocatedType(); - - // Ensure initial value address is in generic, not scratch. - llvm::Value* initial_value_addr = builder->CreateAddrSpaceCast( - llvm_ir::EmitAllocaAtFunctionEntry(element_type, "initial_value_addr", - builder), - llvm::PointerType::get(element_type, - /*AddressSpace=*/0)); - builder->CreateStore(state.initial_value, initial_value_addr); - - llvm::Value* warp_exists = builder->CreateICmpULT( - thread_id_info.thread_id_x, - constant(tiling_scheme.GetNumThreadsFor(TilingScheme::DimX) / - WarpSize())); - - llvm::Value* selected_value = builder->CreateSelect( - warp_exists, block_accum_addr, initial_value_addr); - - selected_values.push_back({selected_value, element_type}); - } - - // If only one warp is present in the block, then we don't need inter-warp - // reduction. - // TODO(b/241414088) If only warp is present, then inter-warp communication - // using shared memory and synchronization using barrier is also unnecessary - // and should be removed. - if (tiling_scheme.GetNumThreadsPerBlock() > WarpSize()) { - EmitFullWarpShuffleDownLoopForReduce( - builder, ir_emitter_context, reducer, absl::MakeSpan(selected_values), - tiling_scheme.GetNumThreadsPerBlock()); - } - - emit_write_output(is_zero(thread_id_info.thread_id_x), selected_values); - }); -} - -// Same arguments as EmitReductionOutputForRowReduction. -void EmitReductionOutputForColumnReduction( - llvm::IRBuilder<>* builder, IrEmitterContext& ir_emitter_context, - const TilingKernelInfo& tiling_kernel_info, - const ReductionCodegenState& reduction_codegen_state, llvm::Type* index_ty, - const ReductionOutputMap& output_arrays, - const HloReduceInstruction* reduction, int partial_result_idx) { - KernelSupportLibrary ksl(builder); - const HloComputation* reducer = reduction->to_apply(); - const auto& thread_id_info = tiling_kernel_info.thread_id_info; - - auto constant = [&](uint64_t c) -> llvm::Constant* { - return llvm::ConstantInt::get(index_ty, c); - }; - auto is_zero = [&](llvm::Value* value) { - return builder->CreateICmpEQ(value, constant(0)); - }; - const TilingScheme& tiling_scheme = reduction_codegen_state.GetTilingScheme(); - int num_outputs = reducer->num_parameters() / 2; - - // Wait for reads from shmem in the last iteration to complete. (If this is - // slow, we could "double-buffer" by having two shmem buffers and switching - // between them.) - if (partial_result_idx > 0) { - EmitSyncThreads(builder, ir_emitter_context); - } - - // Store the transpose in shared memory. - for (int output_idx = 0; output_idx < num_outputs; output_idx++) { - const ReductionCodegenState::ReductionCalculationState& state = - reduction_codegen_state.GetCalculationStateFor(reduction, output_idx); - llvm::GlobalVariable* shared_cache = state.shared_cache; - llvm::AddrSpaceCastInst* shmem_output_addr = - llvm::cast(thread_id_info.GEPIntoSharedMemory( - builder, shared_cache, - {thread_id_info.thread_id_x, thread_id_info.thread_id_y}, - "shmem_output_address")); - llvm::Value* current_output = builder->CreateInBoundsGEP( - state.partial_result_address->getAllocatedType(), - state.partial_result_address, {constant(partial_result_idx)}, - "current_output"); - - llvm::Value* current_output_value = builder->CreateLoad( - state.partial_result_address->getAllocatedType(), current_output); - builder->CreateStore(current_output_value, shmem_output_addr); - } - - EmitSyncThreads(builder, ir_emitter_context); - - // Get transposed element from shared memory. - absl::InlinedVector shmem_transposed_addrs; - for (int output_idx = 0; output_idx < num_outputs; output_idx++) { - const ReductionCodegenState::ReductionCalculationState& state = - reduction_codegen_state.GetCalculationStateFor(reduction, output_idx); - llvm::AddrSpaceCastInst* shmem_transposed_addr = - llvm::cast(thread_id_info.GEPIntoSharedMemory( - builder, state.shared_cache, - {thread_id_info.thread_id_y, thread_id_info.thread_id_x}, - "shmem_transposed_addr")); - shmem_transposed_addrs.push_back( - {shmem_transposed_addr, llvm::cast( - shmem_transposed_addr->getPointerOperand()) - ->getResultElementType()}); - } - - EmitFullWarpShuffleDownLoopForReduce(builder, ir_emitter_context, reducer, - absl::MakeSpan(shmem_transposed_addrs), - tiling_scheme.GetNumThreadsPerBlock()); - - // Some warps in the block are completely outside of the bound of the - // tensor, so they should not write any output at all. - llvm::Value* has_output = builder->CreateAnd( - builder->CreateICmpULT( - GetStartOffsetX(tiling_scheme, thread_id_info.thread_id_y, index_ty, - builder), - tiling_kernel_info.output_tile_bounds[1]), - builder->CreateICmpULT(thread_id_info.thread_id_x, - tiling_kernel_info.output_tile_bounds[0])); - - ksl.If("reduction_write_output", - builder->CreateAnd(has_output, is_zero(thread_id_info.lane_id)), [&] { - WriteReductionOutput(builder, ir_emitter_context, index_ty, - reduction_codegen_state, tiling_kernel_info, - output_arrays, reduction, partial_result_idx, - shmem_transposed_addrs); - }); -} - -// Emits code for reductions in the output_instructions. -Status EmitIRForReduction(llvm::IRBuilder<>* builder, - IrEmitterContext& ir_emitter_context, - mlir::lmhlo::FusionOp fusion, - absl::Span instr_index_group, - FusedIrEmitter& fused_emitter, - const ReductionOutputMap& result_ir_arrays, - const ReductionCodegenInfo& reduction_info, - const Shape& input_shape) { - std::vector reductions; - ExtraOutputGensMap extra_output_gens; - - for (const HloInstruction* hlo : instr_index_group) { - if (IsReductionFromOrToContiguousDimensions(*hlo)) { - reductions.push_back(Cast(hlo)); - } else { - extra_output_gens[hlo] = *fused_emitter.GetGenerator(*hlo); - } - } - - CHECK(!reductions.empty()) << " expect at least one reduce instructions."; - const TilingScheme& tiling_scheme = reduction_info.GetTilingScheme(); - CHECK_EQ(tiling_scheme.GetNumThreadsPerBlockPhysical() % WarpSize(), 0); - llvm::Type* index_ty = - GetIndexTypeForKernel(fusion, - tiling_scheme.GetNumThreadsPerBlockPhysical() * - tiling_scheme.GetNumberOfBlocksPhysical(), - builder); - ReductionCodegenState codegen_state = GenerateReductionCodegenState( - builder, fusion, reduction_info, reductions, fused_emitter); - - EmitTileElementFunction emit_reduction_element = - [&](const TilingThreadIdInfo& thread_id_info, - const llvm_ir::IrArray::Index& index, llvm::Value* y_loc, - llvm::Value* x_loc) { - llvm_ir::IrArray::Index input_index = GetUnnormalizedIndex( - index, input_shape, builder, - codegen_state.GetTilingScheme().GetDimsInElems()); - llvm::Value* partial_result_index = - codegen_state.IsRowReduction() - ? builder->getInt32(0) - : builder->CreateSub( - x_loc, - GetStartOffsetX(tiling_scheme, thread_id_info.thread_id_x, - index_ty, builder)); - - // Clear the linear index field of the IrArray::Index to enable the use - // of GetElementPointer with array types. This enables the vectorization - // of the computation for different partial results. Use this index if - // 'num_partial_results > 1'. - int num_partial_results = codegen_state.GetNumPartialResults(); - llvm_ir::IrArray::Index index_without_linear{ - input_index.multidim(), input_shape, input_index.GetType()}; - - // Emit code to generate the input and perform the reduction computation - // for each reduction instruction. - for (const HloReduceInstruction* reduce : reductions) { - GenerateElementForReducer(builder, ir_emitter_context, reduce, - partial_result_index, codegen_state, - index_without_linear, input_index, - num_partial_results, result_ir_arrays); - } - - // Emit code to generate the output for the non-reduction instructions - // in the fusion, if any. - TF_CHECK_OK(EmitExtraOutputsForReduce( - builder, input_shape, result_ir_arrays, input_index, reduction_info, - extra_output_gens)); - }; - - TF_ASSIGN_OR_RETURN( - TilingKernelInfo tiling_kernel_info, - EmitTilingKernel(builder, tiling_scheme, index_ty, - [&](const TilingThreadIdInfo& thread_id_info, - const llvm_ir::IrArray::Index& index, - std::array tile_dimensions) { - EmitTile(builder, codegen_state.GetTilingScheme(), - index, thread_id_info, tile_dimensions, - emit_reduction_element); - })); - - KernelSupportLibrary ksl(builder); - for (const HloReduceInstruction* reduce : reductions) { - for (int partial_result_idx = 0; - partial_result_idx < reduction_info.GetNumPartialResults(); - ++partial_result_idx) { - if (codegen_state.IsRowReduction()) { - EmitReductionOutputForRowReduction( - builder, ir_emitter_context, tiling_kernel_info, codegen_state, - index_ty, result_ir_arrays, reduce, partial_result_idx); - } else { - EmitReductionOutputForColumnReduction( - builder, ir_emitter_context, tiling_kernel_info, codegen_state, - index_ty, result_ir_arrays, reduce, partial_result_idx); - } - } - } - - return OkStatus(); -} - -StatusOr> BuildKernelThunkForFusion( - IrEmitterContext& ir_emitter_context, KernelReuseCache& kernel_cache, - mlir::lmhlo::FusionOp fusion_op, const HloInstruction& fusion, - const LaunchDimensions& launch_dimensions, absl::string_view discriminator, - std::function, - std::vector)> - kernel_builder_fn, - llvm::IRBuilder<>* builder) { - TF_ASSIGN_OR_RETURN( - auto kernel_arguments, - KernelArguments::Create(ir_emitter_context.allocations(), fusion_op)); - - Status kernel_builder_status = OkStatus(); - auto [entry, cached] = kernel_cache.Get( - fusion.fused_instructions_computation(), kernel_arguments.args(), - discriminator, [&]() -> KernelReuseCache::Entry { - std::vector inputs, outputs; - llvm::Function* kernel; - std::tie(kernel, inputs, outputs) = BuildKernelPrototype( - ir_emitter_context, GetIrNameFromLoc(fusion_op->getLoc()), - kernel_arguments.args(), fusion.fused_parameters().size(), - launch_dimensions, builder); - kernel_builder_status = - kernel_builder_fn(std::move(inputs), std::move(outputs)); - return {kernel->getName().str(), launch_dimensions}; - }); - TF_RETURN_IF_ERROR(kernel_builder_status); - - return std::make_unique( - fusion_op, entry.kernel_name, kernel_arguments.args(), launch_dimensions); -} - -StatusOr> BuildFusedInitializerThunk( - IrEmitterContext& ir_emitter_context, mlir::lmhlo::FusionOp fusion_op, - const HloInstruction& fusion, ElementalIrEmitter& elemental_emitter, - KernelReuseCache& kernel_cache, int output_index, - llvm::IRBuilder<>* builder) { - auto reduce = mlir::dyn_cast_or_null( - fusion_op.getFusionRoots()[output_index]); - - TF_RET_CHECK(reduce); - TF_RET_CHECK(reduce.getNumResults() == 1); - - mlir::Value init_value = reduce.getInitValues()[0]; - mlir::Value dest = fusion_op.getOutputBuffers()[output_index]; - TF_ASSIGN_OR_RETURN(std::optional> constant_init_thunk, - BuildConstantInitializerThunk( - ir_emitter_context, fusion_op, init_value, dest)); - if (constant_init_thunk) { - return *std::move(constant_init_thunk); - } - - auto input_buffers = fusion_op.getInputBuffers(); - - const Shape dest_shape = GetShape(dest); - bool use_experimental_block_size = - ir_emitter_context.debug_options() - .xla_gpu_enable_experimental_block_size(); - - TF_ASSIGN_OR_RETURN(LaunchDimensions launch_dimensions, - CalculateLaunchDimensions( - dest_shape, ir_emitter_context.gpu_device_info(), - use_experimental_block_size)); - - const HloComputation* fused_computation = - fusion.fused_instructions_computation(); - HloInstruction* instr = fused_computation->root_instruction(); - if (instr->opcode() != HloOpcode::kTuple) { - CHECK_EQ(0, output_index); - } else { - instr = instr->mutable_operand(output_index); - } - TF_RET_CHECK(instr->shape().IsArray()); - - auto kernel_builder = [&](std::vector inputs, - std::vector outputs) -> Status { - FusedIrEmitter fused_emitter(elemental_emitter); - for (int i = 0; i < fused_computation->num_parameters(); i++) { - fused_emitter.BindGenerator( - *fused_computation->parameter_instruction(i), - [builder, &inputs, - i](llvm_ir::IrArray::Index index) -> StatusOr { - return inputs[i].EmitReadArrayElement(index, builder); - }); - } - TF_ASSIGN_OR_RETURN(auto generator, - fused_emitter.GetGenerator(*instr->operand(1))); - return ParallelLoopEmitter(generator, outputs, launch_dimensions, builder) - .EmitLoop(GetIrNameFromLoc(fusion_op.getLoc())); - }; - return BuildKernelThunkForFusion( - ir_emitter_context, kernel_cache, fusion_op, fusion, launch_dimensions, - /*discriminator=*/ - absl::StrCat("init_", output_index), kernel_builder, builder); -} - -} // namespace - -StatusOr ReductionFusion::Emit( - KernelReuseCache& kernel_cache, llvm::IRBuilder<>* builder) const { - auto* reduction_codegen_info = analysis_.GetReductionCodegenInfo(); - // Set `use_experimental_block_size` flag to false as the reduction code - // has its own custom logic of choosing a block size. - TF_ASSIGN_OR_RETURN(auto launch_dimensions, - analysis_.GetLaunchDimensions( - /*use_experimental_block_size=*/false)); - - FusionEmissionResult result; - VLOG(3) << "Launch dimensions of " - << mlir::mhlo::GetDebugNameFromLocation(fusion_op().getLoc()) << ": " - << launch_dimensions.ToString(); - if (!reduction_codegen_info->IsRaceFree()) { - absl::Span fusion_roots = analysis_.fusion_roots(); - for (int i = 0; i < fusion_roots.size(); ++i) { - if (IsReductionFromOrToContiguousDimensions(*fusion_roots[i])) { - TF_ASSIGN_OR_RETURN(result.thunks.emplace_back(), - BuildFusedInitializerThunk( - ir_emitter_context_, fusion_op(), fusion_, - elemental_emitter_, kernel_cache, i, builder)); - } - } - } - - auto kernel_builder = [&](std::vector inputs, - std::vector outputs) -> Status { - FusedIrEmitter fused_emitter(elemental_emitter_); - const HloComputation* fused_computation = analysis_.fused_computation(); - for (int i = 0; i < fused_computation->num_parameters(); i++) { - llvm_ir::IrArray ir_array = inputs[i]; - HloInstruction* fused_operand = - fused_computation->parameter_instruction(i); - fused_emitter.BindGenerator( - *fused_operand, - [builder, ir_array, fused_operand]( - const llvm_ir::IrArray::Index& index) -> StatusOr { - return ir_array.EmitReadArrayElement(index, builder, - fused_operand->name()); - }); - } - - // Get outputs. - ReductionOutputMap result_ir_arrays; - - // Skip all parameter buffers first. - int ir_arrays_idx = 0; - auto outputs_span = absl::MakeSpan(outputs); - for (HloInstruction* root : analysis_.fusion_roots()) { - int num_results = - root->shape().IsTuple() ? root->shape().tuple_shapes_size() : 1; - result_ir_arrays[root] = outputs_span.subspan(ir_arrays_idx, num_results); - ir_arrays_idx += num_results; - } - - KernelSupportLibrary ksl(builder, llvm_ir::UnrollMode::kDefaultUnroll); - - // Use raw block_id_y to select the i-th parallel reduction to run. Using - // block_id_y instead of block_id_x simplifies the index calculation - // for reduction code generation as the block_id_y is orthogonal to - // the indices used within the reductions. - const std::vector>& instr_index_groups = - reduction_codegen_info->GetIndexGroups(); - Shape reduce_operand_shape = - reduction_codegen_info->GetReduceOperandShape(); - - llvm::CallInst* raw_block_id_y = gpu::EmitCallToTargetIntrinsic( - gpu::TargetIntrinsicID::kBlockIdy, {}, {}, builder); - llvm_ir::AddRangeMetadata(0, instr_index_groups.size(), - llvm::cast(raw_block_id_y)); - for (int i = 0; i < instr_index_groups.size(); ++i) { - TF_RETURN_IF_ERROR(ksl.IfWithStatus( - absl::StrCat("reduce-group-", i), - builder->CreateICmpEQ(raw_block_id_y, builder->getInt32(i)), [&] { - return EmitIRForReduction(builder, ir_emitter_context_, fusion_op(), - instr_index_groups[i], fused_emitter, - result_ir_arrays, *reduction_codegen_info, - reduce_operand_shape); - })); - } - - return OkStatus(); - }; - - TF_ASSIGN_OR_RETURN( - result.thunks.emplace_back(), - BuildKernelThunkForFusion(ir_emitter_context_, kernel_cache, fusion_op(), - fusion_, launch_dimensions, "", kernel_builder, - builder)); - return result; -} - -} // namespace gpu -} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/fusions/reduction.h b/tensorflow/compiler/xla/service/gpu/fusions/reduction.h deleted file mode 100644 index 5bb45d4e3c01ea..00000000000000 --- a/tensorflow/compiler/xla/service/gpu/fusions/reduction.h +++ /dev/null @@ -1,120 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FUSIONS_REDUCTION_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FUSIONS_REDUCTION_H_ - -#include "tensorflow/compiler/xla/service/gpu/fusions/fusion_emitter.h" -#include "tensorflow/compiler/xla/service/gpu/hlo_fusion_analysis.h" - -namespace xla { -namespace gpu { - -// Generates code for reduction to contiguous dimensions. -// -// Row reduction uses the following algorithm described in CUDA-like -// pseudocode: -// -// ``` -// __global__ void reduce(int num_rows, float *in, float out) { -// __shared__ float[32] cache; -// int offset = blockDim.x * blockIdx.x + threadIdx.x; -// if (offset >= num_rows) return; -// int tile_bound = std::min(offset + kTileSizeX, num_rows); -// float accum = 0; -// for (int i=offset; i Emit( - KernelReuseCache& kernel_cache, - llvm::IRBuilder<>* builder) const override; - - private: - mlir::lmhlo::FusionOp fusion_op() const { return fusion_op_; } - - IrEmitterContext& ir_emitter_context_; - ElementalIrEmitter& elemental_emitter_; - mlir::lmhlo::FusionOp fusion_op_; - const HloFusionInstruction& fusion_; - HloFusionAnalysis& analysis_; -}; - -} // namespace gpu -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FUSIONS_REDUCTION_H_ diff --git a/tensorflow/compiler/xla/service/gpu/fusions/thunk_util.cc b/tensorflow/compiler/xla/service/gpu/fusions/thunk_util.cc deleted file mode 100644 index e38504096376a8..00000000000000 --- a/tensorflow/compiler/xla/service/gpu/fusions/thunk_util.cc +++ /dev/null @@ -1,120 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu/fusions/thunk_util.h" - -#include -#include - -#include "absl/types/span.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project -#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/IR/Operation.h" // from @llvm-project -#include "mlir/IR/SymbolTable.h" // from @llvm-project -#include "mlir/IR/Value.h" // from @llvm-project -#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" -#include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" -#include "tensorflow/compiler/xla/service/gpu/memset_thunk.h" -#include "tensorflow/compiler/xla/service/gpu/thunk.h" -#include "tensorflow/compiler/xla/shape.h" -#include "tensorflow/compiler/xla/translate/hlo_to_mhlo/hlo_utils.h" - -namespace xla { -namespace gpu { -namespace { - -// TODO(b/291536641): Clean this up. What's the difference between this and the -// caller? -std::optional> BuildConstantInitializerThunk( - mlir::Operation* op, absl::Span init_value, mlir::Value dest, - const BufferAllocation::Slice& dest_slice, const Shape& output_shape) { - int64_t num_bytes = init_value.size(); - if (absl::c_all_of(init_value, [](uint8_t byte) { return byte == 0; })) { - return {{std::make_unique(Thunk::ThunkInfo(op), dest_slice, - dest)}}; - } - - // If the literal is 8 or 16 bits wide, we can emit a 32-bit memset by - // repeating the literal 4 or 2 times, so long as the destination buffer is - // an even multiple of 32 bits long. - if ((num_bytes == 1 || num_bytes == 2) && - ShapeUtil::ByteSizeOf(output_shape) % 4 == 0) { - uint16_t pattern16; - if (num_bytes == 1) { - uint8_t b = init_value.front(); - pattern16 = uint16_t{b} | (uint16_t{b} << 8); - } else { - memcpy(&pattern16, init_value.data(), sizeof(pattern16)); - } - uint32_t pattern32 = uint32_t{pattern16} | (uint32_t{pattern16} << 16); - return {{std::make_unique( - Thunk::ThunkInfo(op), pattern32, dest_slice, dest)}}; - } - - // If the literal is an even multiple of 32 bits wide, we can emit a 32-bit - // memset so long as all 32-bit words of the scalar are equal to each other. - if (num_bytes >= 4 && num_bytes % 4 == 0 && - memcmp(init_value.data(), init_value.data() + 4, init_value.size() - 4) == - 0) { - uint32_t word; - memcpy(&word, init_value.data(), sizeof(word)); - return {{std::make_unique(Thunk::ThunkInfo(op), word, - dest_slice, dest)}}; - } - - return std::nullopt; -} - -} // namespace - -StatusOr>> BuildConstantInitializerThunk( - IrEmitterContext& ir_emitter_context, mlir::Operation* op, - mlir::Value init_value, mlir::Value dest) { - mlir::DenseElementsAttr const_init; - if (auto get_global_memref = - mlir::dyn_cast_or_null( - init_value.getDefiningOp())) { - auto global_memref = - mlir::SymbolTable::lookupNearestSymbolFrom( - get_global_memref, get_global_memref.getNameAttr()); - if (global_memref.getConstant() && global_memref.getInitialValue()) { - // If the initial value happens to be a constant, generate a specialized - // thunk. - const_init = global_memref.getInitialValue() - .value() - .cast(); - } - } else if (auto constant = mlir::dyn_cast_or_null( - init_value.getDefiningOp())) { - const_init = constant.getValue().dyn_cast(); - } - - if (const_init) { - std::vector literal_bytes; - TF_RETURN_IF_ERROR( - CopyDenseElementsDataToXlaFormat(const_init, &literal_bytes)); - - TF_ASSIGN_OR_RETURN( - auto dest_slice, - GetAllocationSlice(dest, ir_emitter_context.allocations())); - - const Shape dest_shape = GetShape(dest); - return BuildConstantInitializerThunk(op, literal_bytes, dest, dest_slice, - dest_shape); - } - return std::nullopt; -} - -} // namespace gpu -} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/fusions/thunk_util.h b/tensorflow/compiler/xla/service/gpu/fusions/thunk_util.h deleted file mode 100644 index b52c535f034654..00000000000000 --- a/tensorflow/compiler/xla/service/gpu/fusions/thunk_util.h +++ /dev/null @@ -1,37 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FUSIONS_THUNK_UTIL_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FUSIONS_THUNK_UTIL_H_ - -#include -#include - -#include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" -#include "tensorflow/compiler/xla/service/gpu/thunk.h" -#include "tensorflow/compiler/xla/statusor.h" - -namespace xla { -namespace gpu { - -// Attempts to build an initializer constant for the given value. Returns an -// empty optional if the value is not a constant. -StatusOr>> BuildConstantInitializerThunk( - IrEmitterContext& ir_emitter_context, mlir::Operation* op, - mlir::Value init_value, mlir::Value dest); - -} // namespace gpu -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FUSIONS_THUNK_UTIL_H_ diff --git a/tensorflow/compiler/xla/service/gpu/fusions/tiling_util.cc b/tensorflow/compiler/xla/service/gpu/fusions/tiling_util.cc index 8d0ca80b2d9182..c0c2640a895937 100644 --- a/tensorflow/compiler/xla/service/gpu/fusions/tiling_util.cc +++ b/tensorflow/compiler/xla/service/gpu/fusions/tiling_util.cc @@ -360,38 +360,5 @@ llvm::Value* TilingThreadIdInfo::GEPIntoSharedMemory( return b->CreateAddrSpaceCast(gep, pointer_in_addressspace); } -llvm_ir::IrArray::Index GetUnnormalizedIndex( - const llvm_ir::IrArray::Index& normalized_shape_index, - const Shape& unnormalized_shape, llvm::IRBuilder<>* builder, - absl::Span dims_in_elems) { - CHECK_EQ(normalized_shape_index.size(), 3); - // If the normalization only add a new dimensions of size 1, - // generate simpler indexing. LLVM doesn't always simplify the more - // complicated indexing and this prevents it from vectorizing some - // cases. We do this only for major_to_minor memory layout. - if (unnormalized_shape.rank() == 2 && unnormalized_shape.has_layout() && - unnormalized_shape.dimensions()[0] == normalized_shape_index.dims()[1] && - unnormalized_shape.dimensions()[1] == normalized_shape_index.dims()[2] && - unnormalized_shape.layout().minor_to_major(1) == 0) { - CHECK_EQ(normalized_shape_index.dims()[0], 1); - auto multidim = normalized_shape_index.multidim(); - return llvm_ir::IrArray::Index({multidim[1], multidim[2]}, - unnormalized_shape, - normalized_shape_index.GetType()); - } - if (unnormalized_shape.rank() == 2 && unnormalized_shape.has_layout() && - unnormalized_shape.dimensions()[0] == normalized_shape_index.dims()[2] && - unnormalized_shape.dimensions()[1] == normalized_shape_index.dims()[1] && - unnormalized_shape.layout().minor_to_major(1) == 1) { - CHECK_EQ(normalized_shape_index.dims()[0], 1); - auto multidim = normalized_shape_index.multidim(); - return llvm_ir::IrArray::Index({multidim[2], multidim[1]}, - unnormalized_shape, - normalized_shape_index.GetType()); - } - return normalized_shape_index.SourceIndexOfBitcast( - ShapeUtil::MakeShape(F32, dims_in_elems), unnormalized_shape, builder); -} - } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/fusions/tiling_util.h b/tensorflow/compiler/xla/service/gpu/fusions/tiling_util.h index be697addeb78e0..1d8dd40e916300 100644 --- a/tensorflow/compiler/xla/service/gpu/fusions/tiling_util.h +++ b/tensorflow/compiler/xla/service/gpu/fusions/tiling_util.h @@ -138,11 +138,6 @@ StatusOr EmitTilingKernel( llvm::IRBuilder<>* builder, const TilingScheme& tiling_scheme, llvm::Type* index_ty, const TileElementGenerator& tile_element_generator); -llvm_ir::IrArray::Index GetUnnormalizedIndex( - const llvm_ir::IrArray::Index& normalized_shape_index, - const Shape& unnormalized_shape, llvm::IRBuilder<>* builder, - absl::Span dims_in_elems); - } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index 95d0a67c6c4f27..3070cbf9b1dfa7 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -21,17 +21,13 @@ limitations under the License. #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DerivedTypes.h" -#include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_computation.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" #include "tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h" -#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" -#include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h" -#include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 3e23e291c96330..d51ef9046c6e5b 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -92,7 +92,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/for_thunk.h" #include "tensorflow/compiler/xla/service/gpu/fused_mha_thunk.h" #include "tensorflow/compiler/xla/service/gpu/fusions/fusions.h" -#include "tensorflow/compiler/xla/service/gpu/fusions/thunk_util.h" #include "tensorflow/compiler/xla/service/gpu/fusions/tiling_util.h" #include "tensorflow/compiler/xla/service/gpu/gemm_thunk.h" #include "tensorflow/compiler/xla/service/gpu/gpu_asm_opts_util.h" @@ -2047,6 +2046,8 @@ Status IrEmitterUnnested::EmitFusion(mlir::Operation* op) { #endif LOG(FATAL) << "Unsupported fusion kind: " << backend_config.kind(); } + case HloFusionAnalysis::EmitterFusionKind::kReduction: + return EmitUnnestedReduction(fusion_op, fusion_analysis); case HloFusionAnalysis::EmitterFusionKind::kTranspose: return EmitUnnestedTranspose(fusion_op, fusion_analysis); case HloFusionAnalysis::EmitterFusionKind::kInputSlices: @@ -2054,7 +2055,6 @@ Status IrEmitterUnnested::EmitFusion(mlir::Operation* op) { case HloFusionAnalysis::EmitterFusionKind::kScatter: return EmitScatter(fusion_op, fused_computation, fusion_analysis); case HloFusionAnalysis::EmitterFusionKind::kLoop: - case HloFusionAnalysis::EmitterFusionKind::kReduction: return FailedPrecondition( "Loop fusion should have been handled by GetFusionEmitter."); } @@ -3207,6 +3207,86 @@ IrEmitterUnnested::BuildKernelThunkForNonFusionOp( launch_dimensions); } +std::unique_ptr IrEmitterUnnested::BuildConstantInitializerThunk( + mlir::Operation* op, absl::Span init_value, mlir::Value dest, + const BufferAllocation::Slice& dest_slice, const Shape& output_shape) { + int64_t num_bytes = init_value.size(); + if (absl::c_all_of(init_value, [](uint8_t byte) { return byte == 0; })) { + return std::make_unique(Thunk::ThunkInfo(op), dest_slice, + dest); + } + + // If the literal is 8 or 16 bits wide, we can emit a 32-bit memset by + // repeating the literal 4 or 2 times, so long as the destination buffer is + // an even multiple of 32 bits long. + if ((num_bytes == 1 || num_bytes == 2) && + ShapeUtil::ByteSizeOf(output_shape) % 4 == 0) { + uint16_t pattern16; + if (num_bytes == 1) { + uint8_t b = init_value.front(); + pattern16 = uint16_t{b} | (uint16_t{b} << 8); + } else { + memcpy(&pattern16, init_value.data(), sizeof(pattern16)); + } + uint32_t pattern32 = uint32_t{pattern16} | (uint32_t{pattern16} << 16); + return std::make_unique(Thunk::ThunkInfo(op), + pattern32, dest_slice, dest); + } + + // If the literal is an even multiple of 32 bits wide, we can emit a 32-bit + // memset so long as all 32-bit words of the scalar are equal to each other. + if (num_bytes >= 4 && num_bytes % 4 == 0 && + memcmp(init_value.data(), init_value.data() + 4, init_value.size() - 4) == + 0) { + uint32_t word; + memcpy(&word, init_value.data(), sizeof(word)); + return std::make_unique(Thunk::ThunkInfo(op), word, + dest_slice, dest); + } + + return nullptr; +} + +StatusOr> +IrEmitterUnnested::TryBuildConstantInitializerThunk(mlir::Operation* op, + mlir::Value init_value, + mlir::Value dest) { + mlir::DenseElementsAttr const_init; + if (auto get_global_memref = + mlir::dyn_cast_or_null( + init_value.getDefiningOp())) { + auto global_memref = + mlir::SymbolTable::lookupNearestSymbolFrom( + get_global_memref, get_global_memref.getNameAttr()); + if (global_memref.getConstant() && global_memref.getInitialValue()) { + // If the initial value happens to be a constant, generate a specialized + // thunk. + const_init = global_memref.getInitialValue() + .value() + .cast(); + } + } else if (auto constant = mlir::dyn_cast_or_null( + init_value.getDefiningOp())) { + const_init = constant.getValue().dyn_cast(); + } + + if (const_init) { + std::vector literal_bytes; + TF_RETURN_IF_ERROR( + CopyDenseElementsDataToXlaFormat(const_init, &literal_bytes)); + + TF_ASSIGN_OR_RETURN(auto dest_slice, GetAllocationSlice(dest)); + + const Shape dest_shape = GetShape(dest); + auto thunk = BuildConstantInitializerThunk(op, literal_bytes, dest, + dest_slice, dest_shape); + if (thunk) { + return {std::move(thunk)}; + } + } + return std::unique_ptr(); +} + Status IrEmitterUnnested::BuildInitializerThunk(mlir::Operation* op, mlir::Value init_value, mlir::Value dest) { @@ -3214,11 +3294,10 @@ Status IrEmitterUnnested::BuildInitializerThunk(mlir::Operation* op, auto init_type = init_value.getType().dyn_cast(); TF_RET_CHECK(init_type.getRank() == 0); - TF_ASSIGN_OR_RETURN(std::optional> constant_init_thunk, - BuildConstantInitializerThunk(*ir_emitter_context_, op, - init_value, dest)); + TF_ASSIGN_OR_RETURN(std::unique_ptr constant_init_thunk, + TryBuildConstantInitializerThunk(op, init_value, dest)); if (constant_init_thunk) { - AddThunkToThunkSequence(*std::move(constant_init_thunk)); + AddThunkToThunkSequence(std::move(constant_init_thunk)); return OkStatus(); } @@ -3250,6 +3329,77 @@ Status IrEmitterUnnested::BuildInitializerThunk(mlir::Operation* op, return OkStatus(); } +Status IrEmitterUnnested::BuildFusedInitializerThunk( + mlir::lmhlo::FusionOp fusion, int output_index) { + auto reduce = mlir::dyn_cast_or_null( + fusion.getFusionRoots()[output_index]); + + TF_RET_CHECK(reduce); + TF_RET_CHECK(reduce.getNumResults() == 1); + + mlir::Value init_value = reduce.getInitValues()[0]; + mlir::Value dest = fusion.getOutputBuffers()[output_index]; + TF_ASSIGN_OR_RETURN( + std::unique_ptr constant_init_thunk, + TryBuildConstantInitializerThunk(fusion, init_value, dest)); + if (constant_init_thunk) { + AddThunkToThunkSequence(std::move(constant_init_thunk)); + return OkStatus(); + } + + auto input_buffers = fusion.getInputBuffers(); + + const Shape dest_shape = GetShape(dest); + bool use_experimental_block_size = + ir_emitter_context_->debug_options() + .xla_gpu_enable_experimental_block_size(); + + TF_ASSIGN_OR_RETURN(LaunchDimensions launch_dimensions, + CalculateLaunchDimensions( + dest_shape, ir_emitter_context_->gpu_device_info(), + use_experimental_block_size)); + + TF_ASSIGN_OR_RETURN( + std::optional> opt_ir_arrays, + BuildKernelThunkForFusion( + fusion, launch_dimensions, + /*discriminator=*/absl::StrCat("init_", output_index))); + if (!opt_ir_arrays.has_value()) { + // The kernel was reused, no need to emit code. + return OkStatus(); + } + std::vector& ir_arrays = opt_ir_arrays.value(); + + const llvm_ir::IrArray dest_array = + ir_arrays[input_buffers.size() + output_index]; + + const HloComputation* fused_computation = + *GetOrCreateSubComputationFromRegion(&fusion.getRegion(), + /*is_fusion=*/true); + + FusedIrEmitter fused_emitter(elemental_emitter_); + for (int i = 0; i < fused_computation->num_parameters(); i++) { + fused_emitter.BindGenerator( + *fused_computation->parameter_instruction(i), + [this, &ir_arrays, i](llvm_ir::IrArray::Index index) { + return ir_arrays[i].EmitReadArrayElement(index, &b_); + }); + } + HloInstruction* instr = fused_computation->root_instruction(); + if (instr->opcode() != HloOpcode::kTuple) { + CHECK_EQ(0, output_index); + } else { + instr = instr->mutable_operand(output_index); + } + TF_RET_CHECK(instr->shape().IsArray()); + TF_ASSIGN_OR_RETURN(auto generator, + fused_emitter.GetGenerator(*instr->operand(1))); + TF_RETURN_IF_ERROR( + ParallelLoopEmitter(generator, {dest_array}, launch_dimensions, &b_) + .EmitLoop(GetIrNameFromLoc(fusion.getLoc()))); + return OkStatus(); +} + StatusOr> IrEmitterUnnested::BuildWhileThunk( mlir::lmhlo::WhileOp while_op, const Thunk::ThunkInfo& thunk_info) { // Generate thunk sequence for while 'condition'. @@ -3293,6 +3443,263 @@ Status IrEmitterUnnested::EmitTargetElementLoop( return InternalError("This should be unreachable"); } +// Gets the output offset as calculated from thread_id.x (to be applied to the +// offset calculated from block_id and thread_id.y). +static llvm::Value* GetStartOffsetX(const TilingScheme& tiling_scheme, + llvm::Value* thread_id_x, + llvm::Type* index_ty, + llvm::IRBuilder<>* b) { + int64_t multiplier = tiling_scheme.GetIndexingOrder() == kStridedIndexingX + ? tiling_scheme.GetVectorSize() + : tiling_scheme.GetTileSizeFor(kDimX); + return b->CreateMul(thread_id_x, + llvm::ConstantInt::get(index_ty, multiplier)); +} + +static IrArray::Index GetUnnormalizedIndex( + const IrArray::Index& normalized_shape_index, + const Shape& unnormalized_shape, llvm::IRBuilder<>* b_, + absl::Span dims_in_elems) { + CHECK_EQ(normalized_shape_index.size(), 3); + // If the normalization only add a new dimensions of size 1, + // generate simpler indexing. LLVM doesn't always simplify the more + // complicated indexing and this prevents it from vectorizing some + // cases. We do this only for major_to_minor memory layout. + if (unnormalized_shape.rank() == 2 && unnormalized_shape.has_layout() && + unnormalized_shape.dimensions()[0] == normalized_shape_index.dims()[1] && + unnormalized_shape.dimensions()[1] == normalized_shape_index.dims()[2] && + unnormalized_shape.layout().minor_to_major(1) == 0) { + CHECK_EQ(normalized_shape_index.dims()[0], 1); + auto multidim = normalized_shape_index.multidim(); + return IrArray::Index({multidim[1], multidim[2]}, unnormalized_shape, + normalized_shape_index.GetType()); + } + if (unnormalized_shape.rank() == 2 && unnormalized_shape.has_layout() && + unnormalized_shape.dimensions()[0] == normalized_shape_index.dims()[2] && + unnormalized_shape.dimensions()[1] == normalized_shape_index.dims()[1] && + unnormalized_shape.layout().minor_to_major(1) == 1) { + CHECK_EQ(normalized_shape_index.dims()[0], 1); + auto multidim = normalized_shape_index.multidim(); + return IrArray::Index({multidim[2], multidim[1]}, unnormalized_shape, + normalized_shape_index.GetType()); + } + return normalized_shape_index.SourceIndexOfBitcast( + ShapeUtil::MakeShape(F32, dims_in_elems), unnormalized_shape, b_); +} + +static int GetNumOutputs(const Shape& shape) { + if (shape.IsTuple()) { + return shape.tuple_shapes_size(); + } + return 1; +} + +ReductionCodegenState IrEmitterUnnested::GenerateReductionCodegenState( + mlir::lmhlo::FusionOp fusion, const ReductionCodegenInfo& reduction_info, + absl::Span reduce_instr_index_group, + FusedIrEmitter& fused_emitter) { + ReductionCodegenState reduction_codegen_state(reduction_info); + VLOG(10) << "Emit prologue for reduction: " << llvm_ir::DumpToString(fusion); + + for (const HloReduceInstruction* reduce_hlo : reduce_instr_index_group) { + int num_partial_results = reduction_codegen_state.GetNumPartialResults(); + for (int op_result_idx = 0; + op_result_idx < GetNumOutputs(reduce_hlo->shape()); op_result_idx++) { + Shape result_shape = reduce_hlo->shape().IsTuple() + ? reduce_hlo->shape().tuple_shapes(op_result_idx) + : reduce_hlo->shape(); + + llvm::Type* element_type = + llvm_ir::PrimitiveTypeToIrType(result_shape.element_type(), module_); + llvm::AllocaInst* reduction_input_address = + llvm_ir::EmitAllocaAtFunctionEntry(element_type, + "reduction_input_address", &b_); + + llvm::AllocaInst* partial_result_address = + llvm_ir::EmitAllocaAtFunctionEntryWithCount( + element_type, /*element_count=*/b_.getInt32(num_partial_results), + "partial_reduction_result", &b_); + + const HloInstruction* init_value = + reduce_hlo->init_values()[op_result_idx]; + + // Initialize the partial result with the initial value of the reduction. + llvm::Value* init_ir_value = (*fused_emitter.GetGenerator(*init_value))( + IrArray::Index(b_.getInt32Ty())) + .value(); + + for (int i = 0; i < num_partial_results; ++i) { + b_.CreateStore(init_ir_value, + InBoundsGEP(partial_result_address->getAllocatedType(), + partial_result_address, {b_.getInt32(i)})); + } + + const TilingScheme& tiling_scheme = + reduction_codegen_state.GetTilingScheme(); + int64_t num_threads_x = tiling_scheme.GetNumThreadsFor(kDimX); + llvm::GlobalVariable* shared_cache = [&]() -> llvm::GlobalVariable* { + if (reduction_codegen_state.IsRowReduction()) { + // Multi-row reductions do not use shared memory. + if (RowReductionGetRowsPerWarp(tiling_scheme.GetDimsInElems()[2]) > + 1) { + return nullptr; + } + // Allocate __shared__ + // cache[num_partial_results][num_warps][scaling_factor]. + CHECK_EQ(tiling_scheme.GetNumThreadsPerBlock() % WarpSize(), 0); + int num_warps = tiling_scheme.GetNumThreadsPerBlock() / WarpSize(); + return AllocateShared(tiling_scheme, element_type, + {num_partial_results, num_warps}, + "shared_cache"); + } else { + // Allocate __shared__ + // cache[num_threads][num_threads + 1], where + // num_threads == num_threads_x == num_threads_y. The "+1" is used to + // avoid bank conflicts. + // + // (Although each thread produces num_partial_results results, we + // don't need that much cache: Only one result is live at a time.) + CHECK_EQ(num_threads_x, tiling_scheme.GetNumThreadsFor(kDimY)); + return AllocateShared(tiling_scheme, element_type, + {num_threads_x, num_threads_x + 1}, + "shared_cache"); + } + }(); + + llvm_ir::ElementGenerator input_gen = + *fused_emitter.GetGenerator(*reduce_hlo->inputs()[op_result_idx]); + reduction_codegen_state.SetCalculationStateFor( + {shared_cache, init_ir_value, partial_result_address, + reduction_input_address, input_gen}, + reduce_hlo, op_result_idx); + } + } + + return reduction_codegen_state; +} + +void IrEmitterUnnested::EmitFullWarpShuffleDownLoopForReduce( + const HloComputation* reducer, + absl::Span partial_result_addresses, + int threads_per_block, int num_results_per_warp) { + // This only works when the block size is a multiple of 32 threads. + + // We check this here as a mistake in the number of threads per + // block is very hard to detect. + CHECK_EQ(threads_per_block % 32, 0); + CHECK_EQ(WarpSize() % num_results_per_warp, 0); + + for (int distance = 16 / num_results_per_warp; distance >= 1; distance /= 2) { + absl::InlinedVector reduction_params; + + for (auto acc : partial_result_addresses) { + reduction_params.push_back(acc.first); + } + + for (auto [partial_result_address, element_type] : + partial_result_addresses) { + int bit_width = llvm_ir::GetSizeInBits(element_type); + llvm::Value* result_from_other_lane = llvm_ir::EmitAllocaAtFunctionEntry( + element_type, "result_from_other_lane", &b_); + + reduction_params.push_back(result_from_other_lane); + + // Bitcast cannot be applied to aggregate types (even packed ones), so + // we bitcast addresses of load/store to intN* of the same bit-width. + llvm::Type* shuffled_value_type = + element_type->isStructTy() ? b_.getIntNTy(bit_width) : element_type; + auto convert_pointer_for_shuffle = [&](llvm::Value* ptr) { + return b_.CreatePointerBitCastOrAddrSpaceCast( + ptr, shuffled_value_type->getPointerTo()); + }; + + llvm::Value* partial_result = + b_.CreateLoad(shuffled_value_type, + convert_pointer_for_shuffle(partial_result_address), + "partial_reduction_result"); + b_.CreateStore( + EmitFullWarpShuffleDown(partial_result, b_.getInt32(distance), &b_), + convert_pointer_for_shuffle(result_from_other_lane)); + } + + StatusOr> returned_scalars = + CallNestedComputationWithScalarAddrs(&b_, *ir_emitter_context_, + *reducer, reduction_params); + TF_CHECK_OK(returned_scalars.status()); + + for (int i = 0; i < returned_scalars->size(); i++) { + b_.CreateStore(/*Val=*/returned_scalars->at(i), + /*Ptr=*/partial_result_addresses[i].first); + } + } +} + +llvm::Value* IrEmitterUnnested::GetOutputAddressForReduction( + int partial_result_idx, llvm::Type* index_ty, + const ReductionCodegenState& reduction_codegen_state, + const TilingKernelInfo& tiling_kernel_info, + const IrEmitterUnnested::ReductionOutputMap& output_arrays, + const HloReduceInstruction* reduction, int output_idx) { + auto constant = [&](uint64_t c) -> llvm::Constant* { + return llvm::ConstantInt::get(index_ty, c); + }; + + const TilingScheme& tiling_scheme = reduction_codegen_state.GetTilingScheme(); + const TilingThreadIdInfo& thread_id_info = tiling_kernel_info.thread_id_info; + + IrArray::Index start_offset = [&] { + llvm::Value* x_loc = thread_id_info.thread_id_x; + llvm::Value* y_loc = thread_id_info.thread_id_y; + if (!reduction_codegen_state.IsRowReduction()) { + std::swap(x_loc, y_loc); + } + llvm::Value* start_offset_x = + GetStartOffsetX(tiling_scheme, x_loc, index_ty, &b_); + return tiling_kernel_info.tile_origin.AddOffsetToDim(y_loc, kDimY, &b_) + .AddOffsetToDim(start_offset_x, kDimX, &b_); + }(); + + const IrArray& output_array = output_arrays.at(reduction)[output_idx]; + const Shape& operand_shape = reduction->inputs()[output_idx]->shape(); + Shape reduction_kept_element_shape = + ShapeUtil::DeleteDimensions(reduction->dimensions(), operand_shape); + + // Given the IrArray index of a reduction input, returns the linear address of + // the reduction output as if the reduction were going to keep the input shape + // with the dimensions being reduced moved. + llvm::Value* untransposed_output_linear_address = [&] { + const llvm_ir::IrArray::Index index = + start_offset.AddOffsetToDim(constant(partial_result_idx), kDimX, &b_); + if (reduction_codegen_state.IsRowReduction()) { + // For row-reduction, y-coordinate determines which row we write into. + return index[kDimY]; + } + // For column reduction, we get the transposed address. + absl::Span dims_in_elem = tiling_scheme.GetDimsInElems(); + llvm::Value* x_dim_size = + index.GetConstantWithIndexType(dims_in_elem[kDimX]); + llvm::Value* x_block_offset = b_.CreateMul(index[kDimZ], x_dim_size); + return b_.CreateAdd(x_block_offset, index[kDimX]); + }(); + + // A reduction is allowed to transpose its output. For example, suppose + // we are reducing the second dimension of f32[10,20,30]{3,2,1}. We are + // allowed to produce as output either f32[10,30]{1,0} (no transpose) or + // f32[10,30]{0,1} (transposing the two output dims). + // + // At this point in the function we have a "partial sum" of input elements + // (stored in partial_result_addresses), and we need to accumulate it into + // the correct output element. + IrArray::Index element_index( + /*linear=*/untransposed_output_linear_address, + reduction_kept_element_shape, &b_); + IrArray::Index output_index(element_index.multidim(), output_array.GetShape(), + element_index.GetType()); + + return output_array.EmitArrayElementAddress(output_index, &b_, + "output_element_address"); +} + llvm::Value* IrEmitterUnnested::CastSharedToGlobal(llvm::Value* input, llvm::Type* element_type, llvm::Twine name) { @@ -3302,6 +3709,225 @@ llvm::Value* IrEmitterUnnested::CastSharedToGlobal(llvm::Value* input, name); } +void IrEmitterUnnested::WriteReductionOutput( + llvm::Type* index_ty, const ReductionCodegenState& reduction_codegen_state, + const TilingKernelInfo& tiling_kernel_info, + const ReductionOutputMap& output_arrays, + const HloReduceInstruction* reduction, int partial_result_idx, + const absl::Span values) { + const HloComputation* reducer = reduction->to_apply(); + for (const auto& [oidx, typed_ptr] : llvm::enumerate(values)) { + auto [output_ptr, type] = typed_ptr; + llvm::Value* output_address = GetOutputAddressForReduction( + partial_result_idx, index_ty, reduction_codegen_state, + tiling_kernel_info, output_arrays, reduction, oidx); + if (reduction_codegen_state.IsRaceFree()) { + b_.CreateStore(b_.CreateLoad(type, output_ptr, "output"), output_address); + } else { + CHECK_EQ(values.size(), 1); + TF_CHECK_OK(EmitAtomicOperationForNestedComputation( + &b_, *ir_emitter_context_, *reducer, output_address, output_ptr, + type)); + } + } +} + +void IrEmitterUnnested::EmitReductionOutputForRowReduction( + const TilingKernelInfo& tiling_kernel_info, + const ReductionCodegenState& reduction_codegen_state, llvm::Type* index_ty, + const ReductionOutputMap& output_arrays, + const HloReduceInstruction* reduction, int partial_result_idx) { + const HloComputation* reducer = reduction->to_apply(); + const auto& thread_id_info = tiling_kernel_info.thread_id_info; + auto constant = [&](uint64_t c) -> llvm::Constant* { + return llvm::ConstantInt::get(index_ty, c); + }; + auto is_zero = [&](llvm::Value* value) { + return b_.CreateICmpEQ(value, constant(0)); + }; + + int num_outputs = reducer->num_parameters() / 2; + const TilingScheme& tiling_scheme = reduction_codegen_state.GetTilingScheme(); + absl::InlinedVector current_outputs; + for (int output_idx = 0; output_idx < num_outputs; output_idx++) { + const ReductionCodegenState::ReductionCalculationState& state = + reduction_codegen_state.GetCalculationStateFor(reduction, output_idx); + current_outputs.push_back( + {InBoundsGEP(state.partial_result_address->getAllocatedType(), + state.partial_result_address, + {constant(partial_result_idx)}, "current_output"), + state.partial_result_address->getAllocatedType()}); + } + + int reduced_dimension_size = tiling_scheme.GetDimsInElems()[2]; + int num_rows_per_warp = RowReductionGetRowsPerWarp(reduced_dimension_size); + EmitFullWarpShuffleDownLoopForReduce( + reducer, absl::MakeSpan(current_outputs), + tiling_scheme.GetNumThreadsPerBlockPhysical(), num_rows_per_warp); + + KernelSupportLibrary ksl(&b_); + llvm::Value* warp_id = + b_.CreateUDiv(thread_id_info.thread_id_x, constant(WarpSize())); + + auto emit_write_output = [&](llvm::Value* write_condition, + const absl::Span values) { + ksl.If("reduction_write_output", write_condition, [&] { + WriteReductionOutput(index_ty, reduction_codegen_state, + tiling_kernel_info, output_arrays, reduction, + partial_result_idx, values); + }); + }; + + if (num_rows_per_warp > 1) { + llvm::Value* is_writing_thread = is_zero(b_.CreateAnd( + thread_id_info.thread_id_x, constant(reduced_dimension_size - 1))); + emit_write_output(is_writing_thread, current_outputs); + return; + } + + ksl.If("intra_warp_reduce_write", is_zero(thread_id_info.lane_id), [&] { + for (int oidx = 0; oidx < num_outputs; oidx++) { + const ReductionCodegenState::ReductionCalculationState& state = + reduction_codegen_state.GetCalculationStateFor(reduction, oidx); + llvm::Value* shmem_output_addr = thread_id_info.GEPIntoSharedMemory( + &b_, state.shared_cache, {constant(partial_result_idx), warp_id}); + Store(Load(current_outputs[oidx].second, current_outputs[oidx].first), + shmem_output_addr); + } + }); + + // TODO(cheshire): Don't we want to sync it once for everything in the + // output? Not once per each? + EmitSyncThreads(); + ksl.If("inter_warp_reduce", is_zero(warp_id), [&] { + absl::InlinedVector selected_values; + for (int oidx = 0; oidx < num_outputs; oidx++) { + const ReductionCodegenState::ReductionCalculationState& state = + reduction_codegen_state.GetCalculationStateFor(reduction, oidx); + llvm::Value* block_accum_addr = thread_id_info.GEPIntoSharedMemory( + &b_, state.shared_cache, + {constant(partial_result_idx), thread_id_info.lane_id}); + + llvm::Type* element_type = + state.partial_result_address->getAllocatedType(); + + /* Insure initial value address is in generic, not scratch. */ + llvm::Value* initial_value_addr = + CastSharedToGlobal(llvm_ir::EmitAllocaAtFunctionEntry( + element_type, "initial_value_addr", &b_), + element_type); + b_.CreateStore(state.initial_value, initial_value_addr); + + llvm::Value* warp_exists = b_.CreateICmpULT( + thread_id_info.thread_id_x, + constant(tiling_scheme.GetNumThreadsFor(kDimX) / WarpSize())); + + llvm::Value* selected_value = + b_.CreateSelect(warp_exists, block_accum_addr, initial_value_addr); + + selected_values.push_back({selected_value, element_type}); + } + + // If only one warp is present in the block, then we don't need inter-warp + // reduction. + // TODO(b/241414088) If only warp is present, then inter-warp communication + // using shared memory and synchronization using barrier is also unnecessary + // and should be removed. + if (tiling_scheme.GetNumThreadsPerBlock() > WarpSize()) { + EmitFullWarpShuffleDownLoopForReduce( + reducer, absl::MakeSpan(selected_values), + tiling_scheme.GetNumThreadsPerBlock()); + } + + emit_write_output(is_zero(thread_id_info.thread_id_x), selected_values); + }); +} + +void IrEmitterUnnested::EmitReductionOutputForColumnReduction( + const TilingKernelInfo& tiling_kernel_info, + const ReductionCodegenState& reduction_codegen_state, llvm::Type* index_ty, + const ReductionOutputMap& output_arrays, + const HloReduceInstruction* reduction, int partial_result_idx) { + KernelSupportLibrary ksl(&b_); + const HloComputation* reducer = reduction->to_apply(); + const auto& thread_id_info = tiling_kernel_info.thread_id_info; + + auto constant = [&](uint64_t c) -> llvm::Constant* { + return llvm::ConstantInt::get(index_ty, c); + }; + auto is_zero = [&](llvm::Value* value) { + return b_.CreateICmpEQ(value, constant(0)); + }; + const TilingScheme& tiling_scheme = reduction_codegen_state.GetTilingScheme(); + int num_outputs = reducer->num_parameters() / 2; + + // Wait for reads from shmem in the last iteration to complete. (If this is + // slow, we could "double-buffer" by having two shmem buffers and switching + // between them.) + if (partial_result_idx > 0) { + EmitSyncThreads(); + } + + // Store the transpose in shared memory. + for (int output_idx = 0; output_idx < num_outputs; output_idx++) { + const ReductionCodegenState::ReductionCalculationState& state = + reduction_codegen_state.GetCalculationStateFor(reduction, output_idx); + llvm::GlobalVariable* shared_cache = state.shared_cache; + llvm::AddrSpaceCastInst* shmem_output_addr = + llvm::cast(thread_id_info.GEPIntoSharedMemory( + &b_, shared_cache, + {thread_id_info.thread_id_x, thread_id_info.thread_id_y}, + "shmem_output_address")); + llvm::Value* current_output = + InBoundsGEP(state.partial_result_address->getAllocatedType(), + state.partial_result_address, + {constant(partial_result_idx)}, "current_output"); + + llvm::Value* current_output_value = + Load(state.partial_result_address->getAllocatedType(), current_output); + b_.CreateStore(current_output_value, shmem_output_addr); + } + + EmitSyncThreads(); + + // Get transposed element from shared memory. + absl::InlinedVector shmem_transposed_addrs; + for (int output_idx = 0; output_idx < num_outputs; output_idx++) { + const ReductionCodegenState::ReductionCalculationState& state = + reduction_codegen_state.GetCalculationStateFor(reduction, output_idx); + llvm::AddrSpaceCastInst* shmem_transposed_addr = + llvm::cast(thread_id_info.GEPIntoSharedMemory( + &b_, state.shared_cache, + {thread_id_info.thread_id_y, thread_id_info.thread_id_x}, + "shmem_transposed_addr")); + shmem_transposed_addrs.push_back( + {shmem_transposed_addr, llvm::cast( + shmem_transposed_addr->getPointerOperand()) + ->getResultElementType()}); + } + + EmitFullWarpShuffleDownLoopForReduce(reducer, + absl::MakeSpan(shmem_transposed_addrs), + tiling_scheme.GetNumThreadsPerBlock()); + + // Some warps in the block are completely outside of the bound of the + // tensor, so they should not write any output at all. + llvm::Value* has_output = + b_.CreateAnd(b_.CreateICmpULT(GetStartOffsetX(tiling_scheme, + thread_id_info.thread_id_y, + index_ty, &b_), + tiling_kernel_info.output_tile_bounds[1]), + b_.CreateICmpULT(thread_id_info.thread_id_x, + tiling_kernel_info.output_tile_bounds[0])); + + ksl.If("reduction_write_output", + b_.CreateAnd(has_output, is_zero(thread_id_info.lane_id)), [&] { + WriteReductionOutput(index_ty, reduction_codegen_state, + tiling_kernel_info, output_arrays, reduction, + partial_result_idx, shmem_transposed_addrs); + }); +} + llvm::CallInst* IrEmitterUnnested::EmitSyncThreads() { MaybeEmitFenceForAMDGPU(llvm::AtomicOrdering::SequentiallyConsistent, "workgroup"); @@ -3481,6 +4107,239 @@ llvm::GlobalVariable* IrEmitterUnnested::AllocateShared( array_type, buffer_name); } +// Generate a single element of the tile (update the accumulator state) for a +// given reducer of index `i`. +void IrEmitterUnnested::GenerateElementForReducer( + const HloReduceInstruction* reduction, llvm::Value* partial_result_index, + const ReductionCodegenState& codegen_state, + const llvm_ir::IrArray::Index& index_without_linear, + const IrArray::Index& input_index, int num_partial_results, + const ReductionOutputMap& result_ir_arrays) { + HloComputation* reducer = reduction->to_apply(); + CHECK_EQ(reducer->num_parameters() % 2, 0); + + absl::InlinedVector reduction_accumulators; + absl::InlinedVector reduction_input_value; + for (int red_idx = 0; red_idx < reducer->num_parameters() / 2; red_idx++) { + const ReductionCodegenState::ReductionCalculationState& state = + codegen_state.GetCalculationStateFor(reduction, red_idx); + + llvm::AllocaInst* input_address = state.input_address; + llvm::AllocaInst* partial_reduction_result_address = + state.partial_result_address; + llvm::Value* const input_ir_value = *state.input_gen( + num_partial_results > 1 ? index_without_linear : input_index); + b_.CreateStore(input_ir_value, input_address); + llvm::Value* partial_result_address = + InBoundsGEP(partial_reduction_result_address->getAllocatedType(), + partial_reduction_result_address, {partial_result_index}); + reduction_accumulators.push_back(partial_result_address); + reduction_input_value.push_back(input_address); + } + + absl::InlinedVector reduction_params; + for (llvm::Value* acc : reduction_accumulators) { + reduction_params.push_back(acc); + } + for (llvm::Value* value : reduction_input_value) { + reduction_params.push_back(value); + } + + // Emit a call to the variadic reducer. Since it may be returning a + // tuple, we can't return it directly as a value. Instead, before + // the call, we create N (N = # arguments in the tuple) allocas, one + // for each returned argument, then when we make the call we pass N + // pointers as last parameters, the called computation writes into + // those pointers, and we have returned values on the stack (as well + // as pointers to them). + StatusOr> returned_scalars = + CallNestedComputationWithScalarAddrs(&b_, *ir_emitter_context_, *reducer, + reduction_params); + TF_CHECK_OK(returned_scalars.status()); + + for (int i = 0; i < returned_scalars->size(); i++) { + b_.CreateStore(returned_scalars->at(i), reduction_accumulators[i]); + } +} + +Status IrEmitterUnnested::EmitIRForReduction( + mlir::lmhlo::FusionOp fusion, + absl::Span instr_index_group, + FusedIrEmitter& fused_emitter, const ReductionOutputMap& result_ir_arrays, + const ReductionCodegenInfo& reduction_info, const Shape& input_shape) { + std::vector reductions; + ExtraOutputGensMap extra_output_gens; + + for (const HloInstruction* hlo : instr_index_group) { + if (IsReductionFromOrToContiguousDimensions(*hlo)) { + reductions.push_back(Cast(hlo)); + } else { + extra_output_gens[hlo] = *fused_emitter.GetGenerator(*hlo); + } + } + + CHECK(!reductions.empty()) << " expect at least one reduce instructions."; + const TilingScheme& tiling_scheme = reduction_info.GetTilingScheme(); + CHECK_EQ(tiling_scheme.GetNumThreadsPerBlockPhysical() % WarpSize(), 0); + llvm::Type* index_ty = + GetIndexTypeForKernel(fusion, + tiling_scheme.GetNumThreadsPerBlockPhysical() * + tiling_scheme.GetNumberOfBlocksPhysical(), + &b_); + ReductionCodegenState codegen_state = GenerateReductionCodegenState( + fusion, reduction_info, reductions, fused_emitter); + + EmitTileElementFunction emit_reduction_element = + [&](const TilingThreadIdInfo& thread_id_info, const IrArray::Index& index, + llvm::Value* y_loc, llvm::Value* x_loc) { + IrArray::Index input_index = GetUnnormalizedIndex( + index, input_shape, &b_, + codegen_state.GetTilingScheme().GetDimsInElems()); + llvm::Value* partial_result_index = + codegen_state.IsRowReduction() + ? b_.getInt32(0) + : b_.CreateSub( + x_loc, + GetStartOffsetX(tiling_scheme, thread_id_info.thread_id_x, + index_ty, &b_)); + + // Clear the linear index field of the IrArray::Index to enable the use + // of GetElementPointer with array types. This enables the vectorization + // of the computation for different partial results. Use this index if + // 'num_partial_results > 1'. + int num_partial_results = codegen_state.GetNumPartialResults(); + IrArray::Index index_without_linear{input_index.multidim(), input_shape, + input_index.GetType()}; + + // Emit code to generate the input and perform the reduction computation + // for each reduction instruction. + for (const HloReduceInstruction* reduce : reductions) { + GenerateElementForReducer(reduce, partial_result_index, codegen_state, + index_without_linear, input_index, + num_partial_results, result_ir_arrays); + } + + // Emit code to generate the output for the non-reduction instructions + // in the fusion, if any. + TF_CHECK_OK(EmitExtraOutputsForReduce(input_shape, result_ir_arrays, + input_index, reduction_info, + extra_output_gens)); + }; + + TF_ASSIGN_OR_RETURN( + TilingKernelInfo tiling_kernel_info, + EmitTilingKernel( + &b_, tiling_scheme, index_ty, + [&](const TilingThreadIdInfo& thread_id_info, + const IrArray::Index& index, ValueVector2 tile_dimensions) { + EmitTile(&b_, codegen_state.GetTilingScheme(), index, + thread_id_info, tile_dimensions, emit_reduction_element); + })); + + KernelSupportLibrary ksl(&b_); + for (const HloReduceInstruction* reduce : reductions) { + for (int partial_result_idx = 0; + partial_result_idx < reduction_info.GetNumPartialResults(); + ++partial_result_idx) { + if (codegen_state.IsRowReduction()) { + EmitReductionOutputForRowReduction(tiling_kernel_info, codegen_state, + index_ty, result_ir_arrays, reduce, + partial_result_idx); + } else { + EmitReductionOutputForColumnReduction(tiling_kernel_info, codegen_state, + index_ty, result_ir_arrays, + reduce, partial_result_idx); + } + } + } + + return OkStatus(); +} + +Status IrEmitterUnnested::EmitUnnestedReduction( + mlir::lmhlo::FusionOp fusion, HloFusionAnalysis& fusion_analysis) { + auto* reduction_codegen_info = fusion_analysis.GetReductionCodegenInfo(); + // Set flag to false as Reduction has it's own custom logic of choosing a + // block size. + TF_ASSIGN_OR_RETURN(auto launch_dimensions, + fusion_analysis.GetLaunchDimensions( + /*use_experimental_block_size=*/false)); + + VLOG(3) << "Launch dimensions of " + << mlir::mhlo::GetDebugNameFromLocation(fusion.getLoc()) << ": " + << launch_dimensions.ToString(); + if (!reduction_codegen_info->IsRaceFree()) { + absl::Span fusion_roots = + fusion_analysis.fusion_roots(); + for (int i = 0; i < fusion_roots.size(); ++i) { + if (IsReductionFromOrToContiguousDimensions(*fusion_roots[i])) { + TF_RETURN_IF_ERROR(BuildFusedInitializerThunk(fusion, i)); + } + } + } + + TF_ASSIGN_OR_RETURN( + std::optional> opt_ir_arrays, + BuildKernelThunkForFusion(fusion, launch_dimensions)); + if (!opt_ir_arrays.has_value()) { + // The kernel was reused, no need to emit code. + return OkStatus(); + } + std::vector& ir_arrays = opt_ir_arrays.value(); + + FusedIrEmitter fused_emitter(elemental_emitter_); + const HloComputation* fused_computation = fusion_analysis.fused_computation(); + CHECK_LT(fused_computation->num_parameters(), ir_arrays.size()); + for (int i = 0; i < fused_computation->num_parameters(); i++) { + llvm_ir::IrArray ir_array = ir_arrays[i]; + HloInstruction* fused_operand = fused_computation->parameter_instruction(i); + fused_emitter.BindGenerator( + *fused_operand, + [this, ir_array, fused_operand](const llvm_ir::IrArray::Index& index) { + return ir_array.EmitReadArrayElement(index, &b_, + fused_operand->name()); + }); + } + + // Get outputs. + ReductionOutputMap result_ir_arrays; + + // Skip all parameter buffers first. + int ir_arrays_idx = fused_computation->num_parameters(); + for (HloInstruction* root : fusion_analysis.fusion_roots()) { + int get_num_results = GetNumOutputs(root->shape()); + result_ir_arrays[root] = + absl::MakeSpan(ir_arrays).subspan(ir_arrays_idx, get_num_results); + ir_arrays_idx += get_num_results; + } + + KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll); + + // Use raw block_id_y to select the i-th parallel reduction to run. Using + // block_id_y instead of block_id_x simplifies the index calculation + // for reduction code generation as the block_id_y is orthogonal to + // the indices used within the reductions. + const std::vector>& instr_index_groups = + reduction_codegen_info->GetIndexGroups(); + Shape reduce_operand_shape = reduction_codegen_info->GetReduceOperandShape(); + + llvm::CallInst* raw_block_id_y = gpu::EmitCallToTargetIntrinsic( + gpu::TargetIntrinsicID::kBlockIdy, {}, {}, &b_); + llvm_ir::AddRangeMetadata(0, instr_index_groups.size(), + llvm::cast(raw_block_id_y)); + for (int i = 0; i < instr_index_groups.size(); ++i) { + TF_RETURN_IF_ERROR(ksl.IfWithStatus( + StrCat("reduce-group-", i), + b_.CreateICmpEQ(raw_block_id_y, b_.getInt32(i)), [&] { + return EmitIRForReduction( + fusion, instr_index_groups[i], fused_emitter, result_ir_arrays, + *reduction_codegen_info, reduce_operand_shape); + })); + } + + return OkStatus(); +} + // Emits code for slices based on the below structure. An if statement with // a guarding condition is generated for each ROOT slice. // diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index 905ed163d64ead..a55276d9dcdb4f 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -324,6 +324,75 @@ class IrEmitterUnnested : public IrEmitter { const ReductionCodegenInfo& reduction_info, const ExtraOutputGensMap& extra_output_gens); + // Generates code for reduction to contiguous dimensions. + // + // Row reduction uses the following algorithm described in CUDA-like + // pseudocode: + // + // ``` + // __global__ void reduce(int num_rows, float *in, float out) { + // __shared__ float[32] cache; + // int offset = blockDim.x * blockIdx.x + threadIdx.x; + // if (offset >= num_rows) return; + // int tile_bound = std::min(offset + kTileSizeX, num_rows); + // float accum = 0; + // for (int i=offset; i reduce_instr_index_group, + FusedIrEmitter& fused_emitter); + + // Wraps up the code generation for a tile block of a reduction kernel: + // write the calculated output into the output tensor. + void EmitReductionOutput( + llvm::Type* index_ty, mlir::lmhlo::FusionOp fusion, + absl::Span reduce_instr_index_group, + const ReductionOutputMap& result_ir_arrays, + const ReductionCodegenState& reduction_codegen_state, + const TilingKernelInfo& tiling_kernel_info); + + // Returns the address to write the reduction output to. + llvm::Value* GetOutputAddressForReduction( + int partial_result_idx, llvm::Type* index_ty, + const ReductionCodegenState& reduction_codegen_state, + const TilingKernelInfo& tiling_kernel_info, + const ReductionOutputMap& output_arrays, + const HloReduceInstruction* reduction, int output_idx); + + // Performs the actual write of the reduction result. + using TypedPointer = std::pair; + void WriteReductionOutput( + llvm::Type* index_ty, + const ReductionCodegenState& reduction_codegen_state, + const TilingKernelInfo& tiling_kernel_info, + const ReductionOutputMap& output_arrays, + const HloReduceInstruction* reduction, int partial_result_idx, + const absl::Span values); + + // `current_output`: the value the tile has calculated. + // `output_address`: address where the output value has to be written. + void EmitReductionOutputForRowReduction( + const TilingKernelInfo& tiling_kernel_info, + const ReductionCodegenState& reduction_codegen_state, + llvm::Type* index_ty, const ReductionOutputMap& output_arrays, + const HloReduceInstruction* reduction, int partial_result_idx); + + // Same arguments as EmitReductionOutputForRowReduction. + void EmitReductionOutputForColumnReduction( + const TilingKernelInfo& tiling_kernel_info, + const ReductionCodegenState& reduction_codegen_state, + llvm::Type* index_ty, const ReductionOutputMap& output_arrays, + const HloReduceInstruction* reduction, int partial_result_idx); + + // Emits code for reductions in the output_instructions. + Status EmitIRForReduction(mlir::lmhlo::FusionOp fusion, + absl::Span instr_index_group, + FusedIrEmitter& fused_emitter, + const ReductionOutputMap& result_ir_arrays, + const ReductionCodegenInfo& reduction_info, + const Shape& input_shape); + + // Generate a single element of the tile (update the accumulator state) for a + // given reducer of index `i`. + void GenerateElementForReducer( + const HloReduceInstruction* reduction, llvm::Value* partial_result_index, + const ReductionCodegenState& codegen_state, + const llvm_ir::IrArray::Index& index_without_linear, + const llvm_ir::IrArray::Index& input_index, int num_partial_results, + const ReductionOutputMap& result_ir_arrays); + + // Emits shuffle-down reduction for the `partial_result_address` using the + // reduction computation `reducer`, writes output into + // `partial_result_address`. + // + // Multiple partial_result_address inputs happen when doing variadic + // reduction: each one should get the output value. + void EmitFullWarpShuffleDownLoopForReduce( + const HloComputation* reducer, + absl::Span partial_result_addresses, + int threads_per_block, int num_results_per_warp = 1); + // Allocates a shared tile of given dimensions, applying scaling specified in // tilng_scheme as a major-most dimension to avoid collisions. llvm::GlobalVariable* AllocateShared( @@ -472,8 +618,20 @@ class IrEmitterUnnested : public IrEmitter { mlir::Operation* op, mlir::ValueRange needed_operands, const LaunchDimensions& launch_dimensions); + // Returns a thunk that, given a reduce or select-and-scatter op, + // initializes its memory to the appropriate initial value. + std::unique_ptr BuildConstantInitializerThunk( + mlir::Operation* op, absl::Span init_value, + mlir::Value dest, const BufferAllocation::Slice& dest_slice, + const Shape& output_shape); + + StatusOr> TryBuildConstantInitializerThunk( + mlir::Operation* op, mlir::Value init_value, mlir::Value dest); + Status BuildInitializerThunk(mlir::Operation* op, mlir::Value init_value, mlir::Value dest); + Status BuildFusedInitializerThunk(mlir::lmhlo::FusionOp fusion, + int output_index); // Returns a WhileThunk that invokes thunk sequences for 'condition' and // 'body' sub-computations of while instruction 'hlo'. From 943e7f07294f42d900e697bd4a77c783dfa10676 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Tue, 25 Jul 2023 01:25:09 -0700 Subject: [PATCH 095/410] [xla:gpu] Disable multi-threading in LLJIT PiperOrigin-RevId: 550806039 --- tensorflow/compiler/xla/runtime/execution_engine.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/compiler/xla/runtime/execution_engine.cc b/tensorflow/compiler/xla/runtime/execution_engine.cc index 56915cfc3d4654..1de98620562de4 100644 --- a/tensorflow/compiler/xla/runtime/execution_engine.cc +++ b/tensorflow/compiler/xla/runtime/execution_engine.cc @@ -318,6 +318,7 @@ ExecutionEngine::CreateFromModule(std::unique_ptr ctx, .setCompileFunctionCreator(compile_function_creator) .setObjectLinkingLayerCreator(obj_layer_creator) .setExecutorProcessControl(std::move(*executorProcessControl)) + .setNumCompileThreads(0) // disable multi-threading .create(); if (auto err = jit.takeError()) From 1dec98e327ea8d1a6d07a678888d90d3aaca1745 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 25 Jul 2023 02:02:36 -0700 Subject: [PATCH 096/410] Update GraphDef version to 1568. PiperOrigin-RevId: 550813708 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index ea82e47c14ad5e..046d4cc8289420 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 1567 // Updated: 2023/7/24 +#define TF_GRAPH_DEF_VERSION 1568 // Updated: 2023/7/25 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From cbf699b94ca26baf24165239a9feba8db06205e4 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 25 Jul 2023 02:02:39 -0700 Subject: [PATCH 097/410] compat: Update forward compatibility horizon to 2023-07-25 PiperOrigin-RevId: 550813723 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index ebb0874e051e4e..a0b5b73528629f 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 7, 24) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 7, 25) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From 80bbe8dd2b0248b895484252ee63a9c6a5c50984 Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Tue, 25 Jul 2023 02:19:01 -0700 Subject: [PATCH 098/410] Extract reduction fusion code from ir_emitter_unnested. PiperOrigin-RevId: 550818274 --- tensorflow/compiler/xla/service/gpu/BUILD | 4 + .../compiler/xla/service/gpu/fusions/BUILD | 46 + .../xla/service/gpu/fusions/fusion_emitter.cc | 4 +- .../xla/service/gpu/fusions/fusion_emitter.h | 9 + .../xla/service/gpu/fusions/fusions.cc | 4 + .../xla/service/gpu/fusions/reduction.cc | 959 ++++++++++++++++++ .../xla/service/gpu/fusions/reduction.h | 120 +++ .../xla/service/gpu/fusions/thunk_util.cc | 120 +++ .../xla/service/gpu/fusions/thunk_util.h | 37 + .../xla/service/gpu/fusions/tiling_util.cc | 33 + .../xla/service/gpu/fusions/tiling_util.h | 5 + .../compiler/xla/service/gpu/ir_emitter.cc | 4 + .../xla/service/gpu/ir_emitter_unnested.cc | 871 +--------------- .../xla/service/gpu/ir_emitter_unnested.h | 158 --- 14 files changed, 1349 insertions(+), 1025 deletions(-) create mode 100644 tensorflow/compiler/xla/service/gpu/fusions/reduction.cc create mode 100644 tensorflow/compiler/xla/service/gpu/fusions/reduction.h create mode 100644 tensorflow/compiler/xla/service/gpu/fusions/thunk_util.cc create mode 100644 tensorflow/compiler/xla/service/gpu/fusions/thunk_util.h diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 13ade65c0b8a46..4bf5b9d4159a14 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -309,6 +309,7 @@ cc_library( ":fft_thunk", ":gemm_thunk", ":gpu_asm_opts_util", + ":gpu_constants", ":gpu_conv_runner", ":gpu_device_info", ":gpu_executable", @@ -341,8 +342,10 @@ cc_library( "//tensorflow/compiler/xla/service:custom_call_target_registry", "//tensorflow/compiler/xla/service:name_uniquer", "//tensorflow/compiler/xla/service/gpu/fusions", + "//tensorflow/compiler/xla/service/gpu/fusions:thunk_util", "//tensorflow/compiler/xla/service/gpu/fusions:tiling_util", "//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util", + "//tensorflow/compiler/xla/service/llvm_ir:dynamic_update_slice_util", "//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter", "//tensorflow/compiler/xla/service/llvm_ir:ir_array", "//tensorflow/compiler/xla/service/llvm_ir:kernel_support_library", @@ -406,6 +409,7 @@ cc_library( ":backend_configs_cc", ":hlo_fusion_analysis", ":hlo_to_ir_bindings", + ":ir_emission_utils", ":ir_emitter_context", ":kernel_reuse_cache", ":target_util", diff --git a/tensorflow/compiler/xla/service/gpu/fusions/BUILD b/tensorflow/compiler/xla/service/gpu/fusions/BUILD index 0673e218c07396..644544246200ce 100644 --- a/tensorflow/compiler/xla/service/gpu/fusions/BUILD +++ b/tensorflow/compiler/xla/service/gpu/fusions/BUILD @@ -61,6 +61,7 @@ cc_library( ":fusion_emitter", ":in_place_dynamic_update_slice", ":loop", + ":reduction", "//tensorflow/compiler/xla/hlo/ir:hlo", "//tensorflow/compiler/xla/mlir_hlo:lhlo", "//tensorflow/compiler/xla/service:elemental_ir_emitter", @@ -101,3 +102,48 @@ cc_library( "@llvm-project//llvm:ir_headers", ], ) + +cc_library( + name = "reduction", + srcs = ["reduction.cc"], + hdrs = ["reduction.h"], + deps = [ + ":fusion_emitter", + ":thunk_util", + ":tiling_util", + "//tensorflow/compiler/xla/hlo/ir:hlo", + "//tensorflow/compiler/xla/service/gpu:gpu_executable", + "//tensorflow/compiler/xla/service/gpu:hlo_fusion_analysis", + "//tensorflow/compiler/xla/service/gpu:ir_emission_utils", + "//tensorflow/compiler/xla/service/gpu:ir_emitter", + "//tensorflow/compiler/xla/service/gpu:ir_emitter_context", + "//tensorflow/compiler/xla/service/gpu:kernel_reuse_cache", + "//tensorflow/compiler/xla/service/gpu:parallel_loop_emitter", + "//tensorflow/compiler/xla/service/gpu:target_util", + "//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter", + "//tensorflow/compiler/xla/service/llvm_ir:ir_array", + "//tensorflow/compiler/xla/service/llvm_ir:kernel_support_library", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "//tensorflow/compiler/xla/translate/mhlo_to_hlo:location_exporter", + "@llvm-project//llvm:ir_headers", + ], +) + +cc_library( + name = "thunk_util", + srcs = ["thunk_util.cc"], + hdrs = ["thunk_util.h"], + visibility = ["//tensorflow/compiler/xla/service/gpu:__subpackages__"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/service/gpu:gpu_executable", + "//tensorflow/compiler/xla/service/gpu:ir_emission_utils", + "//tensorflow/compiler/xla/service/gpu:ir_emitter_context", + "//tensorflow/compiler/xla/service/gpu:thunk", + "//tensorflow/compiler/xla/translate/hlo_to_mhlo:hlo_utils", + "@com_google_absl//absl/types:span", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:MemRefDialect", + ], +) diff --git a/tensorflow/compiler/xla/service/gpu/fusions/fusion_emitter.cc b/tensorflow/compiler/xla/service/gpu/fusions/fusion_emitter.cc index b4c60d1c2bc334..ab839a522c5df0 100644 --- a/tensorflow/compiler/xla/service/gpu/fusions/fusion_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/fusions/fusion_emitter.cc @@ -74,6 +74,8 @@ void AnnotateKernelLaunchDimensions(const LaunchDimensions& launch_dims, } } +} // namespace + std::tuple, std::vector> BuildKernelPrototype(IrEmitterContext& ir_emitter_context, @@ -171,8 +173,6 @@ BuildKernelPrototype(IrEmitterContext& ir_emitter_context, return {kernel, std::move(inputs), std::move(outputs)}; } -} // namespace - StatusOr KernelFusionEmitterBase::Emit( KernelReuseCache& kernel_cache, llvm::IRBuilder<>* builder) const { std::string suggested_kernel_name = GetIrNameFromLoc(fusion_op_->getLoc()); diff --git a/tensorflow/compiler/xla/service/gpu/fusions/fusion_emitter.h b/tensorflow/compiler/xla/service/gpu/fusions/fusion_emitter.h index 0fdc738a29eee5..b31247bd987da6 100644 --- a/tensorflow/compiler/xla/service/gpu/fusions/fusion_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/fusions/fusion_emitter.h @@ -80,6 +80,15 @@ class KernelFusionEmitterBase : public FusionInterface { const HloFusionInstruction& fusion_; }; +std::tuple, + std::vector> +BuildKernelPrototype(IrEmitterContext& ir_emitter_context, + const std::string& suggested_name, + absl::Span arguments, + size_t num_inputs, + const LaunchDimensions& launch_dimensions, + llvm::IRBuilder<>* builder); + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/fusions/fusions.cc b/tensorflow/compiler/xla/service/gpu/fusions/fusions.cc index a6c800556c5c7a..34723c1ec64eef 100644 --- a/tensorflow/compiler/xla/service/gpu/fusions/fusions.cc +++ b/tensorflow/compiler/xla/service/gpu/fusions/fusions.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/fusions/copy.h" #include "tensorflow/compiler/xla/service/gpu/fusions/in_place_dynamic_update_slice.h" #include "tensorflow/compiler/xla/service/gpu/fusions/loop.h" +#include "tensorflow/compiler/xla/service/gpu/fusions/reduction.h" #include "tensorflow/compiler/xla/service/gpu/hlo_fusion_analysis.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" @@ -49,6 +50,9 @@ std::optional> GetFusionEmitter( ElementalIrEmitter& elemental_emitter, mlir::lmhlo::FusionOp fusion_op, const HloFusionInstruction& fusion) { switch (analysis.GetEmitterFusionKind()) { + case HloFusionAnalysis::EmitterFusionKind::kReduction: + return std::make_unique( + ir_emitter_context, elemental_emitter, fusion_op, fusion, analysis); case HloFusionAnalysis::EmitterFusionKind::kLoop: { bool is_single = IsSingleInstructionFusion(fusion_op); if (!is_single && CanEmitFusedDynamicUpdateSliceInPlaceForGpu( diff --git a/tensorflow/compiler/xla/service/gpu/fusions/reduction.cc b/tensorflow/compiler/xla/service/gpu/fusions/reduction.cc new file mode 100644 index 00000000000000..dfc1fbaa3f94d2 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/fusions/reduction.cc @@ -0,0 +1,959 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/xla/service/gpu/fusions/reduction.h" + +#include +#include + +#include "llvm/IR/IRBuilder.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/gpu/fusions/fusion_emitter.h" +#include "tensorflow/compiler/xla/service/gpu/fusions/thunk_util.h" +#include "tensorflow/compiler/xla/service/gpu/fusions/tiling_util.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h" +#include "tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h" +#include "tensorflow/compiler/xla/service/gpu/kernel_reuse_cache.h" +#include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h" +#include "tensorflow/compiler/xla/service/gpu/target_util.h" +#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" +#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" +#include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/translate/mhlo_to_hlo/location_exporter.h" + +namespace xla { +namespace gpu { +namespace { + +using TypedPointer = std::pair; + +// Fusion root -> array of indexes, one per reduction output. +using ReductionOutputMap = + ConstHloInstructionMap>; +using ExtraOutputGensMap = ConstHloInstructionMap; + +void MaybeEmitFenceForAMDGPU(llvm::IRBuilder<>* builder, + IrEmitterContext& ir_emitter_context) { + auto* module = builder->GetInsertBlock()->getModule(); + if (IsAMDGPU(module) && + ir_emitter_context.rocm_compute_capability().gcn_arch_name().substr( + 0, 6) == "gfx90a") { + builder->CreateFence( + llvm::AtomicOrdering::SequentiallyConsistent, + builder->getContext().getOrInsertSyncScopeID("workgroup")); + } +} + +void EmitSyncThreads(llvm::IRBuilder<>* builder, + IrEmitterContext& ir_emitter_context) { + MaybeEmitFenceForAMDGPU(builder, ir_emitter_context); + EmitCallToTargetIntrinsic(TargetIntrinsicID::kBarrierId, {}, {}, builder); +} + +// For a row reduction, returns the number of rows we can process in parallel +// per warp. +int RowReductionGetRowsPerWarp(int reduced_dimension_size) { + if (WarpSize() % reduced_dimension_size != 0 || + reduced_dimension_size >= WarpSize()) { + return 1; + } + return WarpSize() / reduced_dimension_size; +} + +llvm::GlobalVariable* AllocateShared( + llvm::IRBuilder<>* builder, const TilingScheme& tiling_scheme, + llvm::Type* element_type, + absl::Span dimensions_major_to_minor, + absl::string_view buffer_name) { + CHECK(!dimensions_major_to_minor.empty()); + llvm::Type* ty = element_type; + for (auto dim : llvm::reverse(dimensions_major_to_minor)) { + ty = llvm::ArrayType::get(ty, dim); + } + ty = llvm::ArrayType::get(ty, tiling_scheme.GetThreadIdScalingFactor()); + return llvm_ir::AllocateSharedMemoryTile( + builder->GetInsertBlock()->getModule(), ty, buffer_name); +} + +Status EmitExtraOutputsForReduce(llvm::IRBuilder<>* builder, + const Shape& reduction_operand_shape, + const ReductionOutputMap& result_ir_arrays, + const llvm_ir::IrArray::Index& index, + const ReductionCodegenInfo& reduction_info, + const ExtraOutputGensMap& extra_output_gens) { + if (extra_output_gens.empty()) { + return OkStatus(); + } + + // Compute all extra output values before writing them. This avoids + // overwriting aliased input/output buffers before all reads occurred. + std::vector> + extra_output_ir_values; + extra_output_ir_values.reserve(extra_output_gens.size()); + + auto get_index = [&](const HloInstruction* instr) { + const Shape& s = instr->shape(); + return ShapeUtil::EqualIgnoringElementType(reduction_operand_shape, s) + ? index + : index.SourceIndexOfBitcast(reduction_operand_shape, s, + builder); + }; + + for (const auto& [instr, generator] : extra_output_gens) { + TF_ASSIGN_OR_RETURN(llvm::Value* const extra_output_ir_value, + generator(get_index(instr))); + extra_output_ir_values.emplace_back(instr, extra_output_ir_value); + } + + for (const auto& [instr, generator] : extra_output_ir_values) { + absl::Span result_ir = result_ir_arrays.at(instr); + CHECK_EQ(result_ir.size(), 1); + result_ir[0].EmitWriteArrayElement( + get_index(instr), generator, builder, /*use_linear_index=*/ + reduction_info.GetNumPartialResults() == 1); + } + return OkStatus(); +} + +ReductionCodegenState GenerateReductionCodegenState( + llvm::IRBuilder<>* builder, mlir::lmhlo::FusionOp fusion, + const ReductionCodegenInfo& reduction_info, + absl::Span reduce_instr_index_group, + FusedIrEmitter& fused_emitter) { + ReductionCodegenState reduction_codegen_state(reduction_info); + VLOG(10) << "Emit prologue for reduction: " << llvm_ir::DumpToString(fusion); + + for (const HloReduceInstruction* reduce_hlo : reduce_instr_index_group) { + int num_partial_results = reduction_codegen_state.GetNumPartialResults(); + int num_outputs = reduce_hlo->shape().IsTuple() + ? reduce_hlo->shape().tuple_shapes_size() + : 1; + for (int op_result_idx = 0; op_result_idx < num_outputs; op_result_idx++) { + Shape result_shape = reduce_hlo->shape().IsTuple() + ? reduce_hlo->shape().tuple_shapes(op_result_idx) + : reduce_hlo->shape(); + + llvm::Type* element_type = llvm_ir::PrimitiveTypeToIrType( + result_shape.element_type(), builder->GetInsertBlock()->getModule()); + llvm::AllocaInst* reduction_input_address = + llvm_ir::EmitAllocaAtFunctionEntry( + element_type, "reduction_input_address", builder); + + llvm::AllocaInst* partial_result_address = + llvm_ir::EmitAllocaAtFunctionEntryWithCount( + element_type, + /*element_count=*/builder->getInt32(num_partial_results), + "partial_reduction_result", builder); + + const HloInstruction* init_value = + reduce_hlo->init_values()[op_result_idx]; + + // Initialize the partial result with the initial value of the reduction. + llvm::Value* init_ir_value = (*fused_emitter.GetGenerator( + *init_value))(llvm_ir::IrArray::Index(builder->getInt32Ty())) + .value(); + + for (int i = 0; i < num_partial_results; ++i) { + builder->CreateStore( + init_ir_value, builder->CreateInBoundsGEP( + partial_result_address->getAllocatedType(), + partial_result_address, {builder->getInt32(i)})); + } + + const TilingScheme& tiling_scheme = + reduction_codegen_state.GetTilingScheme(); + int64_t num_threads_x = + tiling_scheme.GetNumThreadsFor(TilingScheme::DimX); + llvm::GlobalVariable* shared_cache = [&]() -> llvm::GlobalVariable* { + if (reduction_codegen_state.IsRowReduction()) { + // Multi-row reductions do not use shared memory. + if (RowReductionGetRowsPerWarp(tiling_scheme.GetDimsInElems()[2]) > + 1) { + return nullptr; + } + // Allocate __shared__ + // cache[num_partial_results][num_warps][scaling_factor]. + CHECK_EQ(tiling_scheme.GetNumThreadsPerBlock() % WarpSize(), 0); + int num_warps = tiling_scheme.GetNumThreadsPerBlock() / WarpSize(); + return AllocateShared(builder, tiling_scheme, element_type, + {num_partial_results, num_warps}, + "shared_cache"); + } else { + // Allocate __shared__ + // cache[num_threads][num_threads + 1], where + // num_threads == num_threads_x == num_threads_y. The "+1" is used to + // avoid bank conflicts. + // + // (Although each thread produces num_partial_results results, we + // don't need that much cache: Only one result is live at a time.) + CHECK_EQ(num_threads_x, + tiling_scheme.GetNumThreadsFor(TilingScheme::DimY)); + return AllocateShared(builder, tiling_scheme, element_type, + {num_threads_x, num_threads_x + 1}, + "shared_cache"); + } + }(); + + llvm_ir::ElementGenerator input_gen = + *fused_emitter.GetGenerator(*reduce_hlo->inputs()[op_result_idx]); + reduction_codegen_state.SetCalculationStateFor( + {shared_cache, init_ir_value, partial_result_address, + reduction_input_address, input_gen}, + reduce_hlo, op_result_idx); + } + } + + return reduction_codegen_state; +} + +// Generate a single element of the tile (update the accumulator state) for a +// given reducer of index `i`. +void GenerateElementForReducer( + llvm::IRBuilder<>* builder, IrEmitterContext& ir_emitter_context, + const HloReduceInstruction* reduction, llvm::Value* partial_result_index, + const ReductionCodegenState& codegen_state, + const llvm_ir::IrArray::Index& index_without_linear, + const llvm_ir::IrArray::Index& input_index, int num_partial_results, + const ReductionOutputMap& result_ir_arrays) { + HloComputation* reducer = reduction->to_apply(); + CHECK_EQ(reducer->num_parameters() % 2, 0); + + absl::InlinedVector reduction_accumulators; + absl::InlinedVector reduction_input_value; + for (int red_idx = 0; red_idx < reducer->num_parameters() / 2; red_idx++) { + const ReductionCodegenState::ReductionCalculationState& state = + codegen_state.GetCalculationStateFor(reduction, red_idx); + + llvm::AllocaInst* input_address = state.input_address; + llvm::AllocaInst* partial_reduction_result_address = + state.partial_result_address; + llvm::Value* const input_ir_value = *state.input_gen( + num_partial_results > 1 ? index_without_linear : input_index); + builder->CreateStore(input_ir_value, input_address); + llvm::Value* partial_result_address = builder->CreateInBoundsGEP( + partial_reduction_result_address->getAllocatedType(), + partial_reduction_result_address, {partial_result_index}); + reduction_accumulators.push_back(partial_result_address); + reduction_input_value.push_back(input_address); + } + + absl::InlinedVector reduction_params; + for (llvm::Value* acc : reduction_accumulators) { + reduction_params.push_back(acc); + } + for (llvm::Value* value : reduction_input_value) { + reduction_params.push_back(value); + } + + // Emit a call to the variadic reducer. Since it may be returning a + // tuple, we can't return it directly as a value. Instead, before + // the call, we create N (N = # arguments in the tuple) allocas, one + // for each returned argument, then when we make the call we pass N + // pointers as last parameters, the called computation writes into + // those pointers, and we have returned values on the stack (as well + // as pointers to them). + StatusOr> returned_scalars = + CallNestedComputationWithScalarAddrs(builder, ir_emitter_context, + *reducer, reduction_params); + TF_CHECK_OK(returned_scalars.status()); + + for (int i = 0; i < returned_scalars->size(); i++) { + builder->CreateStore(returned_scalars->at(i), reduction_accumulators[i]); + } +} + +// Emits shuffle-down reduction for the `partial_result_address` using the +// reduction computation `reducer`, writes output into +// `partial_result_address`. +// +// Multiple partial_result_address inputs happen when doing variadic +// reduction: each one should get the output value. +void EmitFullWarpShuffleDownLoopForReduce( + llvm::IRBuilder<>* builder, IrEmitterContext& ir_emitter_context, + const HloComputation* reducer, + absl::Span partial_result_addresses, + int threads_per_block, int num_results_per_warp = 1) { + // This only works when the block size is a multiple of 32 threads. + + // We check this here as a mistake in the number of threads per + // block is very hard to detect. + CHECK_EQ(threads_per_block % 32, 0); + CHECK_EQ(WarpSize() % num_results_per_warp, 0); + + for (int distance = 16 / num_results_per_warp; distance >= 1; distance /= 2) { + absl::InlinedVector reduction_params; + + for (auto acc : partial_result_addresses) { + reduction_params.push_back(acc.first); + } + + for (auto [partial_result_address, element_type] : + partial_result_addresses) { + int bit_width = llvm_ir::GetSizeInBits(element_type); + llvm::Value* result_from_other_lane = llvm_ir::EmitAllocaAtFunctionEntry( + element_type, "result_from_other_lane", builder); + + reduction_params.push_back(result_from_other_lane); + + // Bitcast cannot be applied to aggregate types (even packed ones), so + // we bitcast addresses of load/store to intN* of the same bit-width. + llvm::Type* shuffled_value_type = element_type->isStructTy() + ? builder->getIntNTy(bit_width) + : element_type; + auto convert_pointer_for_shuffle = [&](llvm::Value* ptr) { + return builder->CreatePointerBitCastOrAddrSpaceCast( + ptr, shuffled_value_type->getPointerTo()); + }; + + llvm::Value* partial_result = builder->CreateLoad( + shuffled_value_type, + convert_pointer_for_shuffle(partial_result_address), + "partial_reduction_result"); + builder->CreateStore( + EmitFullWarpShuffleDown(partial_result, builder->getInt32(distance), + builder), + convert_pointer_for_shuffle(result_from_other_lane)); + } + + StatusOr> returned_scalars = + CallNestedComputationWithScalarAddrs(builder, ir_emitter_context, + *reducer, reduction_params); + TF_CHECK_OK(returned_scalars.status()); + + for (int i = 0; i < returned_scalars->size(); i++) { + builder->CreateStore(/*Val=*/returned_scalars->at(i), + /*Ptr=*/partial_result_addresses[i].first); + } + } +} + +// Gets the output offset as calculated from thread_id.x (to be applied to the +// offset calculated from block_id and thread_id.y). +llvm::Value* GetStartOffsetX(const TilingScheme& tiling_scheme, + llvm::Value* thread_id_x, llvm::Type* index_ty, + llvm::IRBuilder<>* b) { + int64_t multiplier = + tiling_scheme.GetIndexingOrder() == TilingScheme::StridedIndexingX + ? tiling_scheme.GetVectorSize() + : tiling_scheme.GetTileSizeFor(TilingScheme::DimX); + return b->CreateMul(thread_id_x, + llvm::ConstantInt::get(index_ty, multiplier)); +} + +llvm::Value* GetOutputAddressForReduction( + llvm::IRBuilder<>* builder, int partial_result_idx, llvm::Type* index_ty, + const ReductionCodegenState& reduction_codegen_state, + const TilingKernelInfo& tiling_kernel_info, + const ReductionOutputMap& output_arrays, + const HloReduceInstruction* reduction, int output_idx) { + auto constant = [&](uint64_t c) -> llvm::Constant* { + return llvm::ConstantInt::get(index_ty, c); + }; + + const TilingScheme& tiling_scheme = reduction_codegen_state.GetTilingScheme(); + const TilingThreadIdInfo& thread_id_info = tiling_kernel_info.thread_id_info; + + llvm_ir::IrArray::Index start_offset = [&] { + llvm::Value* x_loc = thread_id_info.thread_id_x; + llvm::Value* y_loc = thread_id_info.thread_id_y; + if (!reduction_codegen_state.IsRowReduction()) { + std::swap(x_loc, y_loc); + } + llvm::Value* start_offset_x = + GetStartOffsetX(tiling_scheme, x_loc, index_ty, builder); + return tiling_kernel_info.tile_origin + .AddOffsetToDim(y_loc, TilingScheme::DimY, builder) + .AddOffsetToDim(start_offset_x, TilingScheme::DimX, builder); + }(); + + const llvm_ir::IrArray& output_array = + output_arrays.at(reduction)[output_idx]; + const Shape& operand_shape = reduction->inputs()[output_idx]->shape(); + Shape reduction_kept_element_shape = + ShapeUtil::DeleteDimensions(reduction->dimensions(), operand_shape); + + // Given the IrArray index of a reduction input, returns the linear address of + // the reduction output as if the reduction were going to keep the input shape + // with the dimensions being reduced moved. + llvm::Value* untransposed_output_linear_address = [&] { + const llvm_ir::IrArray::Index index = start_offset.AddOffsetToDim( + constant(partial_result_idx), TilingScheme::DimX, builder); + if (reduction_codegen_state.IsRowReduction()) { + // For row-reduction, y-coordinate determines which row we write into. + return index[TilingScheme::DimY]; + } + // For column reduction, we get the transposed address. + absl::Span dims_in_elem = tiling_scheme.GetDimsInElems(); + llvm::Value* x_dim_size = + index.GetConstantWithIndexType(dims_in_elem[TilingScheme::DimX]); + llvm::Value* x_block_offset = + builder->CreateMul(index[TilingScheme::DimZ], x_dim_size); + return builder->CreateAdd(x_block_offset, index[TilingScheme::DimX]); + }(); + + // A reduction is allowed to transpose its output. For example, suppose + // we are reducing the second dimension of f32[10,20,30]{3,2,1}. We are + // allowed to produce as output either f32[10,30]{1,0} (no transpose) or + // f32[10,30]{0,1} (transposing the two output dims). + // + // At this point in the function we have a "partial sum" of input elements + // (stored in partial_result_addresses), and we need to accumulate it into + // the correct output element. + llvm_ir::IrArray::Index element_index( + /*linear=*/untransposed_output_linear_address, + reduction_kept_element_shape, builder); + llvm_ir::IrArray::Index output_index(element_index.multidim(), + output_array.GetShape(), + element_index.GetType()); + + return output_array.EmitArrayElementAddress(output_index, builder, + "output_element_address"); +} + +// Wraps up the code generation for a tile block of a reduction kernel: +// write the calculated output into the output tensor. +void WriteReductionOutput(llvm::IRBuilder<>* builder, + IrEmitterContext& ir_emitter_context, + llvm::Type* index_ty, + const ReductionCodegenState& reduction_codegen_state, + const TilingKernelInfo& tiling_kernel_info, + const ReductionOutputMap& output_arrays, + const HloReduceInstruction* reduction, + int partial_result_idx, + const absl::Span values) { + const HloComputation* reducer = reduction->to_apply(); + for (const auto& [oidx, typed_ptr] : llvm::enumerate(values)) { + auto [output_ptr, type] = typed_ptr; + llvm::Value* output_address = GetOutputAddressForReduction( + builder, partial_result_idx, index_ty, reduction_codegen_state, + tiling_kernel_info, output_arrays, reduction, oidx); + if (reduction_codegen_state.IsRaceFree()) { + builder->CreateStore(builder->CreateLoad(type, output_ptr, "output"), + output_address); + } else { + CHECK_EQ(values.size(), 1); + TF_CHECK_OK(EmitAtomicOperationForNestedComputation( + builder, ir_emitter_context, *reducer, output_address, output_ptr, + type)); + } + } +} + +// `current_output`: the value the tile has calculated. +// `output_address`: address where the output value has to be written. +void EmitReductionOutputForRowReduction( + llvm::IRBuilder<>* builder, IrEmitterContext& ir_emitter_context, + const TilingKernelInfo& tiling_kernel_info, + const ReductionCodegenState& reduction_codegen_state, llvm::Type* index_ty, + const ReductionOutputMap& output_arrays, + const HloReduceInstruction* reduction, int partial_result_idx) { + const HloComputation* reducer = reduction->to_apply(); + const auto& thread_id_info = tiling_kernel_info.thread_id_info; + auto constant = [&](uint64_t c) -> llvm::Constant* { + return llvm::ConstantInt::get(index_ty, c); + }; + auto is_zero = [&](llvm::Value* value) { + return builder->CreateICmpEQ(value, constant(0)); + }; + + int num_outputs = reducer->num_parameters() / 2; + const TilingScheme& tiling_scheme = reduction_codegen_state.GetTilingScheme(); + absl::InlinedVector current_outputs; + for (int output_idx = 0; output_idx < num_outputs; output_idx++) { + const ReductionCodegenState::ReductionCalculationState& state = + reduction_codegen_state.GetCalculationStateFor(reduction, output_idx); + current_outputs.push_back( + {builder->CreateInBoundsGEP( + state.partial_result_address->getAllocatedType(), + state.partial_result_address, {constant(partial_result_idx)}, + "current_output"), + state.partial_result_address->getAllocatedType()}); + } + + int reduced_dimension_size = tiling_scheme.GetDimsInElems()[2]; + int num_rows_per_warp = RowReductionGetRowsPerWarp(reduced_dimension_size); + EmitFullWarpShuffleDownLoopForReduce( + builder, ir_emitter_context, reducer, absl::MakeSpan(current_outputs), + tiling_scheme.GetNumThreadsPerBlockPhysical(), num_rows_per_warp); + + KernelSupportLibrary ksl(builder); + llvm::Value* warp_id = + builder->CreateUDiv(thread_id_info.thread_id_x, constant(WarpSize())); + + auto emit_write_output = [&](llvm::Value* write_condition, + const absl::Span values) { + ksl.If("reduction_write_output", write_condition, [&] { + WriteReductionOutput(builder, ir_emitter_context, index_ty, + reduction_codegen_state, tiling_kernel_info, + output_arrays, reduction, partial_result_idx, + values); + }); + }; + + if (num_rows_per_warp > 1) { + llvm::Value* is_writing_thread = is_zero(builder->CreateAnd( + thread_id_info.thread_id_x, constant(reduced_dimension_size - 1))); + emit_write_output(is_writing_thread, current_outputs); + return; + } + + ksl.If("intra_warp_reduce_write", is_zero(thread_id_info.lane_id), [&] { + for (int oidx = 0; oidx < num_outputs; oidx++) { + const ReductionCodegenState::ReductionCalculationState& state = + reduction_codegen_state.GetCalculationStateFor(reduction, oidx); + llvm::Value* shmem_output_addr = thread_id_info.GEPIntoSharedMemory( + builder, state.shared_cache, {constant(partial_result_idx), warp_id}); + builder->CreateStore(builder->CreateLoad(current_outputs[oidx].second, + current_outputs[oidx].first), + shmem_output_addr); + } + }); + + // TODO(cheshire): Don't we want to sync it once for everything in the + // output? Not once per each? + EmitSyncThreads(builder, ir_emitter_context); + ksl.If("inter_warp_reduce", is_zero(warp_id), [&] { + absl::InlinedVector selected_values; + for (int oidx = 0; oidx < num_outputs; oidx++) { + const ReductionCodegenState::ReductionCalculationState& state = + reduction_codegen_state.GetCalculationStateFor(reduction, oidx); + llvm::Value* block_accum_addr = thread_id_info.GEPIntoSharedMemory( + builder, state.shared_cache, + {constant(partial_result_idx), thread_id_info.lane_id}); + + llvm::Type* element_type = + state.partial_result_address->getAllocatedType(); + + // Ensure initial value address is in generic, not scratch. + llvm::Value* initial_value_addr = builder->CreateAddrSpaceCast( + llvm_ir::EmitAllocaAtFunctionEntry(element_type, "initial_value_addr", + builder), + llvm::PointerType::get(element_type, + /*AddressSpace=*/0)); + builder->CreateStore(state.initial_value, initial_value_addr); + + llvm::Value* warp_exists = builder->CreateICmpULT( + thread_id_info.thread_id_x, + constant(tiling_scheme.GetNumThreadsFor(TilingScheme::DimX) / + WarpSize())); + + llvm::Value* selected_value = builder->CreateSelect( + warp_exists, block_accum_addr, initial_value_addr); + + selected_values.push_back({selected_value, element_type}); + } + + // If only one warp is present in the block, then we don't need inter-warp + // reduction. + // TODO(b/241414088) If only warp is present, then inter-warp communication + // using shared memory and synchronization using barrier is also unnecessary + // and should be removed. + if (tiling_scheme.GetNumThreadsPerBlock() > WarpSize()) { + EmitFullWarpShuffleDownLoopForReduce( + builder, ir_emitter_context, reducer, absl::MakeSpan(selected_values), + tiling_scheme.GetNumThreadsPerBlock()); + } + + emit_write_output(is_zero(thread_id_info.thread_id_x), selected_values); + }); +} + +// Same arguments as EmitReductionOutputForRowReduction. +void EmitReductionOutputForColumnReduction( + llvm::IRBuilder<>* builder, IrEmitterContext& ir_emitter_context, + const TilingKernelInfo& tiling_kernel_info, + const ReductionCodegenState& reduction_codegen_state, llvm::Type* index_ty, + const ReductionOutputMap& output_arrays, + const HloReduceInstruction* reduction, int partial_result_idx) { + KernelSupportLibrary ksl(builder); + const HloComputation* reducer = reduction->to_apply(); + const auto& thread_id_info = tiling_kernel_info.thread_id_info; + + auto constant = [&](uint64_t c) -> llvm::Constant* { + return llvm::ConstantInt::get(index_ty, c); + }; + auto is_zero = [&](llvm::Value* value) { + return builder->CreateICmpEQ(value, constant(0)); + }; + const TilingScheme& tiling_scheme = reduction_codegen_state.GetTilingScheme(); + int num_outputs = reducer->num_parameters() / 2; + + // Wait for reads from shmem in the last iteration to complete. (If this is + // slow, we could "double-buffer" by having two shmem buffers and switching + // between them.) + if (partial_result_idx > 0) { + EmitSyncThreads(builder, ir_emitter_context); + } + + // Store the transpose in shared memory. + for (int output_idx = 0; output_idx < num_outputs; output_idx++) { + const ReductionCodegenState::ReductionCalculationState& state = + reduction_codegen_state.GetCalculationStateFor(reduction, output_idx); + llvm::GlobalVariable* shared_cache = state.shared_cache; + llvm::AddrSpaceCastInst* shmem_output_addr = + llvm::cast(thread_id_info.GEPIntoSharedMemory( + builder, shared_cache, + {thread_id_info.thread_id_x, thread_id_info.thread_id_y}, + "shmem_output_address")); + llvm::Value* current_output = builder->CreateInBoundsGEP( + state.partial_result_address->getAllocatedType(), + state.partial_result_address, {constant(partial_result_idx)}, + "current_output"); + + llvm::Value* current_output_value = builder->CreateLoad( + state.partial_result_address->getAllocatedType(), current_output); + builder->CreateStore(current_output_value, shmem_output_addr); + } + + EmitSyncThreads(builder, ir_emitter_context); + + // Get transposed element from shared memory. + absl::InlinedVector shmem_transposed_addrs; + for (int output_idx = 0; output_idx < num_outputs; output_idx++) { + const ReductionCodegenState::ReductionCalculationState& state = + reduction_codegen_state.GetCalculationStateFor(reduction, output_idx); + llvm::AddrSpaceCastInst* shmem_transposed_addr = + llvm::cast(thread_id_info.GEPIntoSharedMemory( + builder, state.shared_cache, + {thread_id_info.thread_id_y, thread_id_info.thread_id_x}, + "shmem_transposed_addr")); + shmem_transposed_addrs.push_back( + {shmem_transposed_addr, llvm::cast( + shmem_transposed_addr->getPointerOperand()) + ->getResultElementType()}); + } + + EmitFullWarpShuffleDownLoopForReduce(builder, ir_emitter_context, reducer, + absl::MakeSpan(shmem_transposed_addrs), + tiling_scheme.GetNumThreadsPerBlock()); + + // Some warps in the block are completely outside of the bound of the + // tensor, so they should not write any output at all. + llvm::Value* has_output = builder->CreateAnd( + builder->CreateICmpULT( + GetStartOffsetX(tiling_scheme, thread_id_info.thread_id_y, index_ty, + builder), + tiling_kernel_info.output_tile_bounds[1]), + builder->CreateICmpULT(thread_id_info.thread_id_x, + tiling_kernel_info.output_tile_bounds[0])); + + ksl.If("reduction_write_output", + builder->CreateAnd(has_output, is_zero(thread_id_info.lane_id)), [&] { + WriteReductionOutput(builder, ir_emitter_context, index_ty, + reduction_codegen_state, tiling_kernel_info, + output_arrays, reduction, partial_result_idx, + shmem_transposed_addrs); + }); +} + +// Emits code for reductions in the output_instructions. +Status EmitIRForReduction(llvm::IRBuilder<>* builder, + IrEmitterContext& ir_emitter_context, + mlir::lmhlo::FusionOp fusion, + absl::Span instr_index_group, + FusedIrEmitter& fused_emitter, + const ReductionOutputMap& result_ir_arrays, + const ReductionCodegenInfo& reduction_info, + const Shape& input_shape) { + std::vector reductions; + ExtraOutputGensMap extra_output_gens; + + for (const HloInstruction* hlo : instr_index_group) { + if (IsReductionFromOrToContiguousDimensions(*hlo)) { + reductions.push_back(Cast(hlo)); + } else { + extra_output_gens[hlo] = *fused_emitter.GetGenerator(*hlo); + } + } + + CHECK(!reductions.empty()) << " expect at least one reduce instructions."; + const TilingScheme& tiling_scheme = reduction_info.GetTilingScheme(); + CHECK_EQ(tiling_scheme.GetNumThreadsPerBlockPhysical() % WarpSize(), 0); + llvm::Type* index_ty = + GetIndexTypeForKernel(fusion, + tiling_scheme.GetNumThreadsPerBlockPhysical() * + tiling_scheme.GetNumberOfBlocksPhysical(), + builder); + ReductionCodegenState codegen_state = GenerateReductionCodegenState( + builder, fusion, reduction_info, reductions, fused_emitter); + + EmitTileElementFunction emit_reduction_element = + [&](const TilingThreadIdInfo& thread_id_info, + const llvm_ir::IrArray::Index& index, llvm::Value* y_loc, + llvm::Value* x_loc) { + llvm_ir::IrArray::Index input_index = GetUnnormalizedIndex( + index, input_shape, builder, + codegen_state.GetTilingScheme().GetDimsInElems()); + llvm::Value* partial_result_index = + codegen_state.IsRowReduction() + ? builder->getInt32(0) + : builder->CreateSub( + x_loc, + GetStartOffsetX(tiling_scheme, thread_id_info.thread_id_x, + index_ty, builder)); + + // Clear the linear index field of the IrArray::Index to enable the use + // of GetElementPointer with array types. This enables the vectorization + // of the computation for different partial results. Use this index if + // 'num_partial_results > 1'. + int num_partial_results = codegen_state.GetNumPartialResults(); + llvm_ir::IrArray::Index index_without_linear{ + input_index.multidim(), input_shape, input_index.GetType()}; + + // Emit code to generate the input and perform the reduction computation + // for each reduction instruction. + for (const HloReduceInstruction* reduce : reductions) { + GenerateElementForReducer(builder, ir_emitter_context, reduce, + partial_result_index, codegen_state, + index_without_linear, input_index, + num_partial_results, result_ir_arrays); + } + + // Emit code to generate the output for the non-reduction instructions + // in the fusion, if any. + TF_CHECK_OK(EmitExtraOutputsForReduce( + builder, input_shape, result_ir_arrays, input_index, reduction_info, + extra_output_gens)); + }; + + TF_ASSIGN_OR_RETURN( + TilingKernelInfo tiling_kernel_info, + EmitTilingKernel(builder, tiling_scheme, index_ty, + [&](const TilingThreadIdInfo& thread_id_info, + const llvm_ir::IrArray::Index& index, + std::array tile_dimensions) { + EmitTile(builder, codegen_state.GetTilingScheme(), + index, thread_id_info, tile_dimensions, + emit_reduction_element); + })); + + KernelSupportLibrary ksl(builder); + for (const HloReduceInstruction* reduce : reductions) { + for (int partial_result_idx = 0; + partial_result_idx < reduction_info.GetNumPartialResults(); + ++partial_result_idx) { + if (codegen_state.IsRowReduction()) { + EmitReductionOutputForRowReduction( + builder, ir_emitter_context, tiling_kernel_info, codegen_state, + index_ty, result_ir_arrays, reduce, partial_result_idx); + } else { + EmitReductionOutputForColumnReduction( + builder, ir_emitter_context, tiling_kernel_info, codegen_state, + index_ty, result_ir_arrays, reduce, partial_result_idx); + } + } + } + + return OkStatus(); +} + +StatusOr> BuildKernelThunkForFusion( + IrEmitterContext& ir_emitter_context, KernelReuseCache& kernel_cache, + mlir::lmhlo::FusionOp fusion_op, const HloInstruction& fusion, + const LaunchDimensions& launch_dimensions, absl::string_view discriminator, + std::function, + std::vector)> + kernel_builder_fn, + llvm::IRBuilder<>* builder) { + TF_ASSIGN_OR_RETURN( + auto kernel_arguments, + KernelArguments::Create(ir_emitter_context.allocations(), fusion_op)); + + Status kernel_builder_status = OkStatus(); + auto [entry, cached] = kernel_cache.Get( + fusion.fused_instructions_computation(), kernel_arguments.args(), + discriminator, [&]() -> KernelReuseCache::Entry { + std::vector inputs, outputs; + llvm::Function* kernel; + std::tie(kernel, inputs, outputs) = BuildKernelPrototype( + ir_emitter_context, GetIrNameFromLoc(fusion_op->getLoc()), + kernel_arguments.args(), fusion.fused_parameters().size(), + launch_dimensions, builder); + kernel_builder_status = + kernel_builder_fn(std::move(inputs), std::move(outputs)); + return {kernel->getName().str(), launch_dimensions}; + }); + TF_RETURN_IF_ERROR(kernel_builder_status); + + return std::make_unique( + fusion_op, entry.kernel_name, kernel_arguments.args(), launch_dimensions); +} + +StatusOr> BuildFusedInitializerThunk( + IrEmitterContext& ir_emitter_context, mlir::lmhlo::FusionOp fusion_op, + const HloInstruction& fusion, ElementalIrEmitter& elemental_emitter, + KernelReuseCache& kernel_cache, int output_index, + llvm::IRBuilder<>* builder) { + auto reduce = mlir::dyn_cast_or_null( + fusion_op.getFusionRoots()[output_index]); + + TF_RET_CHECK(reduce); + TF_RET_CHECK(reduce.getNumResults() == 1); + + mlir::Value init_value = reduce.getInitValues()[0]; + mlir::Value dest = fusion_op.getOutputBuffers()[output_index]; + TF_ASSIGN_OR_RETURN(std::optional> constant_init_thunk, + BuildConstantInitializerThunk( + ir_emitter_context, fusion_op, init_value, dest)); + if (constant_init_thunk) { + return *std::move(constant_init_thunk); + } + + auto input_buffers = fusion_op.getInputBuffers(); + + const Shape dest_shape = GetShape(dest); + bool use_experimental_block_size = + ir_emitter_context.debug_options() + .xla_gpu_enable_experimental_block_size(); + + TF_ASSIGN_OR_RETURN(LaunchDimensions launch_dimensions, + CalculateLaunchDimensions( + dest_shape, ir_emitter_context.gpu_device_info(), + use_experimental_block_size)); + + const HloComputation* fused_computation = + fusion.fused_instructions_computation(); + HloInstruction* instr = fused_computation->root_instruction(); + if (instr->opcode() != HloOpcode::kTuple) { + CHECK_EQ(0, output_index); + } else { + instr = instr->mutable_operand(output_index); + } + TF_RET_CHECK(instr->shape().IsArray()); + + auto kernel_builder = [&](std::vector inputs, + std::vector outputs) -> Status { + FusedIrEmitter fused_emitter(elemental_emitter); + for (int i = 0; i < fused_computation->num_parameters(); i++) { + fused_emitter.BindGenerator( + *fused_computation->parameter_instruction(i), + [builder, &inputs, + i](llvm_ir::IrArray::Index index) -> StatusOr { + return inputs[i].EmitReadArrayElement(index, builder); + }); + } + TF_ASSIGN_OR_RETURN(auto generator, + fused_emitter.GetGenerator(*instr->operand(1))); + return ParallelLoopEmitter(generator, {outputs[0]}, launch_dimensions, + builder) + .EmitLoop(GetIrNameFromLoc(fusion_op.getLoc())); + }; + return BuildKernelThunkForFusion( + ir_emitter_context, kernel_cache, fusion_op, fusion, launch_dimensions, + /*discriminator=*/ + absl::StrCat("init_", output_index), kernel_builder, builder); +} + +} // namespace + +StatusOr ReductionFusion::Emit( + KernelReuseCache& kernel_cache, llvm::IRBuilder<>* builder) const { + auto* reduction_codegen_info = analysis_.GetReductionCodegenInfo(); + // Set `use_experimental_block_size` flag to false as the reduction code + // has its own custom logic of choosing a block size. + TF_ASSIGN_OR_RETURN(auto launch_dimensions, + analysis_.GetLaunchDimensions( + /*use_experimental_block_size=*/false)); + + FusionEmissionResult result; + VLOG(3) << "Launch dimensions of " + << mlir::mhlo::GetDebugNameFromLocation(fusion_op().getLoc()) << ": " + << launch_dimensions.ToString(); + if (!reduction_codegen_info->IsRaceFree()) { + absl::Span fusion_roots = analysis_.fusion_roots(); + for (int i = 0; i < fusion_roots.size(); ++i) { + if (IsReductionFromOrToContiguousDimensions(*fusion_roots[i])) { + TF_ASSIGN_OR_RETURN(result.thunks.emplace_back(), + BuildFusedInitializerThunk( + ir_emitter_context_, fusion_op(), fusion_, + elemental_emitter_, kernel_cache, i, builder)); + } + } + } + + auto kernel_builder = [&](std::vector inputs, + std::vector outputs) -> Status { + FusedIrEmitter fused_emitter(elemental_emitter_); + const HloComputation* fused_computation = analysis_.fused_computation(); + for (int i = 0; i < fused_computation->num_parameters(); i++) { + llvm_ir::IrArray ir_array = inputs[i]; + HloInstruction* fused_operand = + fused_computation->parameter_instruction(i); + fused_emitter.BindGenerator( + *fused_operand, + [builder, ir_array, fused_operand]( + const llvm_ir::IrArray::Index& index) -> StatusOr { + return ir_array.EmitReadArrayElement(index, builder, + fused_operand->name()); + }); + } + + // Get outputs. + ReductionOutputMap result_ir_arrays; + + // Skip all parameter buffers first. + int ir_arrays_idx = 0; + auto outputs_span = absl::MakeSpan(outputs); + for (HloInstruction* root : analysis_.fusion_roots()) { + int num_results = + root->shape().IsTuple() ? root->shape().tuple_shapes_size() : 1; + result_ir_arrays[root] = outputs_span.subspan(ir_arrays_idx, num_results); + ir_arrays_idx += num_results; + } + + KernelSupportLibrary ksl(builder, llvm_ir::UnrollMode::kDefaultUnroll); + + // Use raw block_id_y to select the i-th parallel reduction to run. Using + // block_id_y instead of block_id_x simplifies the index calculation + // for reduction code generation as the block_id_y is orthogonal to + // the indices used within the reductions. + const std::vector>& instr_index_groups = + reduction_codegen_info->GetIndexGroups(); + Shape reduce_operand_shape = + reduction_codegen_info->GetReduceOperandShape(); + + llvm::CallInst* raw_block_id_y = gpu::EmitCallToTargetIntrinsic( + gpu::TargetIntrinsicID::kBlockIdy, {}, {}, builder); + llvm_ir::AddRangeMetadata(0, instr_index_groups.size(), + llvm::cast(raw_block_id_y)); + for (int i = 0; i < instr_index_groups.size(); ++i) { + TF_RETURN_IF_ERROR(ksl.IfWithStatus( + absl::StrCat("reduce-group-", i), + builder->CreateICmpEQ(raw_block_id_y, builder->getInt32(i)), [&] { + return EmitIRForReduction(builder, ir_emitter_context_, fusion_op(), + instr_index_groups[i], fused_emitter, + result_ir_arrays, *reduction_codegen_info, + reduce_operand_shape); + })); + } + + return OkStatus(); + }; + + TF_ASSIGN_OR_RETURN( + result.thunks.emplace_back(), + BuildKernelThunkForFusion(ir_emitter_context_, kernel_cache, fusion_op(), + fusion_, launch_dimensions, "", kernel_builder, + builder)); + return result; +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/fusions/reduction.h b/tensorflow/compiler/xla/service/gpu/fusions/reduction.h new file mode 100644 index 00000000000000..5bb45d4e3c01ea --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/fusions/reduction.h @@ -0,0 +1,120 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FUSIONS_REDUCTION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FUSIONS_REDUCTION_H_ + +#include "tensorflow/compiler/xla/service/gpu/fusions/fusion_emitter.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_fusion_analysis.h" + +namespace xla { +namespace gpu { + +// Generates code for reduction to contiguous dimensions. +// +// Row reduction uses the following algorithm described in CUDA-like +// pseudocode: +// +// ``` +// __global__ void reduce(int num_rows, float *in, float out) { +// __shared__ float[32] cache; +// int offset = blockDim.x * blockIdx.x + threadIdx.x; +// if (offset >= num_rows) return; +// int tile_bound = std::min(offset + kTileSizeX, num_rows); +// float accum = 0; +// for (int i=offset; i Emit( + KernelReuseCache& kernel_cache, + llvm::IRBuilder<>* builder) const override; + + private: + mlir::lmhlo::FusionOp fusion_op() const { return fusion_op_; } + + IrEmitterContext& ir_emitter_context_; + ElementalIrEmitter& elemental_emitter_; + mlir::lmhlo::FusionOp fusion_op_; + const HloFusionInstruction& fusion_; + HloFusionAnalysis& analysis_; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FUSIONS_REDUCTION_H_ diff --git a/tensorflow/compiler/xla/service/gpu/fusions/thunk_util.cc b/tensorflow/compiler/xla/service/gpu/fusions/thunk_util.cc new file mode 100644 index 00000000000000..e38504096376a8 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/fusions/thunk_util.cc @@ -0,0 +1,120 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/xla/service/gpu/fusions/thunk_util.h" + +#include +#include + +#include "absl/types/span.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" +#include "tensorflow/compiler/xla/service/gpu/memset_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/thunk.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/translate/hlo_to_mhlo/hlo_utils.h" + +namespace xla { +namespace gpu { +namespace { + +// TODO(b/291536641): Clean this up. What's the difference between this and the +// caller? +std::optional> BuildConstantInitializerThunk( + mlir::Operation* op, absl::Span init_value, mlir::Value dest, + const BufferAllocation::Slice& dest_slice, const Shape& output_shape) { + int64_t num_bytes = init_value.size(); + if (absl::c_all_of(init_value, [](uint8_t byte) { return byte == 0; })) { + return {{std::make_unique(Thunk::ThunkInfo(op), dest_slice, + dest)}}; + } + + // If the literal is 8 or 16 bits wide, we can emit a 32-bit memset by + // repeating the literal 4 or 2 times, so long as the destination buffer is + // an even multiple of 32 bits long. + if ((num_bytes == 1 || num_bytes == 2) && + ShapeUtil::ByteSizeOf(output_shape) % 4 == 0) { + uint16_t pattern16; + if (num_bytes == 1) { + uint8_t b = init_value.front(); + pattern16 = uint16_t{b} | (uint16_t{b} << 8); + } else { + memcpy(&pattern16, init_value.data(), sizeof(pattern16)); + } + uint32_t pattern32 = uint32_t{pattern16} | (uint32_t{pattern16} << 16); + return {{std::make_unique( + Thunk::ThunkInfo(op), pattern32, dest_slice, dest)}}; + } + + // If the literal is an even multiple of 32 bits wide, we can emit a 32-bit + // memset so long as all 32-bit words of the scalar are equal to each other. + if (num_bytes >= 4 && num_bytes % 4 == 0 && + memcmp(init_value.data(), init_value.data() + 4, init_value.size() - 4) == + 0) { + uint32_t word; + memcpy(&word, init_value.data(), sizeof(word)); + return {{std::make_unique(Thunk::ThunkInfo(op), word, + dest_slice, dest)}}; + } + + return std::nullopt; +} + +} // namespace + +StatusOr>> BuildConstantInitializerThunk( + IrEmitterContext& ir_emitter_context, mlir::Operation* op, + mlir::Value init_value, mlir::Value dest) { + mlir::DenseElementsAttr const_init; + if (auto get_global_memref = + mlir::dyn_cast_or_null( + init_value.getDefiningOp())) { + auto global_memref = + mlir::SymbolTable::lookupNearestSymbolFrom( + get_global_memref, get_global_memref.getNameAttr()); + if (global_memref.getConstant() && global_memref.getInitialValue()) { + // If the initial value happens to be a constant, generate a specialized + // thunk. + const_init = global_memref.getInitialValue() + .value() + .cast(); + } + } else if (auto constant = mlir::dyn_cast_or_null( + init_value.getDefiningOp())) { + const_init = constant.getValue().dyn_cast(); + } + + if (const_init) { + std::vector literal_bytes; + TF_RETURN_IF_ERROR( + CopyDenseElementsDataToXlaFormat(const_init, &literal_bytes)); + + TF_ASSIGN_OR_RETURN( + auto dest_slice, + GetAllocationSlice(dest, ir_emitter_context.allocations())); + + const Shape dest_shape = GetShape(dest); + return BuildConstantInitializerThunk(op, literal_bytes, dest, dest_slice, + dest_shape); + } + return std::nullopt; +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/fusions/thunk_util.h b/tensorflow/compiler/xla/service/gpu/fusions/thunk_util.h new file mode 100644 index 00000000000000..b52c535f034654 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/fusions/thunk_util.h @@ -0,0 +1,37 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FUSIONS_THUNK_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FUSIONS_THUNK_UTIL_H_ + +#include +#include + +#include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" +#include "tensorflow/compiler/xla/service/gpu/thunk.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { +namespace gpu { + +// Attempts to build an initializer constant for the given value. Returns an +// empty optional if the value is not a constant. +StatusOr>> BuildConstantInitializerThunk( + IrEmitterContext& ir_emitter_context, mlir::Operation* op, + mlir::Value init_value, mlir::Value dest); + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FUSIONS_THUNK_UTIL_H_ diff --git a/tensorflow/compiler/xla/service/gpu/fusions/tiling_util.cc b/tensorflow/compiler/xla/service/gpu/fusions/tiling_util.cc index c0c2640a895937..8d0ca80b2d9182 100644 --- a/tensorflow/compiler/xla/service/gpu/fusions/tiling_util.cc +++ b/tensorflow/compiler/xla/service/gpu/fusions/tiling_util.cc @@ -360,5 +360,38 @@ llvm::Value* TilingThreadIdInfo::GEPIntoSharedMemory( return b->CreateAddrSpaceCast(gep, pointer_in_addressspace); } +llvm_ir::IrArray::Index GetUnnormalizedIndex( + const llvm_ir::IrArray::Index& normalized_shape_index, + const Shape& unnormalized_shape, llvm::IRBuilder<>* builder, + absl::Span dims_in_elems) { + CHECK_EQ(normalized_shape_index.size(), 3); + // If the normalization only add a new dimensions of size 1, + // generate simpler indexing. LLVM doesn't always simplify the more + // complicated indexing and this prevents it from vectorizing some + // cases. We do this only for major_to_minor memory layout. + if (unnormalized_shape.rank() == 2 && unnormalized_shape.has_layout() && + unnormalized_shape.dimensions()[0] == normalized_shape_index.dims()[1] && + unnormalized_shape.dimensions()[1] == normalized_shape_index.dims()[2] && + unnormalized_shape.layout().minor_to_major(1) == 0) { + CHECK_EQ(normalized_shape_index.dims()[0], 1); + auto multidim = normalized_shape_index.multidim(); + return llvm_ir::IrArray::Index({multidim[1], multidim[2]}, + unnormalized_shape, + normalized_shape_index.GetType()); + } + if (unnormalized_shape.rank() == 2 && unnormalized_shape.has_layout() && + unnormalized_shape.dimensions()[0] == normalized_shape_index.dims()[2] && + unnormalized_shape.dimensions()[1] == normalized_shape_index.dims()[1] && + unnormalized_shape.layout().minor_to_major(1) == 1) { + CHECK_EQ(normalized_shape_index.dims()[0], 1); + auto multidim = normalized_shape_index.multidim(); + return llvm_ir::IrArray::Index({multidim[2], multidim[1]}, + unnormalized_shape, + normalized_shape_index.GetType()); + } + return normalized_shape_index.SourceIndexOfBitcast( + ShapeUtil::MakeShape(F32, dims_in_elems), unnormalized_shape, builder); +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/fusions/tiling_util.h b/tensorflow/compiler/xla/service/gpu/fusions/tiling_util.h index 1d8dd40e916300..be697addeb78e0 100644 --- a/tensorflow/compiler/xla/service/gpu/fusions/tiling_util.h +++ b/tensorflow/compiler/xla/service/gpu/fusions/tiling_util.h @@ -138,6 +138,11 @@ StatusOr EmitTilingKernel( llvm::IRBuilder<>* builder, const TilingScheme& tiling_scheme, llvm::Type* index_ty, const TileElementGenerator& tile_element_generator); +llvm_ir::IrArray::Index GetUnnormalizedIndex( + const llvm_ir::IrArray::Index& normalized_shape_index, + const Shape& unnormalized_shape, llvm::IRBuilder<>* builder, + absl::Span dims_in_elems); + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index 3070cbf9b1dfa7..95d0a67c6c4f27 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -21,13 +21,17 @@ limitations under the License. #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_computation.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" #include "tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index d51ef9046c6e5b..3e23e291c96330 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -92,6 +92,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/for_thunk.h" #include "tensorflow/compiler/xla/service/gpu/fused_mha_thunk.h" #include "tensorflow/compiler/xla/service/gpu/fusions/fusions.h" +#include "tensorflow/compiler/xla/service/gpu/fusions/thunk_util.h" #include "tensorflow/compiler/xla/service/gpu/fusions/tiling_util.h" #include "tensorflow/compiler/xla/service/gpu/gemm_thunk.h" #include "tensorflow/compiler/xla/service/gpu/gpu_asm_opts_util.h" @@ -2046,8 +2047,6 @@ Status IrEmitterUnnested::EmitFusion(mlir::Operation* op) { #endif LOG(FATAL) << "Unsupported fusion kind: " << backend_config.kind(); } - case HloFusionAnalysis::EmitterFusionKind::kReduction: - return EmitUnnestedReduction(fusion_op, fusion_analysis); case HloFusionAnalysis::EmitterFusionKind::kTranspose: return EmitUnnestedTranspose(fusion_op, fusion_analysis); case HloFusionAnalysis::EmitterFusionKind::kInputSlices: @@ -2055,6 +2054,7 @@ Status IrEmitterUnnested::EmitFusion(mlir::Operation* op) { case HloFusionAnalysis::EmitterFusionKind::kScatter: return EmitScatter(fusion_op, fused_computation, fusion_analysis); case HloFusionAnalysis::EmitterFusionKind::kLoop: + case HloFusionAnalysis::EmitterFusionKind::kReduction: return FailedPrecondition( "Loop fusion should have been handled by GetFusionEmitter."); } @@ -3207,86 +3207,6 @@ IrEmitterUnnested::BuildKernelThunkForNonFusionOp( launch_dimensions); } -std::unique_ptr IrEmitterUnnested::BuildConstantInitializerThunk( - mlir::Operation* op, absl::Span init_value, mlir::Value dest, - const BufferAllocation::Slice& dest_slice, const Shape& output_shape) { - int64_t num_bytes = init_value.size(); - if (absl::c_all_of(init_value, [](uint8_t byte) { return byte == 0; })) { - return std::make_unique(Thunk::ThunkInfo(op), dest_slice, - dest); - } - - // If the literal is 8 or 16 bits wide, we can emit a 32-bit memset by - // repeating the literal 4 or 2 times, so long as the destination buffer is - // an even multiple of 32 bits long. - if ((num_bytes == 1 || num_bytes == 2) && - ShapeUtil::ByteSizeOf(output_shape) % 4 == 0) { - uint16_t pattern16; - if (num_bytes == 1) { - uint8_t b = init_value.front(); - pattern16 = uint16_t{b} | (uint16_t{b} << 8); - } else { - memcpy(&pattern16, init_value.data(), sizeof(pattern16)); - } - uint32_t pattern32 = uint32_t{pattern16} | (uint32_t{pattern16} << 16); - return std::make_unique(Thunk::ThunkInfo(op), - pattern32, dest_slice, dest); - } - - // If the literal is an even multiple of 32 bits wide, we can emit a 32-bit - // memset so long as all 32-bit words of the scalar are equal to each other. - if (num_bytes >= 4 && num_bytes % 4 == 0 && - memcmp(init_value.data(), init_value.data() + 4, init_value.size() - 4) == - 0) { - uint32_t word; - memcpy(&word, init_value.data(), sizeof(word)); - return std::make_unique(Thunk::ThunkInfo(op), word, - dest_slice, dest); - } - - return nullptr; -} - -StatusOr> -IrEmitterUnnested::TryBuildConstantInitializerThunk(mlir::Operation* op, - mlir::Value init_value, - mlir::Value dest) { - mlir::DenseElementsAttr const_init; - if (auto get_global_memref = - mlir::dyn_cast_or_null( - init_value.getDefiningOp())) { - auto global_memref = - mlir::SymbolTable::lookupNearestSymbolFrom( - get_global_memref, get_global_memref.getNameAttr()); - if (global_memref.getConstant() && global_memref.getInitialValue()) { - // If the initial value happens to be a constant, generate a specialized - // thunk. - const_init = global_memref.getInitialValue() - .value() - .cast(); - } - } else if (auto constant = mlir::dyn_cast_or_null( - init_value.getDefiningOp())) { - const_init = constant.getValue().dyn_cast(); - } - - if (const_init) { - std::vector literal_bytes; - TF_RETURN_IF_ERROR( - CopyDenseElementsDataToXlaFormat(const_init, &literal_bytes)); - - TF_ASSIGN_OR_RETURN(auto dest_slice, GetAllocationSlice(dest)); - - const Shape dest_shape = GetShape(dest); - auto thunk = BuildConstantInitializerThunk(op, literal_bytes, dest, - dest_slice, dest_shape); - if (thunk) { - return {std::move(thunk)}; - } - } - return std::unique_ptr(); -} - Status IrEmitterUnnested::BuildInitializerThunk(mlir::Operation* op, mlir::Value init_value, mlir::Value dest) { @@ -3294,10 +3214,11 @@ Status IrEmitterUnnested::BuildInitializerThunk(mlir::Operation* op, auto init_type = init_value.getType().dyn_cast(); TF_RET_CHECK(init_type.getRank() == 0); - TF_ASSIGN_OR_RETURN(std::unique_ptr constant_init_thunk, - TryBuildConstantInitializerThunk(op, init_value, dest)); + TF_ASSIGN_OR_RETURN(std::optional> constant_init_thunk, + BuildConstantInitializerThunk(*ir_emitter_context_, op, + init_value, dest)); if (constant_init_thunk) { - AddThunkToThunkSequence(std::move(constant_init_thunk)); + AddThunkToThunkSequence(*std::move(constant_init_thunk)); return OkStatus(); } @@ -3329,77 +3250,6 @@ Status IrEmitterUnnested::BuildInitializerThunk(mlir::Operation* op, return OkStatus(); } -Status IrEmitterUnnested::BuildFusedInitializerThunk( - mlir::lmhlo::FusionOp fusion, int output_index) { - auto reduce = mlir::dyn_cast_or_null( - fusion.getFusionRoots()[output_index]); - - TF_RET_CHECK(reduce); - TF_RET_CHECK(reduce.getNumResults() == 1); - - mlir::Value init_value = reduce.getInitValues()[0]; - mlir::Value dest = fusion.getOutputBuffers()[output_index]; - TF_ASSIGN_OR_RETURN( - std::unique_ptr constant_init_thunk, - TryBuildConstantInitializerThunk(fusion, init_value, dest)); - if (constant_init_thunk) { - AddThunkToThunkSequence(std::move(constant_init_thunk)); - return OkStatus(); - } - - auto input_buffers = fusion.getInputBuffers(); - - const Shape dest_shape = GetShape(dest); - bool use_experimental_block_size = - ir_emitter_context_->debug_options() - .xla_gpu_enable_experimental_block_size(); - - TF_ASSIGN_OR_RETURN(LaunchDimensions launch_dimensions, - CalculateLaunchDimensions( - dest_shape, ir_emitter_context_->gpu_device_info(), - use_experimental_block_size)); - - TF_ASSIGN_OR_RETURN( - std::optional> opt_ir_arrays, - BuildKernelThunkForFusion( - fusion, launch_dimensions, - /*discriminator=*/absl::StrCat("init_", output_index))); - if (!opt_ir_arrays.has_value()) { - // The kernel was reused, no need to emit code. - return OkStatus(); - } - std::vector& ir_arrays = opt_ir_arrays.value(); - - const llvm_ir::IrArray dest_array = - ir_arrays[input_buffers.size() + output_index]; - - const HloComputation* fused_computation = - *GetOrCreateSubComputationFromRegion(&fusion.getRegion(), - /*is_fusion=*/true); - - FusedIrEmitter fused_emitter(elemental_emitter_); - for (int i = 0; i < fused_computation->num_parameters(); i++) { - fused_emitter.BindGenerator( - *fused_computation->parameter_instruction(i), - [this, &ir_arrays, i](llvm_ir::IrArray::Index index) { - return ir_arrays[i].EmitReadArrayElement(index, &b_); - }); - } - HloInstruction* instr = fused_computation->root_instruction(); - if (instr->opcode() != HloOpcode::kTuple) { - CHECK_EQ(0, output_index); - } else { - instr = instr->mutable_operand(output_index); - } - TF_RET_CHECK(instr->shape().IsArray()); - TF_ASSIGN_OR_RETURN(auto generator, - fused_emitter.GetGenerator(*instr->operand(1))); - TF_RETURN_IF_ERROR( - ParallelLoopEmitter(generator, {dest_array}, launch_dimensions, &b_) - .EmitLoop(GetIrNameFromLoc(fusion.getLoc()))); - return OkStatus(); -} - StatusOr> IrEmitterUnnested::BuildWhileThunk( mlir::lmhlo::WhileOp while_op, const Thunk::ThunkInfo& thunk_info) { // Generate thunk sequence for while 'condition'. @@ -3443,263 +3293,6 @@ Status IrEmitterUnnested::EmitTargetElementLoop( return InternalError("This should be unreachable"); } -// Gets the output offset as calculated from thread_id.x (to be applied to the -// offset calculated from block_id and thread_id.y). -static llvm::Value* GetStartOffsetX(const TilingScheme& tiling_scheme, - llvm::Value* thread_id_x, - llvm::Type* index_ty, - llvm::IRBuilder<>* b) { - int64_t multiplier = tiling_scheme.GetIndexingOrder() == kStridedIndexingX - ? tiling_scheme.GetVectorSize() - : tiling_scheme.GetTileSizeFor(kDimX); - return b->CreateMul(thread_id_x, - llvm::ConstantInt::get(index_ty, multiplier)); -} - -static IrArray::Index GetUnnormalizedIndex( - const IrArray::Index& normalized_shape_index, - const Shape& unnormalized_shape, llvm::IRBuilder<>* b_, - absl::Span dims_in_elems) { - CHECK_EQ(normalized_shape_index.size(), 3); - // If the normalization only add a new dimensions of size 1, - // generate simpler indexing. LLVM doesn't always simplify the more - // complicated indexing and this prevents it from vectorizing some - // cases. We do this only for major_to_minor memory layout. - if (unnormalized_shape.rank() == 2 && unnormalized_shape.has_layout() && - unnormalized_shape.dimensions()[0] == normalized_shape_index.dims()[1] && - unnormalized_shape.dimensions()[1] == normalized_shape_index.dims()[2] && - unnormalized_shape.layout().minor_to_major(1) == 0) { - CHECK_EQ(normalized_shape_index.dims()[0], 1); - auto multidim = normalized_shape_index.multidim(); - return IrArray::Index({multidim[1], multidim[2]}, unnormalized_shape, - normalized_shape_index.GetType()); - } - if (unnormalized_shape.rank() == 2 && unnormalized_shape.has_layout() && - unnormalized_shape.dimensions()[0] == normalized_shape_index.dims()[2] && - unnormalized_shape.dimensions()[1] == normalized_shape_index.dims()[1] && - unnormalized_shape.layout().minor_to_major(1) == 1) { - CHECK_EQ(normalized_shape_index.dims()[0], 1); - auto multidim = normalized_shape_index.multidim(); - return IrArray::Index({multidim[2], multidim[1]}, unnormalized_shape, - normalized_shape_index.GetType()); - } - return normalized_shape_index.SourceIndexOfBitcast( - ShapeUtil::MakeShape(F32, dims_in_elems), unnormalized_shape, b_); -} - -static int GetNumOutputs(const Shape& shape) { - if (shape.IsTuple()) { - return shape.tuple_shapes_size(); - } - return 1; -} - -ReductionCodegenState IrEmitterUnnested::GenerateReductionCodegenState( - mlir::lmhlo::FusionOp fusion, const ReductionCodegenInfo& reduction_info, - absl::Span reduce_instr_index_group, - FusedIrEmitter& fused_emitter) { - ReductionCodegenState reduction_codegen_state(reduction_info); - VLOG(10) << "Emit prologue for reduction: " << llvm_ir::DumpToString(fusion); - - for (const HloReduceInstruction* reduce_hlo : reduce_instr_index_group) { - int num_partial_results = reduction_codegen_state.GetNumPartialResults(); - for (int op_result_idx = 0; - op_result_idx < GetNumOutputs(reduce_hlo->shape()); op_result_idx++) { - Shape result_shape = reduce_hlo->shape().IsTuple() - ? reduce_hlo->shape().tuple_shapes(op_result_idx) - : reduce_hlo->shape(); - - llvm::Type* element_type = - llvm_ir::PrimitiveTypeToIrType(result_shape.element_type(), module_); - llvm::AllocaInst* reduction_input_address = - llvm_ir::EmitAllocaAtFunctionEntry(element_type, - "reduction_input_address", &b_); - - llvm::AllocaInst* partial_result_address = - llvm_ir::EmitAllocaAtFunctionEntryWithCount( - element_type, /*element_count=*/b_.getInt32(num_partial_results), - "partial_reduction_result", &b_); - - const HloInstruction* init_value = - reduce_hlo->init_values()[op_result_idx]; - - // Initialize the partial result with the initial value of the reduction. - llvm::Value* init_ir_value = (*fused_emitter.GetGenerator(*init_value))( - IrArray::Index(b_.getInt32Ty())) - .value(); - - for (int i = 0; i < num_partial_results; ++i) { - b_.CreateStore(init_ir_value, - InBoundsGEP(partial_result_address->getAllocatedType(), - partial_result_address, {b_.getInt32(i)})); - } - - const TilingScheme& tiling_scheme = - reduction_codegen_state.GetTilingScheme(); - int64_t num_threads_x = tiling_scheme.GetNumThreadsFor(kDimX); - llvm::GlobalVariable* shared_cache = [&]() -> llvm::GlobalVariable* { - if (reduction_codegen_state.IsRowReduction()) { - // Multi-row reductions do not use shared memory. - if (RowReductionGetRowsPerWarp(tiling_scheme.GetDimsInElems()[2]) > - 1) { - return nullptr; - } - // Allocate __shared__ - // cache[num_partial_results][num_warps][scaling_factor]. - CHECK_EQ(tiling_scheme.GetNumThreadsPerBlock() % WarpSize(), 0); - int num_warps = tiling_scheme.GetNumThreadsPerBlock() / WarpSize(); - return AllocateShared(tiling_scheme, element_type, - {num_partial_results, num_warps}, - "shared_cache"); - } else { - // Allocate __shared__ - // cache[num_threads][num_threads + 1], where - // num_threads == num_threads_x == num_threads_y. The "+1" is used to - // avoid bank conflicts. - // - // (Although each thread produces num_partial_results results, we - // don't need that much cache: Only one result is live at a time.) - CHECK_EQ(num_threads_x, tiling_scheme.GetNumThreadsFor(kDimY)); - return AllocateShared(tiling_scheme, element_type, - {num_threads_x, num_threads_x + 1}, - "shared_cache"); - } - }(); - - llvm_ir::ElementGenerator input_gen = - *fused_emitter.GetGenerator(*reduce_hlo->inputs()[op_result_idx]); - reduction_codegen_state.SetCalculationStateFor( - {shared_cache, init_ir_value, partial_result_address, - reduction_input_address, input_gen}, - reduce_hlo, op_result_idx); - } - } - - return reduction_codegen_state; -} - -void IrEmitterUnnested::EmitFullWarpShuffleDownLoopForReduce( - const HloComputation* reducer, - absl::Span partial_result_addresses, - int threads_per_block, int num_results_per_warp) { - // This only works when the block size is a multiple of 32 threads. - - // We check this here as a mistake in the number of threads per - // block is very hard to detect. - CHECK_EQ(threads_per_block % 32, 0); - CHECK_EQ(WarpSize() % num_results_per_warp, 0); - - for (int distance = 16 / num_results_per_warp; distance >= 1; distance /= 2) { - absl::InlinedVector reduction_params; - - for (auto acc : partial_result_addresses) { - reduction_params.push_back(acc.first); - } - - for (auto [partial_result_address, element_type] : - partial_result_addresses) { - int bit_width = llvm_ir::GetSizeInBits(element_type); - llvm::Value* result_from_other_lane = llvm_ir::EmitAllocaAtFunctionEntry( - element_type, "result_from_other_lane", &b_); - - reduction_params.push_back(result_from_other_lane); - - // Bitcast cannot be applied to aggregate types (even packed ones), so - // we bitcast addresses of load/store to intN* of the same bit-width. - llvm::Type* shuffled_value_type = - element_type->isStructTy() ? b_.getIntNTy(bit_width) : element_type; - auto convert_pointer_for_shuffle = [&](llvm::Value* ptr) { - return b_.CreatePointerBitCastOrAddrSpaceCast( - ptr, shuffled_value_type->getPointerTo()); - }; - - llvm::Value* partial_result = - b_.CreateLoad(shuffled_value_type, - convert_pointer_for_shuffle(partial_result_address), - "partial_reduction_result"); - b_.CreateStore( - EmitFullWarpShuffleDown(partial_result, b_.getInt32(distance), &b_), - convert_pointer_for_shuffle(result_from_other_lane)); - } - - StatusOr> returned_scalars = - CallNestedComputationWithScalarAddrs(&b_, *ir_emitter_context_, - *reducer, reduction_params); - TF_CHECK_OK(returned_scalars.status()); - - for (int i = 0; i < returned_scalars->size(); i++) { - b_.CreateStore(/*Val=*/returned_scalars->at(i), - /*Ptr=*/partial_result_addresses[i].first); - } - } -} - -llvm::Value* IrEmitterUnnested::GetOutputAddressForReduction( - int partial_result_idx, llvm::Type* index_ty, - const ReductionCodegenState& reduction_codegen_state, - const TilingKernelInfo& tiling_kernel_info, - const IrEmitterUnnested::ReductionOutputMap& output_arrays, - const HloReduceInstruction* reduction, int output_idx) { - auto constant = [&](uint64_t c) -> llvm::Constant* { - return llvm::ConstantInt::get(index_ty, c); - }; - - const TilingScheme& tiling_scheme = reduction_codegen_state.GetTilingScheme(); - const TilingThreadIdInfo& thread_id_info = tiling_kernel_info.thread_id_info; - - IrArray::Index start_offset = [&] { - llvm::Value* x_loc = thread_id_info.thread_id_x; - llvm::Value* y_loc = thread_id_info.thread_id_y; - if (!reduction_codegen_state.IsRowReduction()) { - std::swap(x_loc, y_loc); - } - llvm::Value* start_offset_x = - GetStartOffsetX(tiling_scheme, x_loc, index_ty, &b_); - return tiling_kernel_info.tile_origin.AddOffsetToDim(y_loc, kDimY, &b_) - .AddOffsetToDim(start_offset_x, kDimX, &b_); - }(); - - const IrArray& output_array = output_arrays.at(reduction)[output_idx]; - const Shape& operand_shape = reduction->inputs()[output_idx]->shape(); - Shape reduction_kept_element_shape = - ShapeUtil::DeleteDimensions(reduction->dimensions(), operand_shape); - - // Given the IrArray index of a reduction input, returns the linear address of - // the reduction output as if the reduction were going to keep the input shape - // with the dimensions being reduced moved. - llvm::Value* untransposed_output_linear_address = [&] { - const llvm_ir::IrArray::Index index = - start_offset.AddOffsetToDim(constant(partial_result_idx), kDimX, &b_); - if (reduction_codegen_state.IsRowReduction()) { - // For row-reduction, y-coordinate determines which row we write into. - return index[kDimY]; - } - // For column reduction, we get the transposed address. - absl::Span dims_in_elem = tiling_scheme.GetDimsInElems(); - llvm::Value* x_dim_size = - index.GetConstantWithIndexType(dims_in_elem[kDimX]); - llvm::Value* x_block_offset = b_.CreateMul(index[kDimZ], x_dim_size); - return b_.CreateAdd(x_block_offset, index[kDimX]); - }(); - - // A reduction is allowed to transpose its output. For example, suppose - // we are reducing the second dimension of f32[10,20,30]{3,2,1}. We are - // allowed to produce as output either f32[10,30]{1,0} (no transpose) or - // f32[10,30]{0,1} (transposing the two output dims). - // - // At this point in the function we have a "partial sum" of input elements - // (stored in partial_result_addresses), and we need to accumulate it into - // the correct output element. - IrArray::Index element_index( - /*linear=*/untransposed_output_linear_address, - reduction_kept_element_shape, &b_); - IrArray::Index output_index(element_index.multidim(), output_array.GetShape(), - element_index.GetType()); - - return output_array.EmitArrayElementAddress(output_index, &b_, - "output_element_address"); -} - llvm::Value* IrEmitterUnnested::CastSharedToGlobal(llvm::Value* input, llvm::Type* element_type, llvm::Twine name) { @@ -3709,225 +3302,6 @@ llvm::Value* IrEmitterUnnested::CastSharedToGlobal(llvm::Value* input, name); } -void IrEmitterUnnested::WriteReductionOutput( - llvm::Type* index_ty, const ReductionCodegenState& reduction_codegen_state, - const TilingKernelInfo& tiling_kernel_info, - const ReductionOutputMap& output_arrays, - const HloReduceInstruction* reduction, int partial_result_idx, - const absl::Span values) { - const HloComputation* reducer = reduction->to_apply(); - for (const auto& [oidx, typed_ptr] : llvm::enumerate(values)) { - auto [output_ptr, type] = typed_ptr; - llvm::Value* output_address = GetOutputAddressForReduction( - partial_result_idx, index_ty, reduction_codegen_state, - tiling_kernel_info, output_arrays, reduction, oidx); - if (reduction_codegen_state.IsRaceFree()) { - b_.CreateStore(b_.CreateLoad(type, output_ptr, "output"), output_address); - } else { - CHECK_EQ(values.size(), 1); - TF_CHECK_OK(EmitAtomicOperationForNestedComputation( - &b_, *ir_emitter_context_, *reducer, output_address, output_ptr, - type)); - } - } -} - -void IrEmitterUnnested::EmitReductionOutputForRowReduction( - const TilingKernelInfo& tiling_kernel_info, - const ReductionCodegenState& reduction_codegen_state, llvm::Type* index_ty, - const ReductionOutputMap& output_arrays, - const HloReduceInstruction* reduction, int partial_result_idx) { - const HloComputation* reducer = reduction->to_apply(); - const auto& thread_id_info = tiling_kernel_info.thread_id_info; - auto constant = [&](uint64_t c) -> llvm::Constant* { - return llvm::ConstantInt::get(index_ty, c); - }; - auto is_zero = [&](llvm::Value* value) { - return b_.CreateICmpEQ(value, constant(0)); - }; - - int num_outputs = reducer->num_parameters() / 2; - const TilingScheme& tiling_scheme = reduction_codegen_state.GetTilingScheme(); - absl::InlinedVector current_outputs; - for (int output_idx = 0; output_idx < num_outputs; output_idx++) { - const ReductionCodegenState::ReductionCalculationState& state = - reduction_codegen_state.GetCalculationStateFor(reduction, output_idx); - current_outputs.push_back( - {InBoundsGEP(state.partial_result_address->getAllocatedType(), - state.partial_result_address, - {constant(partial_result_idx)}, "current_output"), - state.partial_result_address->getAllocatedType()}); - } - - int reduced_dimension_size = tiling_scheme.GetDimsInElems()[2]; - int num_rows_per_warp = RowReductionGetRowsPerWarp(reduced_dimension_size); - EmitFullWarpShuffleDownLoopForReduce( - reducer, absl::MakeSpan(current_outputs), - tiling_scheme.GetNumThreadsPerBlockPhysical(), num_rows_per_warp); - - KernelSupportLibrary ksl(&b_); - llvm::Value* warp_id = - b_.CreateUDiv(thread_id_info.thread_id_x, constant(WarpSize())); - - auto emit_write_output = [&](llvm::Value* write_condition, - const absl::Span values) { - ksl.If("reduction_write_output", write_condition, [&] { - WriteReductionOutput(index_ty, reduction_codegen_state, - tiling_kernel_info, output_arrays, reduction, - partial_result_idx, values); - }); - }; - - if (num_rows_per_warp > 1) { - llvm::Value* is_writing_thread = is_zero(b_.CreateAnd( - thread_id_info.thread_id_x, constant(reduced_dimension_size - 1))); - emit_write_output(is_writing_thread, current_outputs); - return; - } - - ksl.If("intra_warp_reduce_write", is_zero(thread_id_info.lane_id), [&] { - for (int oidx = 0; oidx < num_outputs; oidx++) { - const ReductionCodegenState::ReductionCalculationState& state = - reduction_codegen_state.GetCalculationStateFor(reduction, oidx); - llvm::Value* shmem_output_addr = thread_id_info.GEPIntoSharedMemory( - &b_, state.shared_cache, {constant(partial_result_idx), warp_id}); - Store(Load(current_outputs[oidx].second, current_outputs[oidx].first), - shmem_output_addr); - } - }); - - // TODO(cheshire): Don't we want to sync it once for everything in the - // output? Not once per each? - EmitSyncThreads(); - ksl.If("inter_warp_reduce", is_zero(warp_id), [&] { - absl::InlinedVector selected_values; - for (int oidx = 0; oidx < num_outputs; oidx++) { - const ReductionCodegenState::ReductionCalculationState& state = - reduction_codegen_state.GetCalculationStateFor(reduction, oidx); - llvm::Value* block_accum_addr = thread_id_info.GEPIntoSharedMemory( - &b_, state.shared_cache, - {constant(partial_result_idx), thread_id_info.lane_id}); - - llvm::Type* element_type = - state.partial_result_address->getAllocatedType(); - - /* Insure initial value address is in generic, not scratch. */ - llvm::Value* initial_value_addr = - CastSharedToGlobal(llvm_ir::EmitAllocaAtFunctionEntry( - element_type, "initial_value_addr", &b_), - element_type); - b_.CreateStore(state.initial_value, initial_value_addr); - - llvm::Value* warp_exists = b_.CreateICmpULT( - thread_id_info.thread_id_x, - constant(tiling_scheme.GetNumThreadsFor(kDimX) / WarpSize())); - - llvm::Value* selected_value = - b_.CreateSelect(warp_exists, block_accum_addr, initial_value_addr); - - selected_values.push_back({selected_value, element_type}); - } - - // If only one warp is present in the block, then we don't need inter-warp - // reduction. - // TODO(b/241414088) If only warp is present, then inter-warp communication - // using shared memory and synchronization using barrier is also unnecessary - // and should be removed. - if (tiling_scheme.GetNumThreadsPerBlock() > WarpSize()) { - EmitFullWarpShuffleDownLoopForReduce( - reducer, absl::MakeSpan(selected_values), - tiling_scheme.GetNumThreadsPerBlock()); - } - - emit_write_output(is_zero(thread_id_info.thread_id_x), selected_values); - }); -} - -void IrEmitterUnnested::EmitReductionOutputForColumnReduction( - const TilingKernelInfo& tiling_kernel_info, - const ReductionCodegenState& reduction_codegen_state, llvm::Type* index_ty, - const ReductionOutputMap& output_arrays, - const HloReduceInstruction* reduction, int partial_result_idx) { - KernelSupportLibrary ksl(&b_); - const HloComputation* reducer = reduction->to_apply(); - const auto& thread_id_info = tiling_kernel_info.thread_id_info; - - auto constant = [&](uint64_t c) -> llvm::Constant* { - return llvm::ConstantInt::get(index_ty, c); - }; - auto is_zero = [&](llvm::Value* value) { - return b_.CreateICmpEQ(value, constant(0)); - }; - const TilingScheme& tiling_scheme = reduction_codegen_state.GetTilingScheme(); - int num_outputs = reducer->num_parameters() / 2; - - // Wait for reads from shmem in the last iteration to complete. (If this is - // slow, we could "double-buffer" by having two shmem buffers and switching - // between them.) - if (partial_result_idx > 0) { - EmitSyncThreads(); - } - - // Store the transpose in shared memory. - for (int output_idx = 0; output_idx < num_outputs; output_idx++) { - const ReductionCodegenState::ReductionCalculationState& state = - reduction_codegen_state.GetCalculationStateFor(reduction, output_idx); - llvm::GlobalVariable* shared_cache = state.shared_cache; - llvm::AddrSpaceCastInst* shmem_output_addr = - llvm::cast(thread_id_info.GEPIntoSharedMemory( - &b_, shared_cache, - {thread_id_info.thread_id_x, thread_id_info.thread_id_y}, - "shmem_output_address")); - llvm::Value* current_output = - InBoundsGEP(state.partial_result_address->getAllocatedType(), - state.partial_result_address, - {constant(partial_result_idx)}, "current_output"); - - llvm::Value* current_output_value = - Load(state.partial_result_address->getAllocatedType(), current_output); - b_.CreateStore(current_output_value, shmem_output_addr); - } - - EmitSyncThreads(); - - // Get transposed element from shared memory. - absl::InlinedVector shmem_transposed_addrs; - for (int output_idx = 0; output_idx < num_outputs; output_idx++) { - const ReductionCodegenState::ReductionCalculationState& state = - reduction_codegen_state.GetCalculationStateFor(reduction, output_idx); - llvm::AddrSpaceCastInst* shmem_transposed_addr = - llvm::cast(thread_id_info.GEPIntoSharedMemory( - &b_, state.shared_cache, - {thread_id_info.thread_id_y, thread_id_info.thread_id_x}, - "shmem_transposed_addr")); - shmem_transposed_addrs.push_back( - {shmem_transposed_addr, llvm::cast( - shmem_transposed_addr->getPointerOperand()) - ->getResultElementType()}); - } - - EmitFullWarpShuffleDownLoopForReduce(reducer, - absl::MakeSpan(shmem_transposed_addrs), - tiling_scheme.GetNumThreadsPerBlock()); - - // Some warps in the block are completely outside of the bound of the - // tensor, so they should not write any output at all. - llvm::Value* has_output = - b_.CreateAnd(b_.CreateICmpULT(GetStartOffsetX(tiling_scheme, - thread_id_info.thread_id_y, - index_ty, &b_), - tiling_kernel_info.output_tile_bounds[1]), - b_.CreateICmpULT(thread_id_info.thread_id_x, - tiling_kernel_info.output_tile_bounds[0])); - - ksl.If("reduction_write_output", - b_.CreateAnd(has_output, is_zero(thread_id_info.lane_id)), [&] { - WriteReductionOutput(index_ty, reduction_codegen_state, - tiling_kernel_info, output_arrays, reduction, - partial_result_idx, shmem_transposed_addrs); - }); -} - llvm::CallInst* IrEmitterUnnested::EmitSyncThreads() { MaybeEmitFenceForAMDGPU(llvm::AtomicOrdering::SequentiallyConsistent, "workgroup"); @@ -4107,239 +3481,6 @@ llvm::GlobalVariable* IrEmitterUnnested::AllocateShared( array_type, buffer_name); } -// Generate a single element of the tile (update the accumulator state) for a -// given reducer of index `i`. -void IrEmitterUnnested::GenerateElementForReducer( - const HloReduceInstruction* reduction, llvm::Value* partial_result_index, - const ReductionCodegenState& codegen_state, - const llvm_ir::IrArray::Index& index_without_linear, - const IrArray::Index& input_index, int num_partial_results, - const ReductionOutputMap& result_ir_arrays) { - HloComputation* reducer = reduction->to_apply(); - CHECK_EQ(reducer->num_parameters() % 2, 0); - - absl::InlinedVector reduction_accumulators; - absl::InlinedVector reduction_input_value; - for (int red_idx = 0; red_idx < reducer->num_parameters() / 2; red_idx++) { - const ReductionCodegenState::ReductionCalculationState& state = - codegen_state.GetCalculationStateFor(reduction, red_idx); - - llvm::AllocaInst* input_address = state.input_address; - llvm::AllocaInst* partial_reduction_result_address = - state.partial_result_address; - llvm::Value* const input_ir_value = *state.input_gen( - num_partial_results > 1 ? index_without_linear : input_index); - b_.CreateStore(input_ir_value, input_address); - llvm::Value* partial_result_address = - InBoundsGEP(partial_reduction_result_address->getAllocatedType(), - partial_reduction_result_address, {partial_result_index}); - reduction_accumulators.push_back(partial_result_address); - reduction_input_value.push_back(input_address); - } - - absl::InlinedVector reduction_params; - for (llvm::Value* acc : reduction_accumulators) { - reduction_params.push_back(acc); - } - for (llvm::Value* value : reduction_input_value) { - reduction_params.push_back(value); - } - - // Emit a call to the variadic reducer. Since it may be returning a - // tuple, we can't return it directly as a value. Instead, before - // the call, we create N (N = # arguments in the tuple) allocas, one - // for each returned argument, then when we make the call we pass N - // pointers as last parameters, the called computation writes into - // those pointers, and we have returned values on the stack (as well - // as pointers to them). - StatusOr> returned_scalars = - CallNestedComputationWithScalarAddrs(&b_, *ir_emitter_context_, *reducer, - reduction_params); - TF_CHECK_OK(returned_scalars.status()); - - for (int i = 0; i < returned_scalars->size(); i++) { - b_.CreateStore(returned_scalars->at(i), reduction_accumulators[i]); - } -} - -Status IrEmitterUnnested::EmitIRForReduction( - mlir::lmhlo::FusionOp fusion, - absl::Span instr_index_group, - FusedIrEmitter& fused_emitter, const ReductionOutputMap& result_ir_arrays, - const ReductionCodegenInfo& reduction_info, const Shape& input_shape) { - std::vector reductions; - ExtraOutputGensMap extra_output_gens; - - for (const HloInstruction* hlo : instr_index_group) { - if (IsReductionFromOrToContiguousDimensions(*hlo)) { - reductions.push_back(Cast(hlo)); - } else { - extra_output_gens[hlo] = *fused_emitter.GetGenerator(*hlo); - } - } - - CHECK(!reductions.empty()) << " expect at least one reduce instructions."; - const TilingScheme& tiling_scheme = reduction_info.GetTilingScheme(); - CHECK_EQ(tiling_scheme.GetNumThreadsPerBlockPhysical() % WarpSize(), 0); - llvm::Type* index_ty = - GetIndexTypeForKernel(fusion, - tiling_scheme.GetNumThreadsPerBlockPhysical() * - tiling_scheme.GetNumberOfBlocksPhysical(), - &b_); - ReductionCodegenState codegen_state = GenerateReductionCodegenState( - fusion, reduction_info, reductions, fused_emitter); - - EmitTileElementFunction emit_reduction_element = - [&](const TilingThreadIdInfo& thread_id_info, const IrArray::Index& index, - llvm::Value* y_loc, llvm::Value* x_loc) { - IrArray::Index input_index = GetUnnormalizedIndex( - index, input_shape, &b_, - codegen_state.GetTilingScheme().GetDimsInElems()); - llvm::Value* partial_result_index = - codegen_state.IsRowReduction() - ? b_.getInt32(0) - : b_.CreateSub( - x_loc, - GetStartOffsetX(tiling_scheme, thread_id_info.thread_id_x, - index_ty, &b_)); - - // Clear the linear index field of the IrArray::Index to enable the use - // of GetElementPointer with array types. This enables the vectorization - // of the computation for different partial results. Use this index if - // 'num_partial_results > 1'. - int num_partial_results = codegen_state.GetNumPartialResults(); - IrArray::Index index_without_linear{input_index.multidim(), input_shape, - input_index.GetType()}; - - // Emit code to generate the input and perform the reduction computation - // for each reduction instruction. - for (const HloReduceInstruction* reduce : reductions) { - GenerateElementForReducer(reduce, partial_result_index, codegen_state, - index_without_linear, input_index, - num_partial_results, result_ir_arrays); - } - - // Emit code to generate the output for the non-reduction instructions - // in the fusion, if any. - TF_CHECK_OK(EmitExtraOutputsForReduce(input_shape, result_ir_arrays, - input_index, reduction_info, - extra_output_gens)); - }; - - TF_ASSIGN_OR_RETURN( - TilingKernelInfo tiling_kernel_info, - EmitTilingKernel( - &b_, tiling_scheme, index_ty, - [&](const TilingThreadIdInfo& thread_id_info, - const IrArray::Index& index, ValueVector2 tile_dimensions) { - EmitTile(&b_, codegen_state.GetTilingScheme(), index, - thread_id_info, tile_dimensions, emit_reduction_element); - })); - - KernelSupportLibrary ksl(&b_); - for (const HloReduceInstruction* reduce : reductions) { - for (int partial_result_idx = 0; - partial_result_idx < reduction_info.GetNumPartialResults(); - ++partial_result_idx) { - if (codegen_state.IsRowReduction()) { - EmitReductionOutputForRowReduction(tiling_kernel_info, codegen_state, - index_ty, result_ir_arrays, reduce, - partial_result_idx); - } else { - EmitReductionOutputForColumnReduction(tiling_kernel_info, codegen_state, - index_ty, result_ir_arrays, - reduce, partial_result_idx); - } - } - } - - return OkStatus(); -} - -Status IrEmitterUnnested::EmitUnnestedReduction( - mlir::lmhlo::FusionOp fusion, HloFusionAnalysis& fusion_analysis) { - auto* reduction_codegen_info = fusion_analysis.GetReductionCodegenInfo(); - // Set flag to false as Reduction has it's own custom logic of choosing a - // block size. - TF_ASSIGN_OR_RETURN(auto launch_dimensions, - fusion_analysis.GetLaunchDimensions( - /*use_experimental_block_size=*/false)); - - VLOG(3) << "Launch dimensions of " - << mlir::mhlo::GetDebugNameFromLocation(fusion.getLoc()) << ": " - << launch_dimensions.ToString(); - if (!reduction_codegen_info->IsRaceFree()) { - absl::Span fusion_roots = - fusion_analysis.fusion_roots(); - for (int i = 0; i < fusion_roots.size(); ++i) { - if (IsReductionFromOrToContiguousDimensions(*fusion_roots[i])) { - TF_RETURN_IF_ERROR(BuildFusedInitializerThunk(fusion, i)); - } - } - } - - TF_ASSIGN_OR_RETURN( - std::optional> opt_ir_arrays, - BuildKernelThunkForFusion(fusion, launch_dimensions)); - if (!opt_ir_arrays.has_value()) { - // The kernel was reused, no need to emit code. - return OkStatus(); - } - std::vector& ir_arrays = opt_ir_arrays.value(); - - FusedIrEmitter fused_emitter(elemental_emitter_); - const HloComputation* fused_computation = fusion_analysis.fused_computation(); - CHECK_LT(fused_computation->num_parameters(), ir_arrays.size()); - for (int i = 0; i < fused_computation->num_parameters(); i++) { - llvm_ir::IrArray ir_array = ir_arrays[i]; - HloInstruction* fused_operand = fused_computation->parameter_instruction(i); - fused_emitter.BindGenerator( - *fused_operand, - [this, ir_array, fused_operand](const llvm_ir::IrArray::Index& index) { - return ir_array.EmitReadArrayElement(index, &b_, - fused_operand->name()); - }); - } - - // Get outputs. - ReductionOutputMap result_ir_arrays; - - // Skip all parameter buffers first. - int ir_arrays_idx = fused_computation->num_parameters(); - for (HloInstruction* root : fusion_analysis.fusion_roots()) { - int get_num_results = GetNumOutputs(root->shape()); - result_ir_arrays[root] = - absl::MakeSpan(ir_arrays).subspan(ir_arrays_idx, get_num_results); - ir_arrays_idx += get_num_results; - } - - KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll); - - // Use raw block_id_y to select the i-th parallel reduction to run. Using - // block_id_y instead of block_id_x simplifies the index calculation - // for reduction code generation as the block_id_y is orthogonal to - // the indices used within the reductions. - const std::vector>& instr_index_groups = - reduction_codegen_info->GetIndexGroups(); - Shape reduce_operand_shape = reduction_codegen_info->GetReduceOperandShape(); - - llvm::CallInst* raw_block_id_y = gpu::EmitCallToTargetIntrinsic( - gpu::TargetIntrinsicID::kBlockIdy, {}, {}, &b_); - llvm_ir::AddRangeMetadata(0, instr_index_groups.size(), - llvm::cast(raw_block_id_y)); - for (int i = 0; i < instr_index_groups.size(); ++i) { - TF_RETURN_IF_ERROR(ksl.IfWithStatus( - StrCat("reduce-group-", i), - b_.CreateICmpEQ(raw_block_id_y, b_.getInt32(i)), [&] { - return EmitIRForReduction( - fusion, instr_index_groups[i], fused_emitter, result_ir_arrays, - *reduction_codegen_info, reduce_operand_shape); - })); - } - - return OkStatus(); -} - // Emits code for slices based on the below structure. An if statement with // a guarding condition is generated for each ROOT slice. // diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index a55276d9dcdb4f..905ed163d64ead 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -324,75 +324,6 @@ class IrEmitterUnnested : public IrEmitter { const ReductionCodegenInfo& reduction_info, const ExtraOutputGensMap& extra_output_gens); - // Generates code for reduction to contiguous dimensions. - // - // Row reduction uses the following algorithm described in CUDA-like - // pseudocode: - // - // ``` - // __global__ void reduce(int num_rows, float *in, float out) { - // __shared__ float[32] cache; - // int offset = blockDim.x * blockIdx.x + threadIdx.x; - // if (offset >= num_rows) return; - // int tile_bound = std::min(offset + kTileSizeX, num_rows); - // float accum = 0; - // for (int i=offset; i reduce_instr_index_group, - FusedIrEmitter& fused_emitter); - - // Wraps up the code generation for a tile block of a reduction kernel: - // write the calculated output into the output tensor. - void EmitReductionOutput( - llvm::Type* index_ty, mlir::lmhlo::FusionOp fusion, - absl::Span reduce_instr_index_group, - const ReductionOutputMap& result_ir_arrays, - const ReductionCodegenState& reduction_codegen_state, - const TilingKernelInfo& tiling_kernel_info); - - // Returns the address to write the reduction output to. - llvm::Value* GetOutputAddressForReduction( - int partial_result_idx, llvm::Type* index_ty, - const ReductionCodegenState& reduction_codegen_state, - const TilingKernelInfo& tiling_kernel_info, - const ReductionOutputMap& output_arrays, - const HloReduceInstruction* reduction, int output_idx); - - // Performs the actual write of the reduction result. - using TypedPointer = std::pair; - void WriteReductionOutput( - llvm::Type* index_ty, - const ReductionCodegenState& reduction_codegen_state, - const TilingKernelInfo& tiling_kernel_info, - const ReductionOutputMap& output_arrays, - const HloReduceInstruction* reduction, int partial_result_idx, - const absl::Span values); - - // `current_output`: the value the tile has calculated. - // `output_address`: address where the output value has to be written. - void EmitReductionOutputForRowReduction( - const TilingKernelInfo& tiling_kernel_info, - const ReductionCodegenState& reduction_codegen_state, - llvm::Type* index_ty, const ReductionOutputMap& output_arrays, - const HloReduceInstruction* reduction, int partial_result_idx); - - // Same arguments as EmitReductionOutputForRowReduction. - void EmitReductionOutputForColumnReduction( - const TilingKernelInfo& tiling_kernel_info, - const ReductionCodegenState& reduction_codegen_state, - llvm::Type* index_ty, const ReductionOutputMap& output_arrays, - const HloReduceInstruction* reduction, int partial_result_idx); - - // Emits code for reductions in the output_instructions. - Status EmitIRForReduction(mlir::lmhlo::FusionOp fusion, - absl::Span instr_index_group, - FusedIrEmitter& fused_emitter, - const ReductionOutputMap& result_ir_arrays, - const ReductionCodegenInfo& reduction_info, - const Shape& input_shape); - - // Generate a single element of the tile (update the accumulator state) for a - // given reducer of index `i`. - void GenerateElementForReducer( - const HloReduceInstruction* reduction, llvm::Value* partial_result_index, - const ReductionCodegenState& codegen_state, - const llvm_ir::IrArray::Index& index_without_linear, - const llvm_ir::IrArray::Index& input_index, int num_partial_results, - const ReductionOutputMap& result_ir_arrays); - - // Emits shuffle-down reduction for the `partial_result_address` using the - // reduction computation `reducer`, writes output into - // `partial_result_address`. - // - // Multiple partial_result_address inputs happen when doing variadic - // reduction: each one should get the output value. - void EmitFullWarpShuffleDownLoopForReduce( - const HloComputation* reducer, - absl::Span partial_result_addresses, - int threads_per_block, int num_results_per_warp = 1); - // Allocates a shared tile of given dimensions, applying scaling specified in // tilng_scheme as a major-most dimension to avoid collisions. llvm::GlobalVariable* AllocateShared( @@ -618,20 +472,8 @@ class IrEmitterUnnested : public IrEmitter { mlir::Operation* op, mlir::ValueRange needed_operands, const LaunchDimensions& launch_dimensions); - // Returns a thunk that, given a reduce or select-and-scatter op, - // initializes its memory to the appropriate initial value. - std::unique_ptr BuildConstantInitializerThunk( - mlir::Operation* op, absl::Span init_value, - mlir::Value dest, const BufferAllocation::Slice& dest_slice, - const Shape& output_shape); - - StatusOr> TryBuildConstantInitializerThunk( - mlir::Operation* op, mlir::Value init_value, mlir::Value dest); - Status BuildInitializerThunk(mlir::Operation* op, mlir::Value init_value, mlir::Value dest); - Status BuildFusedInitializerThunk(mlir::lmhlo::FusionOp fusion, - int output_index); // Returns a WhileThunk that invokes thunk sequences for 'condition' and // 'body' sub-computations of while instruction 'hlo'. From 3b205a38b4cd3dc7fede52f6481b8400eedbb56c Mon Sep 17 00:00:00 2001 From: Mohammed Anany Date: Tue, 25 Jul 2023 02:24:41 -0700 Subject: [PATCH 099/410] [XLA:GPU] Modifying -NaN scatter test to also run the operand with -NaN. PiperOrigin-RevId: 550820365 --- tensorflow/compiler/xla/tests/scatter_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/compiler/xla/tests/scatter_test.cc b/tensorflow/compiler/xla/tests/scatter_test.cc index 12f09e11b67fd3..4af119cbb1a9bc 100644 --- a/tensorflow/compiler/xla/tests/scatter_test.cc +++ b/tensorflow/compiler/xla/tests/scatter_test.cc @@ -903,7 +903,7 @@ max_f32 (lhs: f32[], rhs: f32[]) -> f32[] { ENTRY main { indices = s32[2] parameter(0) - constant_with_nans = f32[3] constant({nan, 2.5, nan}) + constant_with_nans = f32[3] constant({-nan, 2.5, nan}) operand = f32[3,3] broadcast(constant_with_nans), dimensions={1} updates = f32[2,3] constant({{4.6, -nan, 1}, {2.3, 3.1, 1.6}}) scatter = f32[3,3] scatter(operand, indices, updates), From ad42fc8d17a7e71e4492a87c82410c94b1cbd8e2 Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Tue, 25 Jul 2023 03:31:47 -0700 Subject: [PATCH 100/410] Extract transpose fusions from ir_emitter_unnested. PiperOrigin-RevId: 550833735 --- .../compiler/xla/service/gpu/fusions/BUILD | 23 ++ .../xla/service/gpu/fusions/fusions.cc | 4 + .../xla/service/gpu/fusions/transpose.cc | 232 ++++++++++++++++++ .../xla/service/gpu/fusions/transpose.h | 81 ++++++ .../xla/service/gpu/ir_emitter_unnested.cc | 222 +---------------- .../xla/service/gpu/ir_emitter_unnested.h | 46 ---- 6 files changed, 341 insertions(+), 267 deletions(-) create mode 100644 tensorflow/compiler/xla/service/gpu/fusions/transpose.cc create mode 100644 tensorflow/compiler/xla/service/gpu/fusions/transpose.h diff --git a/tensorflow/compiler/xla/service/gpu/fusions/BUILD b/tensorflow/compiler/xla/service/gpu/fusions/BUILD index 644544246200ce..5bd8adde24b674 100644 --- a/tensorflow/compiler/xla/service/gpu/fusions/BUILD +++ b/tensorflow/compiler/xla/service/gpu/fusions/BUILD @@ -62,6 +62,7 @@ cc_library( ":in_place_dynamic_update_slice", ":loop", ":reduction", + ":transpose", "//tensorflow/compiler/xla/hlo/ir:hlo", "//tensorflow/compiler/xla/mlir_hlo:lhlo", "//tensorflow/compiler/xla/service:elemental_ir_emitter", @@ -147,3 +148,25 @@ cc_library( "@llvm-project//mlir:MemRefDialect", ], ) + +cc_library( + name = "transpose", + srcs = ["transpose.cc"], + hdrs = ["transpose.h"], + deps = [ + ":fusion_emitter", + ":tiling_util", + "//tensorflow/compiler/xla:permutation_util", + "//tensorflow/compiler/xla/hlo/ir:hlo", + "//tensorflow/compiler/xla/mlir_hlo:lhlo", + "//tensorflow/compiler/xla/service:elemental_ir_emitter", + "//tensorflow/compiler/xla/service/gpu:hlo_fusion_analysis", + "//tensorflow/compiler/xla/service/gpu:ir_emitter_context", + "//tensorflow/compiler/xla/service/gpu:parallel_loop_emitter", + "//tensorflow/compiler/xla/service/gpu:target_util", + "//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter", + "//tensorflow/compiler/xla/service/llvm_ir:ir_array", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "@llvm-project//llvm:ir_headers", + ], +) diff --git a/tensorflow/compiler/xla/service/gpu/fusions/fusions.cc b/tensorflow/compiler/xla/service/gpu/fusions/fusions.cc index 34723c1ec64eef..cbbd60050eda7f 100644 --- a/tensorflow/compiler/xla/service/gpu/fusions/fusions.cc +++ b/tensorflow/compiler/xla/service/gpu/fusions/fusions.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/fusions/in_place_dynamic_update_slice.h" #include "tensorflow/compiler/xla/service/gpu/fusions/loop.h" #include "tensorflow/compiler/xla/service/gpu/fusions/reduction.h" +#include "tensorflow/compiler/xla/service/gpu/fusions/transpose.h" #include "tensorflow/compiler/xla/service/gpu/hlo_fusion_analysis.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" @@ -53,6 +54,9 @@ std::optional> GetFusionEmitter( case HloFusionAnalysis::EmitterFusionKind::kReduction: return std::make_unique( ir_emitter_context, elemental_emitter, fusion_op, fusion, analysis); + case HloFusionAnalysis::EmitterFusionKind::kTranspose: + return std::make_unique( + ir_emitter_context, elemental_emitter, fusion_op, fusion, analysis); case HloFusionAnalysis::EmitterFusionKind::kLoop: { bool is_single = IsSingleInstructionFusion(fusion_op); if (!is_single && CanEmitFusedDynamicUpdateSliceInPlaceForGpu( diff --git a/tensorflow/compiler/xla/service/gpu/fusions/transpose.cc b/tensorflow/compiler/xla/service/gpu/fusions/transpose.cc new file mode 100644 index 00000000000000..c56323fc4aca78 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/fusions/transpose.cc @@ -0,0 +1,232 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/xla/service/gpu/fusions/transpose.h" + +#include + +#include "llvm/IR/IRBuilder.h" +#include "tensorflow/compiler/xla/permutation_util.h" +#include "tensorflow/compiler/xla/service/gpu/fusions/tiling_util.h" +#include "tensorflow/compiler/xla/service/gpu/target_util.h" +#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" +#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" + +namespace xla { +namespace gpu { +namespace { + +llvm::GlobalVariable* AllocateShared( + llvm::IRBuilder<>* builder, const TilingScheme& tiling_scheme, + llvm::Type* element_type, + absl::Span dimensions_major_to_minor, + absl::string_view buffer_name) { + CHECK(!dimensions_major_to_minor.empty()); + llvm::Type* ty = element_type; + for (auto dim : llvm::reverse(dimensions_major_to_minor)) { + ty = llvm::ArrayType::get(ty, dim); + } + ty = llvm::ArrayType::get(ty, tiling_scheme.GetThreadIdScalingFactor()); + return llvm_ir::AllocateSharedMemoryTile( + builder->GetInsertBlock()->getModule(), ty, buffer_name); +} + +void MaybeEmitFenceForAMDGPU(llvm::IRBuilder<>* builder, + IrEmitterContext& ir_emitter_context) { + auto* module = builder->GetInsertBlock()->getModule(); + if (IsAMDGPU(module) && + ir_emitter_context.rocm_compute_capability().gcn_arch_name().substr( + 0, 6) == "gfx90a") { + builder->CreateFence( + llvm::AtomicOrdering::SequentiallyConsistent, + builder->getContext().getOrInsertSyncScopeID("workgroup")); + } +} + +void EmitSyncThreads(llvm::IRBuilder<>* builder, + IrEmitterContext& ir_emitter_context) { + MaybeEmitFenceForAMDGPU(builder, ir_emitter_context); + EmitCallToTargetIntrinsic(TargetIntrinsicID::kBarrierId, {}, {}, builder); +} + +llvm_ir::IrArray::Index PermuteIndex(const llvm_ir::IrArray::Index& index, + absl::Span permutation) { + return llvm_ir::IrArray::Index{Permute(index.multidim(), permutation), + Permute(index.dims(), permutation), + index.GetType()}; +} + +} // namespace + +Status TransposeFusion::EmitKernel(const LaunchDimensions& launch_dims, + std::vector inputs, + std::vector outputs, + llvm::IRBuilder<>* builder, + int kernel_index) const { + const auto& tiling_scheme = *analysis_.GetTransposeTilingScheme(); + std::vector hlo_roots = + GetFusionRoots(fusion().fused_instructions_computation()); + FusedIrEmitter fused_emitter(elemental_emitter()); + for (auto [i, input] : llvm::enumerate(inputs)) { + HloInstruction* fused_operand = fusion().fused_parameter(i); + fused_emitter.BindGenerator( + *fused_operand, [input = input, builder, + fused_operand](const llvm_ir::IrArray::Index& index) { + return input.EmitReadArrayElement(index, builder, + fused_operand->name()); + }); + } + + absl::flat_hash_map tiles; + Vector3 permutation; + for (const auto& [tile_idx, root] : llvm::enumerate(hlo_roots)) { + if (auto tr = FindAnyTiledTranspose(*root)) { + permutation = tr->permutation; + const HloInstruction& hero = FindNonTrivialHero(*root); + tiles[&hero] = AllocateShared( + builder, tiling_scheme, + llvm_ir::PrimitiveTypeToIrType( + hero.operand(0)->shape().element_type(), + ir_emitter_context().llvm_module()), + {tiling_scheme.GetBlockTileSizeFor(permutation[TilingScheme::DimX]), + tiling_scheme.GetBlockTileSizeFor(TilingScheme::DimX) + 1}, + absl::StrCat("tr_tile_", tile_idx)); + } + } + + TileElementGenerator tile_generator = + [&](const TilingThreadIdInfo& thread_id_info, + const llvm_ir::IrArray::Index& index, + std::array tile_dimensions) { + // Copy input parameter values to shared memory buffers: + // tile[thread_id_y, thread_id_x] = input[index] + // Note that tile_width and tile_height are flipped here because we + // are reading a transposed tile. + EmitTile( + builder, tiling_scheme, index, thread_id_info, tile_dimensions, + [&](const TilingThreadIdInfo& thread_id_info, + const llvm_ir::IrArray::Index& index, llvm::Value* y_loc, + llvm::Value* x_loc) { + // Compute all extra output values before writing them. This + // avoids overwriting aliased input/output values before all reads + // occurred. + std::vector> + scheduled_writes; + + for (const auto& [output_idx, root] : + llvm::enumerate(hlo_roots)) { + if (FindAnyTiledTranspose(*root)) { + const HloInstruction& hero = FindNonTrivialHero(*root); + llvm_ir::ElementGenerator input_gen = + *fused_emitter.GetGenerator(*hero.operand(0)); + llvm_ir::IrArray::Index untiled_index = GetUnnormalizedIndex( + index, hero.operand(0)->shape(), builder, + tiling_scheme.GetDimsInElems()); + llvm::Value* value = *input_gen(untiled_index); + llvm::Value* addr = thread_id_info.GEPIntoSharedMemory( + builder, tiles[&hero], {y_loc, x_loc}); + + builder->CreateStore(value, addr); + } else { + llvm_ir::IrArray::Index untiled_index = + GetUnnormalizedIndex(index, root->shape(), builder, + tiling_scheme.GetDimsInElems()); + llvm_ir::ElementGenerator output_gen = + *fused_emitter.GetGenerator(*root); + llvm::Value* output_value = *output_gen(untiled_index); + scheduled_writes.emplace_back(outputs[output_idx], + untiled_index, output_value); + } + } + + for (const auto& [output, idx, value] : scheduled_writes) { + output.EmitWriteArrayElement(idx, value, builder); + } + }); + + EmitSyncThreads(builder, ir_emitter_context()); + + llvm_ir::IrArray::Index output_tile_index = + PermuteIndex(index, permutation); + std::array transposed_tile_dimensions = { + tile_dimensions[1], tile_dimensions[0]}; + + EmitTile( + builder, tiling_scheme, output_tile_index, thread_id_info, + transposed_tile_dimensions, + /*emit_elem_function=*/ + [&](const TilingThreadIdInfo& thread_id_info, + const llvm_ir::IrArray::Index& index, llvm::Value* y_loc, + llvm::Value* x_loc) { + for (const auto& [output_idx, root] : + llvm::enumerate(hlo_roots)) { + if (FindAnyTiledTranspose(*root)) { + const HloInstruction& hero = FindNonTrivialHero(*root); + + std::vector idx = {x_loc, y_loc}; + llvm::Value* gep = thread_id_info.GEPIntoSharedMemory( + builder, tiles[&hero], idx); + llvm::Type* type = + thread_id_info.GEPIntoSharedMemoryType(tiles[&hero], idx); + llvm::Value* loaded = + builder->CreateLoad(type, gep, "tiled_buffer"); + + FusedIrEmitter fused_emitter(elemental_emitter()); + fused_emitter.BindGenerator( + hero, [&](const llvm_ir::IrArray::Index& index) { + return loaded; + }); + for (int64_t i = 0; i < fusion() + .fused_instructions_computation() + ->num_parameters(); + ++i) { + llvm_ir::IrArray ir_array = inputs[i]; + HloInstruction* fused_operand = fusion().fused_parameter(i); + fused_emitter.BindGenerator( + *fused_operand, + [=](const llvm_ir::IrArray::Index& index) { + return ir_array.EmitReadArrayElement( + index, builder, fused_operand->name()); + }); + } + + // Apply codegeneration for the code after the real hero. + TF_ASSIGN_OR_RETURN(llvm_ir::ElementGenerator gen, + fused_emitter.GetGenerator(*root)); + + // Both for emission and writing it should be + // index-as-transformed by the computation. + llvm_ir::IrArray::Index untiled_index = GetUnnormalizedIndex( + index, root->shape(), builder, + Permute(tiling_scheme.GetDimsInElems(), permutation)); + TF_ASSIGN_OR_RETURN(llvm::Value * generated, + gen(untiled_index)); + outputs[output_idx].EmitWriteArrayElement(untiled_index, + generated, builder); + } + } + return OkStatus(); + }); + }; + + llvm::Type* index_type = + GetIndexTypeForKernel(fusion_op(), launch_dims.launch_bound(), builder); + return EmitTilingKernel(builder, tiling_scheme, index_type, tile_generator) + .status(); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/fusions/transpose.h b/tensorflow/compiler/xla/service/gpu/fusions/transpose.h new file mode 100644 index 00000000000000..cfd1ff9eef96c0 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/fusions/transpose.h @@ -0,0 +1,81 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FUSIONS_TRANSPOSE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FUSIONS_TRANSPOSE_H_ + +#include + +#include "tensorflow/compiler/xla/hlo/ir/hlo_instructions.h" +#include "tensorflow/compiler/xla/mlir_hlo/lhlo/IR/lhlo_ops.h" +#include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" +#include "tensorflow/compiler/xla/service/gpu/fusions/fusion_emitter.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_fusion_analysis.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" + +namespace xla { +namespace gpu { + +// Emits a kernel for the given hlo instruction using a tiled 0-2-1 transpose +// algorithm to improve the memory access patterns for the input parameters +// with a shape that is a 0-2-1 transpose of the output tensor shape. The +// caller is responsible for making sure that it is safe to apply the shared +// memory transpose on the input parameters. +// +// For the purpose of tiling, the output tensors have a logical shape of three +// components 0-2-1 while the relevant input parameters have a logical shape +// of three components 0-1-2 in the order major to minor. The x- and y- +// dimensions of the tensors are tiled in square tiles with an edge length +// `kTileSize`. Each thread block of `kTileSize` x `kNumRows` threads +// transposes one tile: each thread copies kTileSize/kNumRows elements from +// the input to a shared memory tile, then the otherwise "regular HLO kernel" +// reads from the shared memory instead of the original input. +// +// This is similar to the following CUDA algorithm in TensorFlow: +// https://goo.gl/MStRV6. +// +// `kTileSize` should usually be same as warp size. We currently choose 32 for +// `kTileSize` and 4 for `kNumRows`. The CUDA algorithm uses 8 for `kNumRows`. +// +// TODO(b/33320379): Here each block transposes 1 tile. It may be more +// efficient to launch fewer blocks so each transposes many tiles. +class TransposeFusion : public KernelFusionEmitterBase { + public: + TransposeFusion(IrEmitterContext& ir_emitter_context, + ElementalIrEmitter& elemental_emitter, + mlir::lmhlo::FusionOp fusion_op, + const HloFusionInstruction& fusion, + HloFusionAnalysis& analysis) + : KernelFusionEmitterBase(ir_emitter_context, elemental_emitter, + fusion_op, fusion), + analysis_(analysis) {} + StatusOr launch_dimensions() const override { + return analysis_.GetLaunchDimensions(false); + } + + protected: + Status EmitKernel(const LaunchDimensions& launch_dims, + std::vector inputs, + std::vector outputs, + llvm::IRBuilder<>* builder, + int kernel_index) const override; + + private: + HloFusionAnalysis& analysis_; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FUSIONS_TRANSPOSE_H_ diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 3e23e291c96330..e220091a188ddd 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -21,7 +21,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -69,7 +68,6 @@ limitations under the License. #include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" // from @llvm-project #include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h" // from @llvm-project #include "mlir/Target/LLVMIR/Export.h" // from @llvm-project -#include "tensorflow/compiler/xla/hlo/ir/hlo_casting_utils.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_computation.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_instructions.h" @@ -79,7 +77,6 @@ limitations under the License. #include "tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.h" #include "tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "tensorflow/compiler/xla/mlir_hlo/transforms/gpu_passes.h" -#include "tensorflow/compiler/xla/permutation_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/custom_call_target_registry.h" @@ -93,7 +90,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/fused_mha_thunk.h" #include "tensorflow/compiler/xla/service/gpu/fusions/fusions.h" #include "tensorflow/compiler/xla/service/gpu/fusions/thunk_util.h" -#include "tensorflow/compiler/xla/service/gpu/fusions/tiling_util.h" #include "tensorflow/compiler/xla/service/gpu/gemm_thunk.h" #include "tensorflow/compiler/xla/service/gpu/gpu_asm_opts_util.h" #include "tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h" @@ -111,7 +107,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h" #include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h" #include "tensorflow/compiler/xla/service/gpu/matmul_utils.h" -#include "tensorflow/compiler/xla/service/gpu/memset_thunk.h" #include "tensorflow/compiler/xla/service/gpu/nccl_all_gather_thunk.h" #include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h" #include "tensorflow/compiler/xla/service/gpu/nccl_all_to_all_thunk.h" @@ -1946,32 +1941,6 @@ Status IrEmitterUnnested::EmitTritonFusion( #endif // GOOGLE_CUDA -Status IrEmitterUnnested::EmitUnnestedTranspose( - mlir::lmhlo::FusionOp fusion, HloFusionAnalysis& fusion_analysis) { - auto* tiling_scheme = fusion_analysis.GetTransposeTilingScheme(); - // Set flag to false as Transpose has it's own custom logic of choosing a - // block size. - TF_ASSIGN_OR_RETURN(auto launch_dimensions, - fusion_analysis.GetLaunchDimensions( - /*use_experimental_block_size=*/false)); - - TF_ASSIGN_OR_RETURN( - std::optional> opt_ir_arrays, - BuildKernelThunkForFusion(fusion, launch_dimensions)); - if (!opt_ir_arrays.has_value()) { - // The kernel was reused, no need to emit code. - return OkStatus(); - } - std::vector& ir_arrays = opt_ir_arrays.value(); - - TF_RETURN_IF_ERROR(EmitTransposeTile( - fusion, fusion_analysis.fused_computation(), - absl::MakeSpan(ir_arrays).subspan(0, fusion.getInputBuffers().size()), - absl::MakeSpan(ir_arrays).subspan(fusion.getInputBuffers().size()), - *tiling_scheme, launch_dimensions)); - return OkStatus(); -} - Status IrEmitterUnnested::EmitFusion(mlir::Operation* op) { auto fusion_op = mlir::cast(op); @@ -2047,14 +2016,13 @@ Status IrEmitterUnnested::EmitFusion(mlir::Operation* op) { #endif LOG(FATAL) << "Unsupported fusion kind: " << backend_config.kind(); } - case HloFusionAnalysis::EmitterFusionKind::kTranspose: - return EmitUnnestedTranspose(fusion_op, fusion_analysis); case HloFusionAnalysis::EmitterFusionKind::kInputSlices: return EmitInputFusibleNonStridedSlices(op, fusion_analysis); case HloFusionAnalysis::EmitterFusionKind::kScatter: return EmitScatter(fusion_op, fused_computation, fusion_analysis); case HloFusionAnalysis::EmitterFusionKind::kLoop: case HloFusionAnalysis::EmitterFusionKind::kReduction: + case HloFusionAnalysis::EmitterFusionKind::kTranspose: return FailedPrecondition( "Loop fusion should have been handled by GetFusionEmitter."); } @@ -3293,194 +3261,6 @@ Status IrEmitterUnnested::EmitTargetElementLoop( return InternalError("This should be unreachable"); } -llvm::Value* IrEmitterUnnested::CastSharedToGlobal(llvm::Value* input, - llvm::Type* element_type, - llvm::Twine name) { - return b_.CreateAddrSpaceCast(input, - llvm::PointerType::get(element_type, - /*AddressSpace=*/0), - name); -} - -llvm::CallInst* IrEmitterUnnested::EmitSyncThreads() { - MaybeEmitFenceForAMDGPU(llvm::AtomicOrdering::SequentiallyConsistent, - "workgroup"); - return EmitCallToTargetIntrinsic(TargetIntrinsicID::kBarrierId, {}, {}, &b_); -} - -static IrArray::Index PermuteIndex(const IrArray::Index& index, - absl::Span permutation) { - return IrArray::Index{Permute(index.multidim(), permutation), - Permute(index.dims(), permutation), index.GetType()}; -} - -Status IrEmitterUnnested::EmitTransposeTile( - mlir::lmhlo::FusionOp fusion, const HloComputation* fusion_hlo, - absl::Span operand_arrays, - absl::Span output_arrays, - const TilingScheme& tiling_scheme, - const LaunchDimensions& launch_dimensions) { - std::vector hlo_roots = - GetFusionRoots(const_cast(fusion_hlo)); - FusedIrEmitter fused_emitter(elemental_emitter_); - for (int i = 0; i < fusion_hlo->num_parameters(); i++) { - llvm_ir::IrArray ir_array = operand_arrays[i]; - HloInstruction* fused_operand = fusion_hlo->parameter_instruction(i); - fused_emitter.BindGenerator( - *fused_operand, - [this, ir_array, fused_operand](const llvm_ir::IrArray::Index& index) { - return ir_array.EmitReadArrayElement(index, &b_, - fused_operand->name()); - }); - } - - absl::flat_hash_map tiles; - Vector3 permutation; - for (const auto& [tile_idx, root] : llvm::enumerate(hlo_roots)) { - if (auto tr = FindAnyTiledTranspose(*root)) { - permutation = tr->permutation; - const HloInstruction& hero = FindNonTrivialHero(*root); - tiles[&hero] = - AllocateShared(tiling_scheme, - llvm_ir::PrimitiveTypeToIrType( - hero.operand(0)->shape().element_type(), module_), - {tiling_scheme.GetBlockTileSizeFor(permutation[kDimX]), - tiling_scheme.GetBlockTileSizeFor(kDimX) + 1}, - absl::StrCat("tr_tile_", tile_idx)); - } - } - - TileElementGenerator tile_generator = [&](const TilingThreadIdInfo& - thread_id_info, - const IrArray::Index& index, - ValueVector2 tile_dimensions) { - // Copy input parameter values to shared memory buffers: - // tile[thread_id_y, thread_id_x] = input[index] - // Note that tile_width and tile_height are flipped here because we - // are reading a transposed tile. - EmitTile( - &b_, tiling_scheme, index, thread_id_info, tile_dimensions, - [&](const TilingThreadIdInfo& thread_id_info, - const IrArray::Index& index, llvm::Value* y_loc, - llvm::Value* x_loc) { - // Compute all extra output values before writing them. This avoids - // overwriting aliased input/output values before all reads occurred. - std::vector> - scheduled_writes; - - for (const auto& [output_idx, root] : llvm::enumerate(hlo_roots)) { - if (FindAnyTiledTranspose(*root)) { - const HloInstruction& hero = FindNonTrivialHero(*root); - llvm_ir::ElementGenerator input_gen = - *fused_emitter.GetGenerator(*hero.operand(0)); - IrArray::Index untiled_index = - GetUnnormalizedIndex(index, hero.operand(0)->shape(), &b_, - tiling_scheme.GetDimsInElems()); - llvm::Value* value = *input_gen(untiled_index); - llvm::Value* addr = thread_id_info.GEPIntoSharedMemory( - &b_, tiles[&hero], {y_loc, x_loc}); - - b_.CreateStore(value, addr); - } else { - IrArray::Index untiled_index = GetUnnormalizedIndex( - index, root->shape(), &b_, tiling_scheme.GetDimsInElems()); - llvm_ir::ElementGenerator output_gen = - *fused_emitter.GetGenerator(*root); - llvm::Value* output_value = *output_gen(untiled_index); - scheduled_writes.emplace_back(output_arrays[output_idx], - untiled_index, output_value); - } - } - - for (const auto& [output, idx, value] : scheduled_writes) { - output.EmitWriteArrayElement(idx, value, &b_); - } - }); - - EmitSyncThreads(); - - IrArray::Index output_tile_index = PermuteIndex(index, permutation); - ValueVector2 transposed_tile_dimensions = {tile_dimensions[1], - tile_dimensions[0]}; - - EmitTile( - &b_, tiling_scheme, output_tile_index, thread_id_info, - transposed_tile_dimensions, - /*emit_elem_function=*/ - [&](const TilingThreadIdInfo& thread_id_info, - const llvm_ir::IrArray::Index& index, llvm::Value* y_loc, - llvm::Value* x_loc) { - for (const auto& [output_idx, root] : llvm::enumerate(hlo_roots)) { - if (FindAnyTiledTranspose(*root)) { - const HloInstruction& hero = FindNonTrivialHero(*root); - - std::vector idx = {x_loc, y_loc}; - llvm::Value* gep = - thread_id_info.GEPIntoSharedMemory(&b_, tiles[&hero], idx); - llvm::Type* type = - thread_id_info.GEPIntoSharedMemoryType(tiles[&hero], idx); - llvm::Value* loaded = b_.CreateLoad(type, gep, "tiled_buffer"); - - FusedIrEmitter fused_emitter(elemental_emitter_); - fused_emitter.BindGenerator( - hero, [&](const IrArray::Index& index) { return loaded; }); - for (int64_t i = 0; i < fusion_hlo->num_parameters(); ++i) { - llvm_ir::IrArray ir_array = operand_arrays[i]; - HloInstruction* fused_operand = - fusion_hlo->parameter_instruction(i); - fused_emitter.BindGenerator( - *fused_operand, [this, ir_array, fused_operand]( - const llvm_ir::IrArray::Index& index) { - return ir_array.EmitReadArrayElement( - index, &b_, fused_operand->name()); - }); - } - - // Apply codegeneration for the code after the real hero. - TF_ASSIGN_OR_RETURN(llvm_ir::ElementGenerator gen, - fused_emitter.GetGenerator(*root)); - - // Both for emission and writing it should be index-as-transformed - // by the computation. - IrArray::Index untiled_index = GetUnnormalizedIndex( - index, root->shape(), &b_, - Permute(tiling_scheme.GetDimsInElems(), permutation)); - TF_ASSIGN_OR_RETURN(llvm::Value * generated, gen(untiled_index)); - output_arrays[output_idx].EmitWriteArrayElement(untiled_index, - generated, &b_); - } - } - return OkStatus(); - }); - }; - - llvm::Type* index_type = GetIndexTypeForKernel( - fusion.getOperation(), launch_dimensions.launch_bound(), &b_); - return EmitTilingKernel(&b_, tiling_scheme, index_type, tile_generator) - .status(); -} - -llvm::GlobalVariable* IrEmitterUnnested::AllocateShared( - const TilingScheme& tiling_scheme, llvm::Type* element_type, - absl::Span dimensions_major_to_minor, - absl::string_view buffer_name) { - CHECK(!dimensions_major_to_minor.empty()); - llvm::Type* array_type = nullptr; - for (int i = dimensions_major_to_minor.size() - 1; i >= 0; i--) { - // Iterate in minor-to-major order. - int64_t dim = dimensions_major_to_minor[i]; - if (!array_type) { - array_type = llvm::ArrayType::get(element_type, dim); - } else { - array_type = llvm::ArrayType::get(array_type, dim); - } - } - array_type = llvm::ArrayType::get(array_type, - tiling_scheme.GetThreadIdScalingFactor()); - return llvm_ir::AllocateSharedMemoryTile(b_.GetInsertBlock()->getModule(), - array_type, buffer_name); -} - // Emits code for slices based on the below structure. An if statement with // a guarding condition is generated for each ROOT slice. // diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index 905ed163d64ead..39f5b34e23238f 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -324,33 +324,6 @@ class IrEmitterUnnested : public IrEmitter { const ReductionCodegenInfo& reduction_info, const ExtraOutputGensMap& extra_output_gens); - // Emits a kernel for the given hlo instruction using a tiled 0-2-1 transpose - // algorithm to improve the memory access patterns for the input parameters - // with a shape that is a 0-2-1 transpose of the output tensor shape. The - // caller is responsible for making sure that it is safe to apply the shared - // memory transpose on the input parameters. - // - // - // For the purpose of tiling, the output tensors have a logical shape of three - // components 0-2-1 while the relevant input parameters have a logical shape - // of three components 0-1-2 in the order major to minor. The x- and y- - // dimensions of the tensors are tiled in square tiles with an edge length - // `kTileSize`. Each thread block of `kTileSize` x `kNumRows` threads - // transposes one tile: each thread copies kTileSize/kNumRows elements from - // the input to a shared memory tile, then the otherwise "regular HLO kernel" - // reads from the shared memory instead of the original input. - // - // This is similar to the following CUDA algorithm in TensorFlow: - // https://goo.gl/MStRV6. - // - // `kTileSize` should usually be same as warp size. We currently choose 32 for - // `kTileSize` and 4 for `kNumRows`. The CUDA algorithm uses 8 for `kNumRows`. - // - // TODO(b/33320379): Here each block transposes 1 tile. It may be more - // efficient to launch fewer blocks so each transposes many tiles. - Status EmitUnnestedTranspose(mlir::lmhlo::FusionOp fusion, - HloFusionAnalysis& fusion_analysis); - // Generates code for input-fusible slices. // // Prerequisite: ROOT is either a slice or a tuple of slices. The input shapes @@ -413,17 +386,6 @@ class IrEmitterUnnested : public IrEmitter { const HloComputation* fused_computation, HloFusionAnalysis& fusion_analysis); - // Allocates a shared tile of given dimensions, applying scaling specified in - // tilng_scheme as a major-most dimension to avoid collisions. - llvm::GlobalVariable* AllocateShared( - const TilingScheme& tiling_scheme, llvm::Type* element_type, - absl::Span dimensions_major_to_minor, - absl::string_view buffer_name = ""); - - // Removes some unneeded defining operations from the calculation of `value`, - // before passing it to a KernelThunk. - static StatusOr RemoveTransformingOperations(mlir::Value value); - // Builds a thunk that calls a new or reused kernel for a fusion operation. // // The caller must specify the same launch dimensions for fusions which have @@ -492,9 +454,6 @@ class IrEmitterUnnested : public IrEmitter { StatusOr> BuildConditionalThunk( const HloInstruction* conditional); - // Emit __syncthreads(), synchronization barrier for all threads in a block. - llvm::CallInst* EmitSyncThreads(); - StatusOr GetOrCreateSubComputationFromRegion( mlir::Region* region, bool is_fusion); @@ -513,11 +472,6 @@ class IrEmitterUnnested : public IrEmitter { scratch_nested_computations_; // End optional members for XLA HLO -> LMHLO. - // __shared__ memory uses a different address space, so we cast it to - // global address space before writing or reading. - llvm::Value* CastSharedToGlobal(llvm::Value* input, llvm::Type* element_type, - llvm::Twine name = ""); - // Returns the ShapedSlices for the given operands. StatusOr> GetShapedSlices( mlir::Operation::operand_range operands); From 0e2accc70cc22570166181d8599885e4620aef08 Mon Sep 17 00:00:00 2001 From: Fergus Henderson Date: Tue, 25 Jul 2023 03:37:23 -0700 Subject: [PATCH 101/410] Ensure that TfLiteRegistration objects are properly initialized. PiperOrigin-RevId: 550834660 --- tensorflow/lite/arena_planner_test.cc | 2 +- .../lite/core/tools/verifier_internal_test.cc | 2 +- tensorflow/lite/core/tools/verifier_test.cc | 14 +++++++------- tensorflow/lite/delegates/utils_test.cc | 2 +- .../lite/g3doc/models/convert/operation_fusion.md | 5 +++-- tensorflow/lite/simple_planner_test.cc | 2 +- 6 files changed, 14 insertions(+), 13 deletions(-) diff --git a/tensorflow/lite/arena_planner_test.cc b/tensorflow/lite/arena_planner_test.cc index ded2345c93e8c7..2a434d734f0ec9 100644 --- a/tensorflow/lite/arena_planner_test.cc +++ b/tensorflow/lite/arena_planner_test.cc @@ -59,7 +59,7 @@ class TestOp { : inputs_(inputs), outputs_(outputs), temporaries_(temporaries), - registration_(TfLiteRegistration()) { + registration_{} { registration_.builtin_code = builtin_code; registration_.inplace_operator = inplace_operator; } diff --git a/tensorflow/lite/core/tools/verifier_internal_test.cc b/tensorflow/lite/core/tools/verifier_internal_test.cc index fa063a11f17231..50b52def2f6fad 100644 --- a/tensorflow/lite/core/tools/verifier_internal_test.cc +++ b/tensorflow/lite/core/tools/verifier_internal_test.cc @@ -128,7 +128,7 @@ class TfLiteFlatbufferModelBuilder { flatbuffers::FlatBufferBuilder builder_; MutableOpResolver resolver_; - TfLiteRegistration fake_op_; + TfLiteRegistration fake_op_{}; std::vector> operators_; std::vector> operator_codes_; std::vector> tensors_; diff --git a/tensorflow/lite/core/tools/verifier_test.cc b/tensorflow/lite/core/tools/verifier_test.cc index b0681bf295125a..3d7378433e86d6 100644 --- a/tensorflow/lite/core/tools/verifier_test.cc +++ b/tensorflow/lite/core/tools/verifier_test.cc @@ -155,7 +155,7 @@ class TfLiteFlatbufferModelBuilder { flatbuffers::FlatBufferBuilder builder_; MutableOpResolver resolver_; - TfLiteRegistration fake_op_; + TfLiteRegistration fake_op_{}; MockErrorReporter mock_reporter_; std::vector> operators_; std::vector> operator_codes_; @@ -628,7 +628,7 @@ TEST(VerifyModel, SimpleValidSparseTensor) { ::tflite::FinishModelBuffer(builder, model_); MockErrorReporter mock_reporter; MutableOpResolver resolver; - TfLiteRegistration fake_op; + TfLiteRegistration fake_op{}; resolver.AddCustom("FakeOp", &fake_op); ASSERT_TRUE( Verify(builder.GetBufferPointer(), builder.GetSize(), &mock_reporter)); @@ -653,7 +653,7 @@ TEST(VerifyModel, InvalidSparseTensorMissingBlockMap) { ::tflite::FinishModelBuffer(builder, model_); MockErrorReporter mock_reporter; MutableOpResolver resolver; - TfLiteRegistration fake_op; + TfLiteRegistration fake_op{}; resolver.AddCustom("FakeOp", &fake_op); ASSERT_FALSE( Verify(builder.GetBufferPointer(), builder.GetSize(), &mock_reporter)); @@ -681,7 +681,7 @@ TEST(VerifyModel, InvalidSparseTensorIndexOutOfBound) { ::tflite::FinishModelBuffer(builder, model_); MockErrorReporter mock_reporter; MutableOpResolver resolver; - TfLiteRegistration fake_op; + TfLiteRegistration fake_op{}; resolver.AddCustom("FakeOp", &fake_op); ASSERT_FALSE( Verify(builder.GetBufferPointer(), builder.GetSize(), &mock_reporter)); @@ -708,7 +708,7 @@ TEST(VerifyModel, InvalidSparseTensorInvalidBuffer) { ::tflite::FinishModelBuffer(builder, model_); MockErrorReporter mock_reporter; MutableOpResolver resolver; - TfLiteRegistration fake_op; + TfLiteRegistration fake_op{}; resolver.AddCustom("FakeOp", &fake_op); ASSERT_FALSE( Verify(builder.GetBufferPointer(), builder.GetSize(), &mock_reporter)); @@ -737,7 +737,7 @@ TEST(VerifyModel, InvalidSparseTensorInvalidTraversalOrder) { ::tflite::FinishModelBuffer(builder, model_); MockErrorReporter mock_reporter; MutableOpResolver resolver; - TfLiteRegistration fake_op; + TfLiteRegistration fake_op{}; resolver.AddCustom("FakeOp", &fake_op); ASSERT_FALSE( Verify(builder.GetBufferPointer(), builder.GetSize(), &mock_reporter)); @@ -778,7 +778,7 @@ TEST(VerifyModel, ValidSparseTensorBCSC) { ::tflite::FinishModelBuffer(builder, model_); MockErrorReporter mock_reporter; MutableOpResolver resolver; - TfLiteRegistration fake_op; + TfLiteRegistration fake_op{}; resolver.AddCustom("FakeOp", &fake_op); ASSERT_TRUE( Verify(builder.GetBufferPointer(), builder.GetSize(), &mock_reporter)); diff --git a/tensorflow/lite/delegates/utils_test.cc b/tensorflow/lite/delegates/utils_test.cc index afa069607fcd39..e13e285fa31c57 100644 --- a/tensorflow/lite/delegates/utils_test.cc +++ b/tensorflow/lite/delegates/utils_test.cc @@ -205,7 +205,7 @@ class MockTfLiteContext : public TfLiteContext { // For simplicity, the mocked graph has only type of node and one // registration. TfLiteNode node_; - TfLiteRegistration registration_; + TfLiteRegistration registration_{}; // The TfLiteDelegateParams object that's manually populated inside the mocked // TfLiteContext::PreviewDelegatePartitioning. diff --git a/tensorflow/lite/g3doc/models/convert/operation_fusion.md b/tensorflow/lite/g3doc/models/convert/operation_fusion.md index 747cfc721be141..16450f5b40e65b 100644 --- a/tensorflow/lite/g3doc/models/convert/operation_fusion.md +++ b/tensorflow/lite/g3doc/models/convert/operation_fusion.md @@ -127,8 +127,9 @@ Note that, the name to register the op with should be similar to the name specified in the `name` attribute in the implements signature. An example for the op in the example is + ```c++ - TfLiteRegistration reg; + TfLiteRegistration reg = {}; // This name must match the name specified in the implements signature. static constexpr char kOpName[] = "my_custom_fused_op"; reg.custom_name = kOpName; @@ -137,7 +138,7 @@ An example for the op in the example is return kTfLiteOk; }; reg.invoke = [](TfLiteContext* context, TfLiteNode* node) -> TfLiteStatus { - // Add your coder. + // Add your code. return kTfLiteOk; }; reg.builtin_code = kTfLiteCustom; diff --git a/tensorflow/lite/simple_planner_test.cc b/tensorflow/lite/simple_planner_test.cc index 0b49600f569d39..93fb693519f5c5 100644 --- a/tensorflow/lite/simple_planner_test.cc +++ b/tensorflow/lite/simple_planner_test.cc @@ -46,7 +46,7 @@ class TestOp { std::vector inputs_; std::vector outputs_; std::vector temporaries_; - TfLiteRegistration registration_; + TfLiteRegistration registration_{}; }; // A test graph where inputs are processed by the given nodes to produce From d928bbba955b15eed30cf218c5d7e281e74affde Mon Sep 17 00:00:00 2001 From: Andrew Goodbody Date: Thu, 13 Jul 2023 17:32:20 +0100 Subject: [PATCH 102/410] Fix ambiguity in use of overloaded functions in XLA Cast ambiguous parameter so that it is not ambiguous and so gcc will compile it. --- .../compiler/xla/python/pjrt_ifrt/xla_sharding_serdes_test.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding_serdes_test.cc b/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding_serdes_test.cc index a98a0271a03a41..f01435d597e32d 100644 --- a/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding_serdes_test.cc +++ b/tensorflow/compiler/xla/python/pjrt_ifrt/xla_sharding_serdes_test.cc @@ -34,7 +34,8 @@ class XlaShardingSerDesTest : public test_util::ShardingTest {}; TEST_P(XlaShardingSerDesTest, HloShardingRoundTrip) { auto device_list = GetDevices({0, 1}); - auto xla_hlo_sharding = xla::HloSharding::Tile(xla::TileAssignment({2, 1})); + auto xla_hlo_sharding = xla::HloSharding::Tile( + xla::TileAssignment((absl::Span){2, 1})); auto sharding = HloSharding::Create(device_list, /*xla_hlo_sharding=*/xla_hlo_sharding); From 46d60005ad28dfa5a6514a202b27df8d84c95345 Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Tue, 25 Jul 2023 05:05:25 -0700 Subject: [PATCH 103/410] Extract input slices fusion from ir_emitter_unnested. PiperOrigin-RevId: 550850343 --- .../compiler/xla/service/gpu/fusions/BUILD | 18 ++ .../xla/service/gpu/fusions/fusions.cc | 4 + .../xla/service/gpu/fusions/input_slices.cc | 184 ++++++++++++++ .../xla/service/gpu/fusions/input_slices.h | 59 +++++ .../xla/service/gpu/ir_emitter_unnested.cc | 225 +----------------- .../xla/service/gpu/ir_emitter_unnested.h | 29 --- 6 files changed, 266 insertions(+), 253 deletions(-) create mode 100644 tensorflow/compiler/xla/service/gpu/fusions/input_slices.cc create mode 100644 tensorflow/compiler/xla/service/gpu/fusions/input_slices.h diff --git a/tensorflow/compiler/xla/service/gpu/fusions/BUILD b/tensorflow/compiler/xla/service/gpu/fusions/BUILD index 5bd8adde24b674..22aeb4d6b48827 100644 --- a/tensorflow/compiler/xla/service/gpu/fusions/BUILD +++ b/tensorflow/compiler/xla/service/gpu/fusions/BUILD @@ -60,6 +60,7 @@ cc_library( ":copy", ":fusion_emitter", ":in_place_dynamic_update_slice", + ":input_slices", ":loop", ":reduction", ":transpose", @@ -170,3 +171,20 @@ cc_library( "@llvm-project//llvm:ir_headers", ], ) + +cc_library( + name = "input_slices", + srcs = ["input_slices.cc"], + hdrs = ["input_slices.h"], + deps = [ + ":fusion_emitter", + "//tensorflow/compiler/xla/service:elemental_ir_emitter", + "//tensorflow/compiler/xla/service/gpu:hlo_fusion_analysis", + "//tensorflow/compiler/xla/service/gpu:ir_emission_utils", + "//tensorflow/compiler/xla/service/gpu:parallel_loop_emitter", + "//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter", + "//tensorflow/compiler/xla/service/llvm_ir:kernel_support_library", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "@llvm-project//llvm:ir_headers", + ], +) diff --git a/tensorflow/compiler/xla/service/gpu/fusions/fusions.cc b/tensorflow/compiler/xla/service/gpu/fusions/fusions.cc index cbbd60050eda7f..602caa5ddc6ba5 100644 --- a/tensorflow/compiler/xla/service/gpu/fusions/fusions.cc +++ b/tensorflow/compiler/xla/service/gpu/fusions/fusions.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/mlir_hlo/lhlo/IR/lhlo_ops.h" #include "tensorflow/compiler/xla/service/gpu/fusions/copy.h" #include "tensorflow/compiler/xla/service/gpu/fusions/in_place_dynamic_update_slice.h" +#include "tensorflow/compiler/xla/service/gpu/fusions/input_slices.h" #include "tensorflow/compiler/xla/service/gpu/fusions/loop.h" #include "tensorflow/compiler/xla/service/gpu/fusions/reduction.h" #include "tensorflow/compiler/xla/service/gpu/fusions/transpose.h" @@ -51,6 +52,9 @@ std::optional> GetFusionEmitter( ElementalIrEmitter& elemental_emitter, mlir::lmhlo::FusionOp fusion_op, const HloFusionInstruction& fusion) { switch (analysis.GetEmitterFusionKind()) { + case HloFusionAnalysis::EmitterFusionKind::kInputSlices: + return std::make_unique( + ir_emitter_context, elemental_emitter, fusion_op, fusion, analysis); case HloFusionAnalysis::EmitterFusionKind::kReduction: return std::make_unique( ir_emitter_context, elemental_emitter, fusion_op, fusion, analysis); diff --git a/tensorflow/compiler/xla/service/gpu/fusions/input_slices.cc b/tensorflow/compiler/xla/service/gpu/fusions/input_slices.cc new file mode 100644 index 00000000000000..2e3edc3a0bbbb0 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/fusions/input_slices.cc @@ -0,0 +1,184 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/xla/service/gpu/fusions/input_slices.h" + +#include "llvm/IR/IRBuilder.h" +#include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h" +#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" +#include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" + +namespace xla { +namespace gpu { +namespace { + +// Emits code for slices based on the below structure. An if statement with +// a guarding condition is generated for each ROOT slice. +// +// Pseudo code: +// +// Compute values of slice input operands +// +// Compute guarding_cond0 +// if (guarding_cond0) { +// Write to output of slice0 +// } +// +// Compute guarding_cond1 +// if (guarding_cond1) { +// Write to output of slice1 +// } +// +Status EmitElementForInputFusibleSlices( + ElementalIrEmitter& elemental_emitter, + const HloComputation* fused_computation, + const std::vector& inputs, + const std::vector& outputs, + const llvm_ir::IrArray::Index& index, llvm::IRBuilder<>* builder) { + VLOG(10) << "Emitting slice input fusion for " + << fused_computation->ToString(); + + HloInstruction* slice_or_tuple = fused_computation->root_instruction(); + auto slice_instructions = [&]() -> absl::Span { + if (slice_or_tuple->opcode() == HloOpcode::kSlice) { + return absl::Span(&slice_or_tuple, 1); + } + CHECK_EQ(slice_or_tuple->opcode(), HloOpcode::kTuple); + return slice_or_tuple->operands(); + }(); + + // Emit input operand values of slices. + std::vector input_ir_values; + FusedIrEmitter fused_emitter(elemental_emitter); + for (int i = 0; i < fused_computation->num_parameters(); i++) { + fused_emitter.BindGenerator( + *fused_computation->parameter_instruction(i), + [&inputs, i, builder](llvm_ir::IrArray::Index index) { + return inputs[i].EmitReadArrayElement(index, builder); + }); + } + for (const HloInstruction* slice : slice_instructions) { + auto input_generator = *fused_emitter.GetGenerator(*slice->operand(0)); + input_ir_values.push_back(input_generator(index).value()); + } + + // Emit for slice_instructions. + KernelSupportLibrary ksl(builder, llvm_ir::UnrollMode::kDefaultUnroll); + for (int64_t i = 0; i < slice_instructions.size(); ++i) { + HloInstruction* slice = slice_instructions[i]; + + // guarding_cond := index >= start && index < limit, for each dim. + std::vector index_within_ranges; + for (size_t dim = 0; dim < slice->slice_starts().size(); ++dim) { + CHECK_EQ(slice->slice_strides(dim), 1); + auto larger_or_equal_than_start = builder->CreateICmpSGE( + index.multidim()[dim], + index.GetConstantWithIndexType(slice->slice_starts(dim))); + llvm::Value* smaller_than_limit = builder->CreateICmpSLT( + index.multidim()[dim], + index.GetConstantWithIndexType(slice->slice_limits(dim))); + llvm::Value* within_range = + builder->CreateAnd(larger_or_equal_than_start, smaller_than_limit); + index_within_ranges.push_back(within_range); + } + llvm::Value* guarding_cond = builder->CreateAnd(index_within_ranges); + + auto emit_slice_elem_func = [&] { + const std::vector& src_multidim = index.multidim(); + std::vector dst_multidim(src_multidim.size()); + for (size_t dim = 0; dim < src_multidim.size(); ++dim) { + dst_multidim[dim] = builder->CreateSub( + src_multidim[dim], + index.GetConstantWithIndexType(slice->slice_starts(dim))); + } + llvm_ir::IrArray src_ir_array = outputs[i]; + llvm_ir::IrArray::Index slice_dst_index(dst_multidim, slice->shape(), + index.GetType()); + src_ir_array.EmitWriteArrayElement(slice_dst_index, input_ir_values[i], + builder); + }; + + ksl.If(absl::StrCat("slice", i), guarding_cond, emit_slice_elem_func); + } + return OkStatus(); +} + +// Gets the input shape of the ROOT slices, which will be used as the kernel +// launch dims. The slice input fusion requires the input shapes of the ROOT +// slices to be the same although the (slice) output shapes can be different. +// +// Returns the input shape of the ROOT slices if all the input shapes of ROOT +// slices are the same and the slices are non-strided. Otherwise, returns +// FailedPrecondition. +StatusOr GetConsistentInputShapeForRootSlices( + const HloComputation* fused_computation) { + const HloInstruction& root = *fused_computation->root_instruction(); + if (root.opcode() == HloOpcode::kSlice) { + return root.operands()[0]->shape(); + } + + CHECK_EQ(root.opcode(), HloOpcode::kTuple); + const Shape& first_slice_operand_shape = + root.operands()[0]->operands()[0]->shape(); + for (size_t i = 1; i < root.operands().size(); ++i) { + const HloInstruction* slice = root.operands()[i]; + const Shape& operand_shape = slice->operands()[0]->shape(); + if (!ShapeUtil::EqualIgnoringElementType(first_slice_operand_shape, + operand_shape)) { + return FailedPrecondition( + "Fused slices do not have the same input shape, fused computation = " + "%s.", + root.parent()->name()); + } + } + + return first_slice_operand_shape; +} + +} // namespace + +StatusOr InputSlicesFusion::launch_dimensions() const { + bool use_experimental_block_size = + ir_emitter_context() + .debug_options() + .xla_gpu_enable_experimental_block_size(); + return analysis_.GetLaunchDimensions(use_experimental_block_size); +} + +Status InputSlicesFusion::EmitKernel(const LaunchDimensions& launch_dims, + std::vector inputs, + std::vector outputs, + llvm::IRBuilder<>* builder, + int kernel_index) const { + TF_ASSIGN_OR_RETURN(Shape element_shape, + GetConsistentInputShapeForRootSlices( + fusion().fused_instructions_computation())); + return ParallelLoopEmitter( + [&](const llvm_ir::IrArray::Index index) -> Status { + return EmitElementForInputFusibleSlices( + elemental_emitter(), + fusion().fused_instructions_computation(), inputs, outputs, + index, builder); + }, + element_shape, launch_dims, builder) + .EmitLoop(llvm_ir::IrName(GetIrNameFromLoc(fusion_op().getLoc())), + GetIndexTypeForKernel(fusion_op(), launch_dims.launch_bound(), + builder)); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/fusions/input_slices.h b/tensorflow/compiler/xla/service/gpu/fusions/input_slices.h new file mode 100644 index 00000000000000..88f6ec90bf0b34 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/fusions/input_slices.h @@ -0,0 +1,59 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FUSIONS_INPUT_SLICES_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FUSIONS_INPUT_SLICES_H_ + +#include + +#include "tensorflow/compiler/xla/service/gpu/fusions/fusion_emitter.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_fusion_analysis.h" + +namespace xla { +namespace gpu { + +// Generates code for input-fusible slices. +// +// Prerequisite: ROOT is either a slice or a tuple of slices. The input shapes +// of all ROOT slices need to be the same while their output shapes can be +// different. On the other hand, the input ranges of slices can be +// overlapping. Further generalization/specialization when the needs are seen +// in the future. +class InputSlicesFusion : public KernelFusionEmitterBase { + public: + InputSlicesFusion(IrEmitterContext& ir_emitter_context, + ElementalIrEmitter& elemental_emitter, + mlir::lmhlo::FusionOp fusion_op, + const HloFusionInstruction& fusion, + HloFusionAnalysis& analysis) + : KernelFusionEmitterBase(ir_emitter_context, elemental_emitter, + fusion_op, fusion), + analysis_(analysis) {} + StatusOr launch_dimensions() const override; + + protected: + Status EmitKernel(const LaunchDimensions& launch_dims, + std::vector inputs, + std::vector outputs, + llvm::IRBuilder<>* builder, + int kernel_index) const override; + + private: + HloFusionAnalysis& analysis_; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FUSIONS_INPUT_SLICES_H_ diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index e220091a188ddd..75cdc5978a7e4e 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -32,7 +32,6 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" -#include "absl/container/inlined_vector.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" @@ -157,7 +156,6 @@ namespace gpu { namespace { -using absl::InlinedVector; using absl::StrCat; using llvm_ir::IrArray; using llvm_ir::IrName; @@ -240,61 +238,6 @@ void AnnotateKernelLaunchDimensions(const LaunchDimensions& launch_dims, } } -bool IsSingleInstructionFusion(mlir::lmhlo::FusionOp fusion) { - int instruction_count = 0; - for (mlir::Operation& instr : fusion.getRegion().front()) { - if (mlir::isa( - &instr)) { - continue; - } - instruction_count++; - } - return instruction_count == 1; -} - -// Gets the input shape of the ROOT slices, which will be used as the kernel -// launch dims. The slice input fusion requires the input shapes of the ROOT -// slices to be the same although the (slice) output shapes can be different. -// -// Returns the input shape of the ROOT slices if all the input shapes of ROOT -// slices are the same and the slices are non-strided. Otherwise, returns -// FailedPrecondition. -StatusOr GetConsistentInputShapeForRootSlices( - const HloComputation* fused_computation) { - const HloInstruction& root = *fused_computation->root_instruction(); - if (root.opcode() == HloOpcode::kSlice) { - return root.operands()[0]->shape(); - } - - CHECK_EQ(root.opcode(), HloOpcode::kTuple); - const Shape& first_slice_operand_shape = - root.operands()[0]->operands()[0]->shape(); - for (size_t i = 1; i < root.operands().size(); ++i) { - const HloInstruction* slice = root.operands()[i]; - const Shape& operand_shape = slice->operands()[0]->shape(); - if (!ShapeUtil::EqualIgnoringElementType(first_slice_operand_shape, - operand_shape)) { - return FailedPrecondition( - "Fused slices do not have the same input shape, fused computation = " - "%s.", - root.parent()->name()); - } - } - - return first_slice_operand_shape; -} - -// For a row reduction, returns the number of rows we can process in parallel -// per warp. -int RowReductionGetRowsPerWarp(int reduced_dimension_size) { - if (WarpSize() % reduced_dimension_size != 0 || - reduced_dimension_size >= WarpSize()) { - return 1; - } - return WarpSize() / reduced_dimension_size; -} - StatusOr AsCudnnfMHAKind( mlir::lmhlo_gpu::FusedMhaDagSignature signature) { switch (signature) { @@ -2016,10 +1959,9 @@ Status IrEmitterUnnested::EmitFusion(mlir::Operation* op) { #endif LOG(FATAL) << "Unsupported fusion kind: " << backend_config.kind(); } - case HloFusionAnalysis::EmitterFusionKind::kInputSlices: - return EmitInputFusibleNonStridedSlices(op, fusion_analysis); case HloFusionAnalysis::EmitterFusionKind::kScatter: return EmitScatter(fusion_op, fused_computation, fusion_analysis); + case HloFusionAnalysis::EmitterFusionKind::kInputSlices: case HloFusionAnalysis::EmitterFusionKind::kLoop: case HloFusionAnalysis::EmitterFusionKind::kReduction: case HloFusionAnalysis::EmitterFusionKind::kTranspose: @@ -2028,44 +1970,6 @@ Status IrEmitterUnnested::EmitFusion(mlir::Operation* op) { } } -Status IrEmitterUnnested::EmitExtraOutputsForReduce( - const Shape& reduction_operand_shape, - const ReductionOutputMap& result_ir_arrays, const IrArray::Index& index, - const ReductionCodegenInfo& reduction_info, - const ExtraOutputGensMap& extra_output_gens) { - if (extra_output_gens.empty()) { - return OkStatus(); - } - - // Compute all extra output values before writing them. This avoids - // overwriting aliased input/output buffers before all reads occurred. - std::vector> - extra_output_ir_values; - extra_output_ir_values.reserve(extra_output_gens.size()); - - auto get_index = [&](const HloInstruction* instr) { - const Shape& s = instr->shape(); - return ShapeUtil::EqualIgnoringElementType(reduction_operand_shape, s) - ? index - : index.SourceIndexOfBitcast(reduction_operand_shape, s, &b_); - }; - - for (const auto& [instr, generator] : extra_output_gens) { - TF_ASSIGN_OR_RETURN(llvm::Value* const extra_output_ir_value, - generator(get_index(instr))); - extra_output_ir_values.emplace_back(instr, extra_output_ir_value); - } - - for (const auto& [instr, generator] : extra_output_ir_values) { - absl::Span result_ir = result_ir_arrays.at(instr); - CHECK_EQ(result_ir.size(), 1); - result_ir[0].EmitWriteArrayElement( - get_index(instr), generator, &b_, /*use_linear_index=*/ - reduction_info.GetNumPartialResults() == 1); - } - return OkStatus(); -} - Status IrEmitterUnnested::AssertNonDeterminismIsOkay( const std::string& op_name) { if (ir_emitter_context_->debug_options().xla_gpu_deterministic_ops()) { @@ -3261,133 +3165,6 @@ Status IrEmitterUnnested::EmitTargetElementLoop( return InternalError("This should be unreachable"); } -// Emits code for slices based on the below structure. An if statement with -// a guarding condition is generated for each ROOT slice. -// -// Pseudo code: -// -// Compute values of slice input operands -// -// Compute guarding_cond0 -// if (guarding_cond0) { -// Write to output of slice0 -// } -// -// Compute guarding_cond1 -// if (guarding_cond1) { -// Write to output of slice1 -// } -// -Status IrEmitterUnnested::EmitElementForInputFusibleSlices( - const HloComputation* fused_computation, - absl::Span ir_arrays, - const llvm_ir::IrArray::Index& index) { - VLOG(10) << "Emitting slice input fusion for " - << fused_computation->ToString(); - - HloInstruction* slice_or_tuple = fused_computation->root_instruction(); - auto slice_instructions = [&]() -> absl::Span { - if (slice_or_tuple->opcode() == HloOpcode::kSlice) { - return absl::Span(&slice_or_tuple, 1); - } - CHECK_EQ(slice_or_tuple->opcode(), HloOpcode::kTuple); - return slice_or_tuple->operands(); - }(); - - // Emit input operand values of slices. - std::vector input_ir_values; - FusedIrEmitter fused_emitter(elemental_emitter_); - for (int i = 0; i < fused_computation->num_parameters(); i++) { - fused_emitter.BindGenerator( - *fused_computation->parameter_instruction(i), - [this, &ir_arrays, i](llvm_ir::IrArray::Index index) { - return ir_arrays[i].EmitReadArrayElement(index, &b_); - }); - } - for (const HloInstruction* slice : slice_instructions) { - auto input_generator = *fused_emitter.GetGenerator(*slice->operand(0)); - input_ir_values.push_back(input_generator(index).value()); - } - - // Emit for slice_instructions. - KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll); - for (int64_t i = 0; i < slice_instructions.size(); ++i) { - HloInstruction* slice = slice_instructions[i]; - - // guarding_cond := index >= start && index < limit, for each dim. - std::vector index_within_ranges; - for (size_t dim = 0; dim < slice->slice_starts().size(); ++dim) { - CHECK_EQ(slice->slice_strides(dim), 1); - auto larger_or_equal_than_start = b_.CreateICmpSGE( - index.multidim()[dim], - index.GetConstantWithIndexType(slice->slice_starts(dim))); - llvm::Value* smaller_than_limit = b_.CreateICmpSLT( - index.multidim()[dim], - index.GetConstantWithIndexType(slice->slice_limits(dim))); - llvm::Value* within_range = - b_.CreateAnd(larger_or_equal_than_start, smaller_than_limit); - index_within_ranges.push_back(within_range); - } - llvm::Value* guarding_cond = b_.CreateAnd(index_within_ranges); - - auto emit_slice_elem_func = [&] { - const std::vector& src_multidim = index.multidim(); - std::vector dst_multidim(src_multidim.size()); - for (size_t dim = 0; dim < src_multidim.size(); ++dim) { - dst_multidim[dim] = - Sub(src_multidim[dim], - index.GetConstantWithIndexType(slice->slice_starts(dim))); - } - llvm_ir::IrArray src_ir_array = - ir_arrays[fused_computation->num_parameters() + i]; - IrArray::Index slice_dst_index(dst_multidim, slice->shape(), - index.GetType()); - src_ir_array.EmitWriteArrayElement(slice_dst_index, input_ir_values[i], - &b_); - }; - - ksl.If(StrCat("slice", i), guarding_cond, emit_slice_elem_func); - } - return OkStatus(); -} - -Status IrEmitterUnnested::EmitInputFusibleNonStridedSlices( - mlir::Operation* op, HloFusionAnalysis& fusion_analysis) { - auto fusion = mlir::cast(op); - - TF_ASSIGN_OR_RETURN(const HloComputation* fused_computation, - GetOrCreateSubComputationFromRegion(&fusion.getRegion(), - /*is_fusion=*/true)); - - bool use_experimental_block_size = - ir_emitter_context_->debug_options() - .xla_gpu_enable_experimental_block_size(); - TF_ASSIGN_OR_RETURN( - LaunchDimensions launch_dimensions, - fusion_analysis.GetLaunchDimensions(use_experimental_block_size)); - - TF_ASSIGN_OR_RETURN( - std::optional> opt_ir_arrays, - BuildKernelThunkForFusion(fusion, launch_dimensions)); - if (!opt_ir_arrays.has_value()) { - // The kernel was reused, no need to emit code. - return OkStatus(); - } - std::vector& ir_arrays = opt_ir_arrays.value(); - - TF_ASSIGN_OR_RETURN(Shape element_shape, - GetConsistentInputShapeForRootSlices(fused_computation)); - return ParallelLoopEmitter( - [&](const llvm_ir::IrArray::Index index) -> Status { - return EmitElementForInputFusibleSlices(fused_computation, - ir_arrays, index); - }, - element_shape, launch_dimensions, &b_) - .EmitLoop( - IrName(GetIrNameFromLoc(fusion.getLoc())), - GetIndexTypeForKernel(fusion, launch_dimensions.launch_bound(), &b_)); -} - Status IrEmitterUnnested::EmitScatter(mlir::lmhlo::FusionOp fusion_op, const HloComputation* fused_computation, HloFusionAnalysis& fusion_analysis) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index 39f5b34e23238f..5409e2e74292e1 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -317,28 +317,6 @@ class IrEmitterUnnested : public IrEmitter { absl::Span arguments, const LaunchDimensions& launch_dimensions); - // Helper for writing extra outputs from inside a reduce kernel. - Status EmitExtraOutputsForReduce(const Shape& reduction_operand_shape, - const ReductionOutputMap& result_ir_arrays, - const llvm_ir::IrArray::Index& index, - const ReductionCodegenInfo& reduction_info, - const ExtraOutputGensMap& extra_output_gens); - - // Generates code for input-fusible slices. - // - // Prerequisite: ROOT is either a slice or a tuple of slices. The input shapes - // of all ROOT slices need to be the same while their output shapes can be - // different. On the other hand, the input ranges of slices can be - // overlapping. Further generalization/specialization when the needs are seen - // in the future. - Status EmitInputFusibleNonStridedSlices(mlir::Operation* op, - HloFusionAnalysis& fusion_analysis); - - Status EmitElementForInputFusibleSlices( - const HloComputation* fused_computation, - absl::Span ir_arrays, - const llvm_ir::IrArray::Index& index); - // Emits code for an in-place scatter, modifying `thunk`s launch dimensions in // the process. Scatter indices are taken from `scatter_indices_gen`, updates // from `updates_gen`. The output buffer is expected to have the operand @@ -375,13 +353,6 @@ class IrEmitterUnnested : public IrEmitter { Status EmitScatter(const ScatterDescriptor& desc, const LaunchDimensions& launch_dimensions); - Status EmitTransposeTile(mlir::lmhlo::FusionOp fusion, - const HloComputation* fusion_hlo, - absl::Span operand_arrays, - absl::Span output_arrays, - const TilingScheme& tiling_scheme, - const LaunchDimensions& launch_dimensions); - Status EmitScatter(mlir::lmhlo::FusionOp fusion_op, const HloComputation* fused_computation, HloFusionAnalysis& fusion_analysis); From 21cbb712ca04eb8a0c753d16a8cfb6b2b0c6218b Mon Sep 17 00:00:00 2001 From: Ian Hua Date: Tue, 25 Jul 2023 06:14:37 -0700 Subject: [PATCH 104/410] Improve DPB model coverage. PiperOrigin-RevId: 550864778 --- .../delegate_performance/android/models/BUILD | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/models/BUILD b/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/models/BUILD index bb25d68bd4ce6a..8f944837c3c15c 100644 --- a/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/models/BUILD +++ b/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/models/BUILD @@ -38,7 +38,7 @@ ACCURACY_MODELS = [ ), ] + accuracy_benchmark_extra_models() -LATENCY_MODELS = [ +BASIC_LATENCY_MODELS = [ ( "mobilenet_v1_1.0_224.tflite", "@tflite_mobilenet_float//:mobilenet_v1_1.0_224.tflite", @@ -47,7 +47,9 @@ LATENCY_MODELS = [ "mobilenet_v1_1.0_224_quant.tflite", "@tflite_mobilenet_quant//:mobilenet_v1_1.0_224_quant.tflite", ), -] + latency_benchmark_extra_models() +] + +LATENCY_MODELS = BASIC_LATENCY_MODELS + latency_benchmark_extra_models() COPY_CMD = """ srcs=($(SRCS)) @@ -72,6 +74,15 @@ genrule( cmd = COPY_CMD, ) +filegroup( + name = "latency_models_test_only", + testonly = True, + srcs = [ + "assets/latency/mobilenet_v1_1.0_224.tflite", + "assets/latency/mobilenet_v1_1.0_224_quant.tflite", + ], +) + # Latency criteria for latency benchmarking. filegroup( name = "latency_criteria_files", From f93b09e0f05272f0e2e99f30d912410815d07564 Mon Sep 17 00:00:00 2001 From: Oleg Shyshkov Date: Tue, 25 Jul 2023 06:25:02 -0700 Subject: [PATCH 105/410] [XLA] Simplify fusion_config object creation. There is no need to copy HloModuleConfig twice if we only want to modify one field. PiperOrigin-RevId: 550867272 --- .../xla/service/instruction_fusion.cc | 27 +++++++------------ 1 file changed, 10 insertions(+), 17 deletions(-) diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index 731815aaf21bfe..2d1414baf5b8d2 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -515,13 +515,7 @@ StatusOr InstructionFusion::Run( const absl::flat_hash_set& execution_threads) { bool changed = false; int64_t fuse_count = 0; - std::vector>* fusion_config = nullptr; - HloModuleConfig module_config; - if (config_collection_mode_ != FusionConfigCollection::kOff) { - module_config = module->config(); - fusion_config = module_config.mutable_fusion_config(); - fusion_config->clear(); - } + std::vector> fusion_config; bool dump_fusion = module->config().debug_options().xla_dump_fusion_visualization(); @@ -695,23 +689,22 @@ StatusOr InstructionFusion::Run( const std::vector* comp_fusion_config = fusion_queue->FusionConfiguration(); if (comp_fusion_config && !comp_fusion_config->empty()) { - fusion_config->push_back(*comp_fusion_config); + fusion_config.push_back(*comp_fusion_config); } } } if (config_collection_mode_ != FusionConfigCollection::kOff) { - int64_t fused_count = 0; - for (auto& config_per_computation : *fusion_config) { - for (auto edge : config_per_computation) { - if (edge) { - ++fused_count; - } + if (VLOG_IS_ON(1)) { + int64_t fused_count = 0; + for (auto& config_per_computation : fusion_config) { + fused_count += std::count(config_per_computation.begin(), + config_per_computation.end(), true); } + VLOG(1) << "There are " << fused_count << " fused bits that cause " + << fuse_count << " fusion actions."; } - VLOG(1) << "There are " << fused_count << " fused bits that cause " - << fuse_count << " fusion actions."; - module->set_config(module_config); + *module->config().mutable_fusion_config() = std::move(fusion_config); } VLOG(1) << "Fusion count: " << fuse_count; From 111929c093fd8d1d7861b9b2735f2fa1bb61a59e Mon Sep 17 00:00:00 2001 From: Oleg Shyshkov Date: Tue, 25 Jul 2023 07:44:31 -0700 Subject: [PATCH 106/410] [XLA] Refactor code that dumps fusion states in helper functions (NFC). Purely cosmetic to reduce size of InstructionFusion::Run. PiperOrigin-RevId: 550883573 --- .../xla/service/instruction_fusion.cc | 84 +++++++++++-------- .../compiler/xla/service/instruction_fusion.h | 14 ++++ 2 files changed, 64 insertions(+), 34 deletions(-) diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index 2d1414baf5b8d2..b9e70760ee80fa 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -584,14 +584,8 @@ StatusOr InstructionFusion::Run( should_fuse = ShouldFuse(instruction, i); if (should_fuse && consume_fuel()) { if (dump_fusion) { - RegisterFusionState( - *computation, - absl::StrCat("About to fuse |", operand->name(), "| into |", - instruction->name(), - "| inside InstructionFusion with may_duplicate=", - may_duplicate_), - /*consumer=*/*instruction, - /*producer=*/operand); + DumpPreFusionState(computation, /*consumer=*/instruction, + /*producer=*/operand); } fusion_queue->PreFusion(operand, instruction); @@ -611,14 +605,8 @@ StatusOr InstructionFusion::Run( if (can_fuse_mof) { if (consume_fuel()) { if (dump_fusion) { - RegisterFusionState( - *computation, - absl::StrCat( - "About to MOF-fuse |", operand->name(), "| into |", - instruction->name(), - "| inside InstructionFusion with may_duplicate=", - may_duplicate_), - /*consumer=*/*instruction, /*producer=*/operand); + DumpPreFusionState(computation, /*consumer=*/instruction, + /*producer=*/operand, /*is_mof=*/true); } fusion_queue->PreFusion(operand, instruction); @@ -636,17 +624,9 @@ StatusOr InstructionFusion::Run( << instruction->ToShortString() << "| as " << should_fuse.Explain(); - // Readability optimizations: lack of fusion for tuple accesses - // generates a lot of noise. - if (operand->opcode() != HloOpcode::kGetTupleElement && - instruction->opcode() != HloOpcode::kGetTupleElement) { - RegisterFusionState(*computation, - absl::StrCat("Not fusing |", operand->name(), - "| into |", instruction->name(), - "| as ", should_fuse.Explain()), - /*consumer=*/*instruction, - /*producer=*/operand); - } + DumpNotFusingState(computation, + /*consumer=*/instruction, + /*producer=*/operand, /*decision=*/should_fuse); } fusion_queue->NotFusingInstruction(operand, instruction); @@ -669,13 +649,7 @@ StatusOr InstructionFusion::Run( } if (dump_fusion) { - RegisterFusionState( - *computation, - absl::StrCat("Fused |", producer_name, "| into |", - fusion_instruction->name(), - "| inside InstructionFusion with may_duplicate=", - may_duplicate_), - *fusion_instruction); + DumpStateAfterFusion(computation, fusion_instruction, producer_name); } if (fusion_instruction != instruction) { @@ -1080,4 +1054,46 @@ bool InstructionFusion::ReusesOperandElements(const HloInstruction* consumer, return ReusedOperandsOf(consumer).contains(operand); } +void InstructionFusion::DumpPreFusionState(HloComputation* computation, + HloInstruction* consumer, + HloInstruction* producer, + bool is_mof) { + RegisterFusionState( + *computation, + absl::StrCat( + "About to ", is_mof ? "MOF-fuse" : "fuse", " |", producer->name(), + "| into |", consumer->name(), + "| inside InstructionFusion with may_duplicate=", may_duplicate_), + *consumer, producer); +} + +void InstructionFusion::DumpNotFusingState(HloComputation* computation, + HloInstruction* consumer, + HloInstruction* producer, + FusionDecision decision) { + // Readability optimizations: lack of fusion for tuple accesses + // generates a lot of noise. + if (producer->opcode() == HloOpcode::kGetTupleElement || + consumer->opcode() == HloOpcode::kGetTupleElement) { + return; + } + + RegisterFusionState( + *computation, + absl::StrCat("Not fusing |", producer->name(), "| into |", + consumer->name(), "| as ", decision.Explain()), + *consumer, producer); +} + +void InstructionFusion::DumpStateAfterFusion(HloComputation* computation, + HloInstruction* fusion_instruction, + const std::string& producer_name) { + RegisterFusionState( + *computation, + absl::StrCat( + "Fused |", producer_name, "| into |", fusion_instruction->name(), + "| inside InstructionFusion with may_duplicate=", may_duplicate_), + *fusion_instruction); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h index 743ec09829aa21..8217ab16a46f69 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.h +++ b/tensorflow/compiler/xla/service/instruction_fusion.h @@ -314,6 +314,20 @@ class InstructionFusion : public HloModulePass { // duplicated. std::function is_expensive_; + // Dumps the state of computation before fusion. + void DumpPreFusionState(HloComputation* computation, HloInstruction* consumer, + HloInstruction* producer, bool is_mof = false); + + // Dumps the state of computation and the reason why the fusion was not + // performed. + void DumpNotFusingState(HloComputation* computation, HloInstruction* consumer, + HloInstruction* producer, FusionDecision decision); + + // Dumps the state of computation after fusion happened. + void DumpStateAfterFusion(HloComputation* computation, + HloInstruction* fusion_instruction, + const std::string& producer_name); + // Returns whether we may duplicate an instruction if we want to fuse it. bool may_duplicate_; From be93649864e1a0f04a6a6e217f75a19fe2c216b9 Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Tue, 25 Jul 2023 08:01:00 -0700 Subject: [PATCH 107/410] Use lit consistently in XLA This change makes it so that the same lit runner is used in XLA regardless of repo, making recent changes to tensorflow's runlit.cfg.py unnecessary. Change all single quotes to double quotes in runlit.site.cfg.py to appease the formatter. PiperOrigin-RevId: 550887354 --- tensorflow/compiler/mlir/runlit.cfg.py | 19 ------- tensorflow/compiler/xla/glob_lit_test.bzl | 3 +- .../mlir/backends/cpu/transforms/tests/BUILD | 2 +- .../mlir/backends/gpu/transforms/tests/BUILD | 2 +- .../backends/openxla/transforms/tests/BUILD | 2 +- .../compiler/xla/mlir/framework/tests/BUILD | 2 +- .../xla/mlir/math/transforms/tests/BUILD | 2 +- .../xla/mlir/memref/transforms/tests/BUILD | 2 +- .../compiler/xla/mlir/runtime/ir/tests/BUILD | 2 +- .../xla/mlir/runtime/transforms/tests/BUILD | 2 +- .../tools/mlir_bisect/rewrites/tests/BUILD | 2 +- .../xla/mlir/tools/mlir_bisect/tests/BUILD | 2 +- .../compiler/xla/mlir/xla_cpu/tests/BUILD | 2 +- .../mlir_interpreter/dialects/tests/BUILD | 2 +- .../compiler/xla/python/ifrt/ir/tests/BUILD | 2 +- tensorflow/compiler/xla/runlit.site.cfg.py | 55 ++++++++++--------- .../compiler/xla/service/gpu/tests/BUILD | 2 +- .../xla/translate/hlo_to_mhlo/tests/BUILD | 2 +- .../xla/translate/mhlo_to_hlo/tests/BUILD | 2 +- .../mhlo_to_lhlo_with_xla/tests/BUILD | 2 +- 20 files changed, 47 insertions(+), 64 deletions(-) diff --git a/tensorflow/compiler/mlir/runlit.cfg.py b/tensorflow/compiler/mlir/runlit.cfg.py index 41cb7760b7efdc..c5e64137ac8d00 100644 --- a/tensorflow/compiler/mlir/runlit.cfg.py +++ b/tensorflow/compiler/mlir/runlit.cfg.py @@ -60,25 +60,6 @@ llvm_config.config.substitutions.append( ('%tfrt_bindir', 'tensorflow/compiler/aot')) -subst_marker = 'SUBST_' -subst_marker_len = len(subst_marker) -# Include aditional substitutions that may be defined via params -llvm_config.config.substitutions.extend( - ('%%{%s}' % key, val) - for key, val in lit_config.params.items() - if not key.startswith(subst_marker) -) - -# Include ir substitutions for FileCheck -llvm_config.config.substitutions.append(( - '%{IR_SUBST}', - ' '.join( - "-D{}='{}'".format(key[subst_marker_len:], val.replace('[SPACE]', ' ')) - for key, val in lit_config.params.items() - if key.startswith(subst_marker) - ), -)) - # Tweak the PATH to include the tools dir. llvm_config.with_environment('PATH', config.llvm_tools_dir, append_path=True) diff --git a/tensorflow/compiler/xla/glob_lit_test.bzl b/tensorflow/compiler/xla/glob_lit_test.bzl index a95496c8ded39b..d4d40834f00da3 100644 --- a/tensorflow/compiler/xla/glob_lit_test.bzl +++ b/tensorflow/compiler/xla/glob_lit_test.bzl @@ -46,12 +46,13 @@ def _run_lit_test(name, data, size, tags, driver, features, exec_properties): """ # Disable tests on windows for now, to enable testing rest of all xla and mlir. + xla_root_dir = "tensorflow/compiler/xla/" native.py_test( name = name, srcs = ["@llvm-project//llvm:lit"], tags = tags + ["no_windows"], args = [ - "xla/" + paths.basename(data[-1]) + " --config-prefix=runlit -v", + xla_root_dir + paths.basename(data[-1]) + " --config-prefix=runlit -v", ] + features, data = data + [ "//tensorflow/compiler/xla:litfiles", diff --git a/tensorflow/compiler/xla/mlir/backends/cpu/transforms/tests/BUILD b/tensorflow/compiler/xla/mlir/backends/cpu/transforms/tests/BUILD index 2de8ab7761cd27..9a6889bc909302 100644 --- a/tensorflow/compiler/xla/mlir/backends/cpu/transforms/tests/BUILD +++ b/tensorflow/compiler/xla/mlir/backends/cpu/transforms/tests/BUILD @@ -1,5 +1,5 @@ load("//tensorflow/tsl:tsl.default.bzl", "filegroup") -load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") +load("//tensorflow/compiler/xla:glob_lit_test.bzl", "glob_lit_tests") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], diff --git a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/BUILD b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/BUILD index 7d563b53c30800..d81291d0cb0f1f 100644 --- a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/BUILD +++ b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/BUILD @@ -1,5 +1,5 @@ load("//tensorflow/tsl:tsl.default.bzl", "filegroup") -load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") +load("//tensorflow/compiler/xla:glob_lit_test.bzl", "glob_lit_tests") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], diff --git a/tensorflow/compiler/xla/mlir/backends/openxla/transforms/tests/BUILD b/tensorflow/compiler/xla/mlir/backends/openxla/transforms/tests/BUILD index cbfd8155fee7b8..0238b82c42ebc1 100644 --- a/tensorflow/compiler/xla/mlir/backends/openxla/transforms/tests/BUILD +++ b/tensorflow/compiler/xla/mlir/backends/openxla/transforms/tests/BUILD @@ -1,5 +1,5 @@ load("//tensorflow/tsl:tsl.default.bzl", "filegroup") -load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") +load("//tensorflow/compiler/xla:glob_lit_test.bzl", "glob_lit_tests") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], diff --git a/tensorflow/compiler/xla/mlir/framework/tests/BUILD b/tensorflow/compiler/xla/mlir/framework/tests/BUILD index 7dd8c6bdf59c41..93953a26e914af 100644 --- a/tensorflow/compiler/xla/mlir/framework/tests/BUILD +++ b/tensorflow/compiler/xla/mlir/framework/tests/BUILD @@ -1,5 +1,5 @@ load("//tensorflow/tsl:tsl.default.bzl", "filegroup") -load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") +load("//tensorflow/compiler/xla:glob_lit_test.bzl", "glob_lit_tests") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], diff --git a/tensorflow/compiler/xla/mlir/math/transforms/tests/BUILD b/tensorflow/compiler/xla/mlir/math/transforms/tests/BUILD index eee4aaa5dbaf4d..bbee190ea813b4 100644 --- a/tensorflow/compiler/xla/mlir/math/transforms/tests/BUILD +++ b/tensorflow/compiler/xla/mlir/math/transforms/tests/BUILD @@ -1,5 +1,5 @@ load("//tensorflow/tsl:tsl.default.bzl", "filegroup") -load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") +load("//tensorflow/compiler/xla:glob_lit_test.bzl", "glob_lit_tests") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], diff --git a/tensorflow/compiler/xla/mlir/memref/transforms/tests/BUILD b/tensorflow/compiler/xla/mlir/memref/transforms/tests/BUILD index eee4aaa5dbaf4d..bbee190ea813b4 100644 --- a/tensorflow/compiler/xla/mlir/memref/transforms/tests/BUILD +++ b/tensorflow/compiler/xla/mlir/memref/transforms/tests/BUILD @@ -1,5 +1,5 @@ load("//tensorflow/tsl:tsl.default.bzl", "filegroup") -load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") +load("//tensorflow/compiler/xla:glob_lit_test.bzl", "glob_lit_tests") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], diff --git a/tensorflow/compiler/xla/mlir/runtime/ir/tests/BUILD b/tensorflow/compiler/xla/mlir/runtime/ir/tests/BUILD index 726567f1677b42..36bc769b2e90c5 100644 --- a/tensorflow/compiler/xla/mlir/runtime/ir/tests/BUILD +++ b/tensorflow/compiler/xla/mlir/runtime/ir/tests/BUILD @@ -1,4 +1,4 @@ -load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") +load("//tensorflow/compiler/xla:glob_lit_test.bzl", "glob_lit_tests") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") load("//tensorflow/tsl:tsl.default.bzl", "filegroup", "get_compatible_with_portable") diff --git a/tensorflow/compiler/xla/mlir/runtime/transforms/tests/BUILD b/tensorflow/compiler/xla/mlir/runtime/transforms/tests/BUILD index d5bb08e782731d..cc87e0d95b1f3f 100644 --- a/tensorflow/compiler/xla/mlir/runtime/transforms/tests/BUILD +++ b/tensorflow/compiler/xla/mlir/runtime/transforms/tests/BUILD @@ -1,5 +1,5 @@ load("//tensorflow/tsl:tsl.default.bzl", "filegroup") -load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") +load("//tensorflow/compiler/xla:glob_lit_test.bzl", "glob_lit_tests") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], diff --git a/tensorflow/compiler/xla/mlir/tools/mlir_bisect/rewrites/tests/BUILD b/tensorflow/compiler/xla/mlir/tools/mlir_bisect/rewrites/tests/BUILD index 66e100bb3f32d7..90f9c49503e1f8 100644 --- a/tensorflow/compiler/xla/mlir/tools/mlir_bisect/rewrites/tests/BUILD +++ b/tensorflow/compiler/xla/mlir/tools/mlir_bisect/rewrites/tests/BUILD @@ -1,4 +1,4 @@ -load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") +load("//tensorflow/compiler/xla:glob_lit_test.bzl", "glob_lit_tests") load("//tensorflow/tsl:tsl.default.bzl", "filegroup") # copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) diff --git a/tensorflow/compiler/xla/mlir/tools/mlir_bisect/tests/BUILD b/tensorflow/compiler/xla/mlir/tools/mlir_bisect/tests/BUILD index fdfa4f6294aad0..eee7853875d964 100644 --- a/tensorflow/compiler/xla/mlir/tools/mlir_bisect/tests/BUILD +++ b/tensorflow/compiler/xla/mlir/tools/mlir_bisect/tests/BUILD @@ -1,4 +1,4 @@ -load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") +load("//tensorflow/compiler/xla:glob_lit_test.bzl", "glob_lit_tests") load("//tensorflow/tsl:tsl.default.bzl", "filegroup") # copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) diff --git a/tensorflow/compiler/xla/mlir/xla_cpu/tests/BUILD b/tensorflow/compiler/xla/mlir/xla_cpu/tests/BUILD index 5145bc5af7a2a5..66e94b227fd05a 100644 --- a/tensorflow/compiler/xla/mlir/xla_cpu/tests/BUILD +++ b/tensorflow/compiler/xla/mlir/xla_cpu/tests/BUILD @@ -1,5 +1,5 @@ load("//tensorflow/tsl:tsl.default.bzl", "filegroup") -load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") +load("//tensorflow/compiler/xla:glob_lit_test.bzl", "glob_lit_tests") # copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) diff --git a/tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/BUILD b/tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/BUILD index c082dedad88ef4..2e0b6b019e6cba 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/BUILD +++ b/tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/BUILD @@ -1,4 +1,4 @@ -load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") +load("//tensorflow/compiler/xla:glob_lit_test.bzl", "glob_lit_tests") load("//tensorflow/tsl:tsl.default.bzl", "filegroup") # copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) diff --git a/tensorflow/compiler/xla/python/ifrt/ir/tests/BUILD b/tensorflow/compiler/xla/python/ifrt/ir/tests/BUILD index b625ae0a48262c..ec8b47bc6bdd9e 100644 --- a/tensorflow/compiler/xla/python/ifrt/ir/tests/BUILD +++ b/tensorflow/compiler/xla/python/ifrt/ir/tests/BUILD @@ -1,4 +1,4 @@ -load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") +load("//tensorflow/compiler/xla:glob_lit_test.bzl", "glob_lit_tests") load("//tensorflow/compiler/xla:xla.bzl", "xla_cc_test") package( diff --git a/tensorflow/compiler/xla/runlit.site.cfg.py b/tensorflow/compiler/xla/runlit.site.cfg.py index 136900ae840e47..e6d5ffa1ca6e20 100644 --- a/tensorflow/compiler/xla/runlit.site.cfg.py +++ b/tensorflow/compiler/xla/runlit.site.cfg.py @@ -18,56 +18,57 @@ import lit.llvm # Handle the test srcdir for platforms. On windows, things are weird with bazel. -if platform.system() == 'Windows': - srcdir = os.environ['TEST_SRCDIR'] - real_test_srcdir = srcdir[:srcdir.find('xla/')] - external_srcdir = os.path.join(real_test_srcdir, 'external') +if platform.system() == "Windows": + srcdir = os.environ["TEST_SRCDIR"] + real_test_srcdir = srcdir[:srcdir.find("xla/")] + external_srcdir = os.path.join(real_test_srcdir, "external") else: - real_test_srcdir = os.environ['TEST_SRCDIR'] + real_test_srcdir = os.environ["TEST_SRCDIR"] external_srcdir = real_test_srcdir # Lint for undefined variables is disabled as config is not defined inside this # file, instead config is injected by lit.py. The structure is common for lit # tests and intended to only persist temporarily (b/136126535). # pylint: disable=undefined-variable -config.llvm_tools_dir = os.path.join(external_srcdir, 'llvm-project', 'llvm') +config.llvm_tools_dir = os.path.join(external_srcdir, "llvm-project", "llvm") config.mlir_obj_root = os.path.join(real_test_srcdir) -config.mlir_tools_dir = os.path.join(external_srcdir, 'llvm-project', 'mlir') +config.mlir_tools_dir = os.path.join(external_srcdir, "llvm-project", "mlir") # TODO(jpienaar): Replace with suffices in build rule. -config.suffixes = ['.td', '.mlir', '.pbtxt'] +config.suffixes = [".td", ".mlir", ".pbtxt"] -xla_root_dir = 'tensorflow/compiler/xla/' +xla_root_dir = "tensorflow/compiler/xla/" mlir_tf_tools_dirs = [ - 'mlir/backends/cpu', - 'mlir/backends/gpu', - 'mlir/runtime', - 'mlir/tools/mlir_bisect', - 'mlir_hlo', - 'python/ifrt/ir/tests', - 'service/gpu/tests', - 'service/mlir_gpu', - 'translate', - 'translate/mhlo_to_lhlo_with_xla', + "mlir/backends/cpu", + "mlir/backends/gpu", + "mlir/runtime", + "mlir/tools/mlir_bisect", + "mlir_hlo", + "python/ifrt/ir/tests", + "service/gpu/tests", + "service/mlir_gpu", + "translate", + "translate/mhlo_to_lhlo_with_xla", ] config.mlir_tf_tools_dirs = [ - os.path.join(real_test_srcdir, os.environ['TEST_WORKSPACE'], xla_root_dir, + os.path.join(real_test_srcdir, os.environ["TEST_WORKSPACE"], xla_root_dir, s) for s in mlir_tf_tools_dirs ] -test_dir = os.environ['TEST_TARGET'] -test_dir = test_dir.strip('/').rsplit(':', 1)[0] +test_dir = os.environ["TEST_TARGET"] +test_dir = test_dir.strip("/").rsplit(":", 1)[0] config.mlir_test_dir = os.path.join(real_test_srcdir, - os.environ['TEST_WORKSPACE'], test_dir) + os.environ["TEST_WORKSPACE"], test_dir) -if platform.system() == 'Windows': +if platform.system() == "Windows": # Configure this to work with msys2, TF's preferred windows bash. - config.lit_tools_dir = '/usr/bin' + config.lit_tools_dir = "/usr/bin" lit.llvm.initialize(lit_config, config) + # Let the main config do the real work. lit_config.load_config( config, os.path.join( - os.path.join(real_test_srcdir, os.environ['TEST_WORKSPACE'], - 'xla/runlit.cfg.py'))) + os.path.join(real_test_srcdir, os.environ["TEST_WORKSPACE"], + xla_root_dir + "runlit.cfg.py"))) # pylint: enable=undefined-variable diff --git a/tensorflow/compiler/xla/service/gpu/tests/BUILD b/tensorflow/compiler/xla/service/gpu/tests/BUILD index d72f115cd8e012..c396b2dfb4db7d 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/gpu/tests/BUILD @@ -12,7 +12,7 @@ load( "//tensorflow/tsl/platform:build_config_root.bzl", "tf_cuda_tests_tags", ) -load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") +load("//tensorflow/compiler/xla:glob_lit_test.bzl", "glob_lit_tests") load( "//tensorflow/tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured", diff --git a/tensorflow/compiler/xla/translate/hlo_to_mhlo/tests/BUILD b/tensorflow/compiler/xla/translate/hlo_to_mhlo/tests/BUILD index a4d673eeefafd1..b541b009e256a7 100644 --- a/tensorflow/compiler/xla/translate/hlo_to_mhlo/tests/BUILD +++ b/tensorflow/compiler/xla/translate/hlo_to_mhlo/tests/BUILD @@ -1,5 +1,5 @@ load("//tensorflow/tsl:tsl.default.bzl", "filegroup") -load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") +load("//tensorflow/compiler/xla:glob_lit_test.bzl", "glob_lit_tests") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], diff --git a/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/BUILD b/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/BUILD index dd214111598528..828ceabe1d2355 100644 --- a/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/BUILD +++ b/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/BUILD @@ -1,5 +1,5 @@ load("//tensorflow/tsl:tsl.default.bzl", "filegroup") -load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") +load("//tensorflow/compiler/xla:glob_lit_test.bzl", "glob_lit_tests") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], diff --git a/tensorflow/compiler/xla/translate/mhlo_to_lhlo_with_xla/tests/BUILD b/tensorflow/compiler/xla/translate/mhlo_to_lhlo_with_xla/tests/BUILD index a53073f341d955..ac23a29bced276 100644 --- a/tensorflow/compiler/xla/translate/mhlo_to_lhlo_with_xla/tests/BUILD +++ b/tensorflow/compiler/xla/translate/mhlo_to_lhlo_with_xla/tests/BUILD @@ -1,5 +1,5 @@ load("//tensorflow/tsl:tsl.default.bzl", "filegroup") -load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") +load("//tensorflow/compiler/xla:glob_lit_test.bzl", "glob_lit_tests") load( "//tensorflow/tsl/platform:build_config_root.bzl", "tf_cuda_tests_tags", From 861ad37d23dc27c8e69477ac083f7f59ca44ffe2 Mon Sep 17 00:00:00 2001 From: justkw Date: Tue, 25 Jul 2023 08:38:09 -0700 Subject: [PATCH 108/410] curl upgrade to 8.1.2 --- tensorflow/workspace2.bzl | 6 +++--- third_party/curl.BUILD | 14 ++++++++++++-- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/tensorflow/workspace2.bzl b/tensorflow/workspace2.bzl index 9f1f474c270fe9..8e5afacddf8844 100644 --- a/tensorflow/workspace2.bzl +++ b/tensorflow/workspace2.bzl @@ -512,10 +512,10 @@ def _tf_repositories(): tf_http_archive( name = "curl", build_file = "//third_party:curl.BUILD", - sha256 = "5fd29000a4089934f121eff456101f0a5d09e2a3e89da1d714adf06c4be887cb", - strip_prefix = "curl-8.0.1", + sha256 = "2e5a9b8fcdc095bdd2f079561f369de71c5eb3b80f00a702fbe9a8b8d9897891", + strip_prefix = "curl-8.1.2", system_build_file = "//third_party/systemlibs:curl.BUILD", - urls = tf_mirror_urls("https://curl.haxx.se/download/curl-8.0.1.tar.gz"), + urls = tf_mirror_urls("https://curl.haxx.se/download/curl-8.1.2.tar.gz"), ) # WARNING: make sure ncteisen@ and vpai@ are cc-ed on any CL to change the below rule diff --git a/third_party/curl.BUILD b/third_party/curl.BUILD index 4de0704b099e42..fba1bc6abda748 100644 --- a/third_party/curl.BUILD +++ b/third_party/curl.BUILD @@ -40,10 +40,18 @@ cc_library( "lib/asyn.h", "lib/asyn-ares.c", "lib/base64.c", + "lib/bufq.c", + "lib/bufq.h", "lib/bufref.c", "lib/bufref.h", "lib/c-hyper.c", "lib/c-hyper.h", + "lib/cf-h1-proxy.c", + "lib/cf-h1-proxy.h", + "lib/cf-h2-proxy.c", + "lib/cf-h2-proxy.h", + "lib/cf-haproxy.c", + "lib/cf-haproxy.h", "lib/cf-https-connect.c", "lib/cf-https-connect.h", "lib/cf-socket.c", @@ -121,6 +129,8 @@ cc_library( "lib/doh.h", "lib/dynbuf.c", "lib/dynbuf.h", + "lib/dynhds.c", + "lib/dynhds.h", "lib/easy.c", "lib/easy_lock.h", "lib/easygetopt.c", @@ -147,8 +157,6 @@ cc_library( "lib/getinfo.h", "lib/gopher.c", "lib/gopher.h", - "lib/h2h3.c", - "lib/h2h3.h", "lib/hash.c", "lib/hash.h", "lib/headers.c", @@ -164,6 +172,8 @@ cc_library( "lib/hsts.h", "lib/http.c", "lib/http.h", + "lib/http1.c", + "lib/http1.h", "lib/http2.c", "lib/http2.h", "lib/http_aws_sigv4.c", From bf0761a4c54d5b9f0901566f179f061de7531365 Mon Sep 17 00:00:00 2001 From: Andrew Goodbody Date: Tue, 25 Jul 2023 17:16:32 +0100 Subject: [PATCH 109/410] Limit the version of wrapt to be used to prevent unit test failures A change in the behaviour of wrapt ends up in a unit test failure for TensorFlow. This can be removed once mitigation in TensorFlow is in place. https://github.com/tensorflow/tensorflow/issues/60687 and https://github.com/GrahamDumpleton/wrapt/issues/231 This is an alternative and less intrusive fix to https://github.com/tensorflow/tensorflow/pull/60688 which was inadvertently reverted by a mis-merge in another commit. --- tensorflow/tools/pip_package/setup.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py index 5b84e9368f544c..6d3e9383ec0205 100644 --- a/tensorflow/tools/pip_package/setup.py +++ b/tensorflow/tools/pip_package/setup.py @@ -93,12 +93,15 @@ def standard_or_nightly(standard, nightly): 'numpy >= 1.23.5', 'opt_einsum >= 2.3.2', 'packaging', - 'protobuf>=3.20.3,<5.0.0dev,!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5', + # pylint:disable=line-too-long + ( + 'protobuf>=3.20.3,<5.0.0dev,!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5' + ), 'setuptools', 'six >= 1.12.0', 'termcolor >= 1.1.0', 'typing_extensions >= 3.6.6', - 'wrapt >= 1.11.0', + 'wrapt >= 1.11.0, < 1.15', # This looks worse as a wrapped line. pylint:disable=line-too-long ( 'tensorflow-io-gcs-filesystem >= 0.23.1;platform_machine!="arm64" or' From 0af8b588dc0ef57a022d4bf0f3ae74ab074a18ae Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Tue, 25 Jul 2023 09:17:37 -0700 Subject: [PATCH 110/410] [XLA:GPU] Remove executable cache from autotuner util. This cache in practice is hit just once per entry - when Triton GEMM autotuner uses a thread pool to precompile all configurations. Therefore it's more efficient to store these compiled executables in the autotuner locally instead. This reduces their lifetime and requires less RAM. Also remove a duplicate config. PiperOrigin-RevId: 550906028 --- tensorflow/compiler/xla/service/gpu/BUILD | 6 - .../xla/service/gpu/autotuner_compile_util.cc | 77 +--------- .../xla/service/gpu/autotuner_compile_util.h | 22 +-- .../xla/service/gpu/triton_autotuner.cc | 139 ++++++++---------- 4 files changed, 74 insertions(+), 170 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 4bf5b9d4159a14..b164a34e6f8fb8 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -1335,16 +1335,10 @@ cc_library( ":autotuner_util", ":gpu_executable_run_options", ":ir_emission_utils", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", - "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", - "//tensorflow/compiler/xla:autotune_results_proto_cc", - "//tensorflow/compiler/xla:autotuning_proto_cc", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_proto_cc", "//tensorflow/compiler/xla/hlo/ir:hlo", diff --git a/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.cc b/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.cc index 6bd4e7a6c64782..2d205e81558dce 100644 --- a/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.cc +++ b/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.cc @@ -17,20 +17,13 @@ limitations under the License. #include #include -#include #include #include -#include "absl/base/const_init.h" -#include "absl/container/flat_hash_set.h" -#include "absl/container/node_hash_map.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" -#include "absl/synchronization/mutex.h" #include "absl/time/time.h" #include "absl/types/span.h" -#include "tensorflow/compiler/xla/autotune_results.pb.h" -#include "tensorflow/compiler/xla/autotuning.pb.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_clone_context.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_computation.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" @@ -54,34 +47,6 @@ namespace gpu { namespace { -struct CompilationKey { - template - friend H AbslHashValue(H h, const CompilationKey& k) { - return H::combine(std::move(h), k.autotune_key, k.res.SerializeAsString()); - } - - bool operator==(const CompilationKey& k) const { - return res.SerializeAsString() == k.res.SerializeAsString() && - autotune_key == k.autotune_key; - } - - std::string ToString() const { - return absl::StrFormat("", autotune_key.ToString(), - res.DebugString()); - } - - AutotuneCacheKey autotune_key; - AutotuneResult res; -}; - -static absl::Mutex executable_cache_mutex(absl::kConstInit); -// The key is the "standard" AutotuneCacheKey, which encompasses both the device -// type and the code of the HLO. We need this because TritonAutotuner may be -// called with different device types, and an executable compiled for one device -// type may not run on another. -static auto& ABSL_GUARDED_BY(executable_cache_mutex) executable_cache = - *new absl::node_hash_map>(); - std::vector ExecutionInputsFromBuffers( Executable* executable, absl::Span buffers) { const HloInstruction::InstructionVector& params = @@ -126,16 +91,9 @@ AutotunerCompileUtil::AutotunerCompileUtil(const AutotuneConfig& config, } StatusOr> -AutotunerCompileUtil::GenerateAndProfileExecutable( - const AutotuneResult& config, const AutotuneCacheKey& cache_key, - se::Stream* stream, absl::Span input_buffers, - GenerateModuleFn extractor) { - TF_ASSIGN_OR_RETURN(Executable * executable, - Compile(config, cache_key, std::move(extractor))); - - if (!executable) { - return {std::nullopt}; - } +AutotunerCompileUtil::ProfileExecutable( + Executable* executable, se::Stream* stream, + absl::Span input_buffers) { { std::vector execution_inputs = ExecutionInputsFromBuffers(executable, input_buffers); @@ -157,29 +115,9 @@ AutotunerCompileUtil::GenerateAndProfileExecutable( timer_duration, execution_output.Commit().ConsumeResult()); } -StatusOr AutotunerCompileUtil::Compile( - const AutotuneResult& res, const AutotuneCacheKey& cache_key, +StatusOr> AutotunerCompileUtil::Compile( GenerateModuleFn extractor) { - CompilationKey key{cache_key, res}; - { - absl::MutexLock lock(&executable_cache_mutex); - auto it = executable_cache.find(key); - if (it != executable_cache.end()) { - VLOG(4) << "Compilation cache hit"; - return it->second.get(); - } - } - - TF_ASSIGN_OR_RETURN(std::unique_ptr executable, - CompileNoCache(std::move(extractor))); - absl::MutexLock lock(&executable_cache_mutex); - auto [it, inserted] = executable_cache.emplace(key, std::move(executable)); - return it->second.get(); -} - -StatusOr> AutotunerCompileUtil::CompileNoCache( - GenerateModuleFn module_extractor) { - StatusOr> new_hlo_module = module_extractor(); + StatusOr> new_hlo_module = extractor(); if (new_hlo_module.status().GetPayload(kUncompilableFusion).has_value()) { // Incompatible value of split-k is an expected failure. return std::unique_ptr(); @@ -234,10 +172,5 @@ StatusOr AutotunerCompileUtil::Execute( return std::move(output); } -/*static*/ void AutotunerCompileUtil::ClearCompilationCache() { - absl::MutexLock lock(&executable_cache_mutex); - executable_cache.clear(); -} - } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.h b/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.h index 1ad8ddb5ba9acc..fdece262b54785 100644 --- a/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.h +++ b/tensorflow/compiler/xla/service/gpu/autotuner_compile_util.h @@ -20,8 +20,6 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/autotune_results.pb.h" -#include "tensorflow/compiler/xla/autotuning.pb.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_clone_context.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_computation.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" @@ -36,8 +34,6 @@ namespace gpu { // Autotuning utils which require compiling fusions separately. Requires a // separate target, as runtime autotuning cannot perform compilation. -// -// Uses a global cache, *not* unique per instance. class AutotunerCompileUtil { public: using GenerateModuleFn = @@ -64,35 +60,25 @@ class AutotunerCompileUtil { // Runs the resulting executable with the given extractor, cached with // `(cache_key, config)`. Returns `std::nullopt` on expected failure, bad // `Status` otherwise. - StatusOr> GenerateAndProfileExecutable( - const AutotuneResult& config, const AutotuneCacheKey& cache_key, - se::Stream* stream, absl::Span input_buffers, - GenerateModuleFn extractor); + StatusOr> ProfileExecutable( + Executable* executable, se::Stream* stream, + absl::Span input_buffers); // Generic method to compile a generated module from `extractor` in isolation. // - // On *expected* failures we will store an empty unique_ptr in cache. - // // Returns: // - `nullptr` on *expected* failure // - `Executable` if everything goes fine. // - `Status` on *unexpected* failure. - StatusOr Compile( - const AutotuneResult& res, const AutotuneCacheKey& cache_key, + StatusOr> Compile( AutotunerCompileUtil::GenerateModuleFn extractor); - // Clears the global compilation cache. - static void ClearCompilationCache(); - private: AutotunerCompileUtil(const AutotuneConfig& config, Compiler* compiler, se::StreamExecutor& stream_executor, se::Stream& stream, se::DeviceMemoryAllocator& allocator, const DebugOptions& opts); - StatusOr> CompileNoCache( - AutotunerCompileUtil::GenerateModuleFn module_extractor); - StatusOr Execute(Executable& executable, std::vector arguments); diff --git a/tensorflow/compiler/xla/service/gpu/triton_autotuner.cc b/tensorflow/compiler/xla/service/gpu/triton_autotuner.cc index 75b527a4ca5140..eabd445a72240c 100644 --- a/tensorflow/compiler/xla/service/gpu/triton_autotuner.cc +++ b/tensorflow/compiler/xla/service/gpu/triton_autotuner.cc @@ -89,6 +89,19 @@ static AutotuneResult::TritonGemmKey GemmKey(int64_t block_m, int64_t block_n, return key; } +struct CompilationKey { + template + friend H AbslHashValue(H h, const CompilationKey& k) { + return H::combine(std::move(h), k.key.SerializeAsString()); + } + + bool operator==(const CompilationKey& k) const { + return key.SerializeAsString() == k.key.SerializeAsString(); + } + + AutotuneResult::TritonGemmKey key; +}; + class TritonAutotunerVisitor : public DfsHloRewriteVisitor { public: TritonAutotunerVisitor( @@ -108,9 +121,7 @@ class TritonAutotunerVisitor : public DfsHloRewriteVisitor { VLOG(1) << "Tuning " << hlo->ToString(); TF_ASSIGN_OR_RETURN(AutotuneResult autotune_result, AutotunerUtil::Autotune(hlo, config_, [&] { - return AutotuneMatmulNoCache( - hlo, - AutotuneCacheKey(config_.GetModelStr(), *hlo)); + return AutotuneMatmulNoCache(hlo); })); VLOG(1) << "Result: " << autotune_result.ShortDebugString(); @@ -129,11 +140,7 @@ class TritonAutotunerVisitor : public DfsHloRewriteVisitor { private: // Autotunes a matmul without using the autotuning cache. - // - // `cache_key`: The cache key corresponding to the code of the fusion and the - // device type. Passing it to avoid recalculating it everywhere it's needed. - StatusOr AutotuneMatmulNoCache( - const HloInstruction* instr, const AutotuneCacheKey& cache_key) { + StatusOr AutotuneMatmulNoCache(const HloInstruction* instr) { if (config_.IsDeviceless()) { return InternalError( "Expect autotune result cache hit for deviceless compilation."); @@ -169,26 +176,38 @@ class TritonAutotunerVisitor : public DfsHloRewriteVisitor { GpuDeviceInfo gpu_device_info = GetGpuDeviceInfo(config_.GetExecutor()); - // Pre-compile all versions first using the thread pool. + absl::Mutex executables_mutex; + absl::flat_hash_map> + executables; + + auto compile = [&](const AutotuneResult::TritonGemmKey& conf) { + TF_ASSIGN_OR_RETURN(std::unique_ptr executable, + autotuner_compile_util_->Compile([&] { + return TritonGemmAutotuneExtractor( + conf, gpu_device_info, + fusion.FusionInstruction()); + })); + absl::MutexLock lock(&executables_mutex); + CHECK(executables.emplace(CompilationKey{conf}, std::move(executable)) + .second); + return OkStatus(); + }; + + // Pre-compile all versions first. if (thread_pool_ && debug_opts.xla_gpu_force_compilation_parallelism() != 1) { tsl::BlockingCounter counter(configurations.size()); for (const AutotuneResult::TritonGemmKey& conf : configurations) { thread_pool_->Schedule([&] { - AutotuneResult config; - *config.mutable_triton() = conf; - StatusOr res = - autotuner_compile_util_->Compile(config, cache_key, [&] { - return TritonGemmAutotuneExtractor(conf, gpu_device_info, - fusion.FusionInstruction()); - }); - if (!res.ok()) { - LOG(ERROR) << "Failure: " << res.status(); - } + TF_CHECK_OK(compile(conf)); counter.DecrementCount(); }); } counter.Wait(); + } else { + for (const AutotuneResult::TritonGemmKey& conf : configurations) { + TF_RETURN_IF_ERROR(compile(conf)); + } } std::vector inputs; @@ -204,7 +223,7 @@ class TritonAutotunerVisitor : public DfsHloRewriteVisitor { if (config_.should_check_correctness()) { TF_ASSIGN_OR_RETURN( reference_buffer, - RunMatmulWithCublas(fusion, stream, allocator, inputs, cache_key)); + RunMatmulWithCublas(fusion, stream, allocator, inputs)); } std::vector results; @@ -214,9 +233,14 @@ class TritonAutotunerVisitor : public DfsHloRewriteVisitor { AutotuneResult res; *res.mutable_triton() = conf; - TF_ASSIGN_OR_RETURN( - std::optional profiling_output, - RunMatmulWithConfig(fusion, conf, stream, inputs, cache_key)); + auto it = executables.find(CompilationKey{conf}); + if (it == executables.end() || it->second == nullptr) { + VLOG(1) << "Skipping this tiling."; + continue; + } + TF_ASSIGN_OR_RETURN(std::optional profiling_output, + autotuner_compile_util_->ProfileExecutable( + it->second.get(), stream, inputs)); if (!profiling_output) { VLOG(1) << "Skipping this tiling."; @@ -283,28 +307,6 @@ class TritonAutotunerVisitor : public DfsHloRewriteVisitor { return best; } - // Run a fusion with a given tiling on given buffers. - // Returns `true` if run successfully, `false` if the tiling has to be - // skipped. - // - // `cache_key`: The cache key corresponding to the code of the fusion and the - // device type. Passing it to avoid recalculating it everywhere it's needed. - StatusOr> RunMatmulWithConfig( - const HloComputation& hlo_computation, - const AutotuneResult::TritonGemmKey& autotune_config, se::Stream* stream, - absl::Span input_buffers, - const AutotuneCacheKey& cache_key) { - AutotuneResult config; - *config.mutable_triton() = autotune_config; - - return autotuner_compile_util_->GenerateAndProfileExecutable( - config, cache_key, stream, input_buffers, [&] { - return TritonGemmAutotuneExtractor( - autotune_config, GetGpuDeviceInfo(config_.GetExecutor()), - hlo_computation.FusionInstruction()); - }); - } - StatusOr> TritonGemmAutotuneExtractor( const AutotuneResult::TritonGemmKey& key, const GpuDeviceInfo& gpu_device_info, const HloInstruction* fusion) { @@ -345,31 +347,20 @@ class TritonAutotunerVisitor : public DfsHloRewriteVisitor { return new_module; } - // Runs a matmul fusion without Triton - with cuBLAS, to generate a reference - // output. - // - // `cache_key`: The cache key corresponding to the code of the fusion and the - // device type. Passing it to avoid recalculating it everywhere it's needed. + // Runs matmul fusion contents without Triton - with cuBLAS, to generate + // a reference output. StatusOr RunMatmulWithCublas( const HloComputation& original_computation, se::Stream* stream, se::DeviceMemoryAllocator* allocator, - absl::Span input_buffers, - const AutotuneCacheKey& cache_key) { - AutotuneResult res; - - // We need some value to cache compilation. We associate the compiled module - // with autotune key + result picking 0th algorithm for cuBLAS. - AutotuneResult::GemmKey gemm; - gemm.set_algorithm(0); - *res.mutable_gemm() = gemm; - + absl::Span input_buffers) { + StatusOr> executable = + autotuner_compile_util_->Compile([&] { + return CublasGemmAutotuneExtractor( + GetGpuDeviceInfo(config_.GetExecutor()), &original_computation); + }); TF_ASSIGN_OR_RETURN(std::optional output, - autotuner_compile_util_->GenerateAndProfileExecutable( - res, cache_key, stream, input_buffers, [&] { - return CublasGemmAutotuneExtractor( - GetGpuDeviceInfo(config_.GetExecutor()), - &original_computation); - })); + autotuner_compile_util_->ProfileExecutable( + executable->get(), stream, input_buffers)); TF_RET_CHECK(output.has_value()); return std::move(output->output); } @@ -444,14 +435,14 @@ std::vector GetFixedMatmulAutotuneConfigs( GemmKey(128, 256, 32, 1, 3, 8), GemmKey(256, 128, 32, 1, 3, 8), GemmKey(256, 64, 32, 1, 4, 4), GemmKey(64, 256, 32, 1, 4, 4), GemmKey(128, 64, 32, 1, 4, 4), GemmKey(64, 128, 32, 1, 4, 4), - GemmKey(128, 256, 32, 1, 3, 8), GemmKey(256, 128, 128, 1, 3, 8), - GemmKey(256, 64, 128, 1, 4, 4), GemmKey(64, 256, 128, 1, 4, 4), - GemmKey(128, 128, 128, 1, 4, 4), GemmKey(128, 64, 64, 1, 4, 4), - GemmKey(64, 128, 64, 1, 4, 4), GemmKey(128, 32, 64, 1, 4, 4), - GemmKey(64, 32, 64, 1, 4, 4), GemmKey(32, 128, 32, 1, 4, 4), - GemmKey(128, 128, 32, 1, 4, 4), GemmKey(16, 16, 256, 1, 3, 4), - GemmKey(128, 128, 64, 2, 1, 8), GemmKey(64, 64, 64, 1, 2, 4), - GemmKey(16, 64, 256, 8, 1, 4), GemmKey(256, 256, 128, 1, 3, 8)}, + GemmKey(256, 128, 128, 1, 3, 8), GemmKey(256, 64, 128, 1, 4, 4), + GemmKey(64, 256, 128, 1, 4, 4), GemmKey(128, 128, 128, 1, 4, 4), + GemmKey(128, 64, 64, 1, 4, 4), GemmKey(64, 128, 64, 1, 4, 4), + GemmKey(128, 32, 64, 1, 4, 4), GemmKey(64, 32, 64, 1, 4, 4), + GemmKey(32, 128, 32, 1, 4, 4), GemmKey(128, 128, 32, 1, 4, 4), + GemmKey(16, 16, 256, 1, 3, 4), GemmKey(128, 128, 64, 2, 1, 8), + GemmKey(64, 64, 64, 1, 2, 4), GemmKey(16, 64, 256, 8, 1, 4), + GemmKey(256, 256, 128, 1, 3, 8)}, std::back_inserter(configs)); } if (compute_capability.IsAtLeast(se::CudaComputeCapability::HOPPER)) { From ee37da9984452c26c3eac9463f2c0edf670b4c95 Mon Sep 17 00:00:00 2001 From: Oleg Shyshkov Date: Tue, 25 Jul 2023 09:27:41 -0700 Subject: [PATCH 111/410] [XLA] Simplify should_fuse computations. Makes the code easier to read and reduces a level of nestings. PiperOrigin-RevId: 550908705 --- .../xla/service/instruction_fusion.cc | 62 +++++++++---------- 1 file changed, 31 insertions(+), 31 deletions(-) diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index b9e70760ee80fa..0d3967e6fd64d8 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -575,58 +575,58 @@ StatusOr InstructionFusion::Run( HloInstruction* fusion_instruction = nullptr; - FusionDecision should_fuse(do_not_duplicate.count(operand) == 0, - "operand can not be duplicated"); - // Try "regular" fusion if the operand may be duplicated. Otherwise, // perform multi-output fusion, unless this creates a cycle. - if (should_fuse) { - should_fuse = ShouldFuse(instruction, i); - if (should_fuse && consume_fuel()) { - if (dump_fusion) { - DumpPreFusionState(computation, /*consumer=*/instruction, - /*producer=*/operand); - } + FusionDecision use_regular_fusion(do_not_duplicate.count(operand) == 0, + "operand can not be duplicated"); + if (use_regular_fusion) { + use_regular_fusion = + use_regular_fusion.And(ShouldFuse(instruction, i)); + } - fusion_queue->PreFusion(operand, instruction); - fusion_instruction = Fuse(operand, instruction, computation); + if (use_regular_fusion && consume_fuel()) { + if (dump_fusion) { + DumpPreFusionState(computation, /*consumer=*/instruction, + /*producer=*/operand); } + + fusion_queue->PreFusion(operand, instruction); + fusion_instruction = Fuse(operand, instruction, computation); } - if (!should_fuse) { - FusionDecision can_fuse_mof = - ShouldFuseIntoMultiOutput(instruction, i); - if (can_fuse_mof) { - can_fuse_mof = can_fuse_mof.And( + FusionDecision use_mof; + if (!use_regular_fusion) { + use_mof = ShouldFuseIntoMultiOutput(instruction, i); + if (use_mof) { + use_mof = use_mof.And( FusionDecision{!MultiOutputFusionCreatesCycle( operand, instruction, *reachability), "multi-output fusion creates a cycle"}); } - if (can_fuse_mof) { - if (consume_fuel()) { - if (dump_fusion) { - DumpPreFusionState(computation, /*consumer=*/instruction, - /*producer=*/operand, /*is_mof=*/true); - } - - fusion_queue->PreFusion(operand, instruction); - fusion_instruction = - FuseIntoMultiOutput(operand, instruction, computation); + if (use_mof && consume_fuel()) { + if (dump_fusion) { + DumpPreFusionState(computation, /*consumer=*/instruction, + /*producer=*/operand, /*is_mof=*/true); } + + fusion_queue->PreFusion(operand, instruction); + fusion_instruction = + FuseIntoMultiOutput(operand, instruction, computation); } - should_fuse = should_fuse.Or(can_fuse_mof); } if (fusion_instruction == nullptr) { - CHECK(!should_fuse.CanFuse()); + FusionDecision fusion_decision = use_regular_fusion.Or(use_mof); + CHECK(!fusion_decision.CanFuse()); if (dump_fusion) { VLOG(2) << "Not fusing " << operand->ToShortString() << "| into |" << instruction->ToShortString() << "| as " - << should_fuse.Explain(); + << fusion_decision.Explain(); DumpNotFusingState(computation, /*consumer=*/instruction, - /*producer=*/operand, /*decision=*/should_fuse); + /*producer=*/operand, + /*decision=*/fusion_decision); } fusion_queue->NotFusingInstruction(operand, instruction); From d0cb12441747ef9fb14137cb99f0b6a17e22b5e4 Mon Sep 17 00:00:00 2001 From: David Svantesson Date: Tue, 25 Jul 2023 09:33:40 -0700 Subject: [PATCH 112/410] PR #61235: Add inter scheduler support on AArch64 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/61235 This PR adds support for inter op scheduler in the oneDNN + ACL build. It enables the creation of more than 1 scheduler inside ACL to increase performance of models with parallel ops. For benchmarked NLP models the average performance increase is 9%, for CV classification models its around 2%. The below benchmarks were done with the following PR’s applied as patches: #60026, #60723, #61110, #61114, #61093, #61123 ![nlp_models_benchmarked](https://github.com/tensorflow/tensorflow/assets/117736650/7a3a4df1-475b-4dc4-ab85-7e9b97eb7b27) ![cv_models_benchmarked](https://github.com/tensorflow/tensorflow/assets/117736650/245103e6-f6d1-4da9-abb8-3a90a86217a9) Copybara import of the project: -- 0883ab272279db14e4fef1e7829fd5f363178f63 by David Svantesson : Add inter scheduler support -- 8dae4150505064fc8d6dee1a48df66de67c96af3 by Milos Puzovic : Address comments from review -- ecac66d0ba9407ca5a48522a3b7c64a9da203223 by Milos Puzovic : Fix formating Merging this change closes #61235 COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/tensorflow/pull/61235 from davsva01:acl_inter_scheduler ecac66d0ba9407ca5a48522a3b7c64a9da203223 PiperOrigin-RevId: 550910306 --- tensorflow/core/kernels/mkl/mkl_concat_op.cc | 6 +- .../kernels/mkl/mkl_conv_grad_filter_ops.cc | 6 +- .../kernels/mkl/mkl_conv_grad_input_ops.cc | 6 +- tensorflow/core/kernels/mkl/mkl_conv_ops.cc | 6 +- .../mkl/mkl_eltwise_activation_base_op.h | 6 +- .../kernels/mkl/mkl_fused_batch_norm_op.cc | 10 +- .../core/kernels/mkl/mkl_matmul_ops_common.h | 10 +- .../kernels/mkl/mkl_pooling_ops_common.cc | 4 +- .../core/kernels/mkl/mkl_pooling_ops_common.h | 6 +- .../core/kernels/mkl/mkl_quantize_op.cc | 6 +- tensorflow/core/kernels/mkl/mkl_relu_op.cc | 10 +- tensorflow/core/kernels/mkl/mkl_softmax_op.cc | 6 +- tensorflow/core/util/mkl_util.h | 20 ++-- tensorflow/tsl/util/onednn_threadpool.h | 2 + tensorflow/workspace2.bzl | 6 +- .../acl_thread_local_scheduler.patch | 98 +++++++++++++++++++ .../onednn_acl_thread_local_scheduler.patch | 98 +++++++++++++++++++ 17 files changed, 254 insertions(+), 52 deletions(-) create mode 100644 third_party/compute_library/acl_thread_local_scheduler.patch create mode 100644 third_party/mkl_dnn/onednn_acl_thread_local_scheduler.patch diff --git a/tensorflow/core/kernels/mkl/mkl_concat_op.cc b/tensorflow/core/kernels/mkl/mkl_concat_op.cc index 497f997860176b..804567b7e79b25 100644 --- a/tensorflow/core/kernels/mkl/mkl_concat_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_concat_op.cc @@ -32,7 +32,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/mkl_util.h" -#ifdef DNNL_AARCH64_USE_ACL +#if defined(DNNL_AARCH64_USE_ACL) && defined(ENABLE_ONEDNN_OPENMP) #include "tensorflow/core/platform/mutex.h" #endif @@ -297,7 +297,7 @@ class MklConcatFwdPrimitive : public MklPrimitive { const dnnl::memory& dst_data, const MklConcatFwdParams& concat_fwd_dims, std::shared_ptr fwd_stream) { -#ifdef DNNL_AARCH64_USE_ACL +#if defined(DNNL_AARCH64_USE_ACL) && defined(ENABLE_ONEDNN_OPENMP) mutex_lock lock(primitive_execution_mu_); #endif DCHECK_EQ(in_data.size(), context_.data_mem.size()); @@ -397,7 +397,7 @@ class MklConcatFwdPrimitive : public MklPrimitive { struct ConcatFwdContext context_; -#ifdef DNNL_AARCH64_USE_ACL +#if defined(DNNL_AARCH64_USE_ACL) && defined(ENABLE_ONEDNN_OPENMP) mutex primitive_execution_mu_; #endif }; diff --git a/tensorflow/core/kernels/mkl/mkl_conv_grad_filter_ops.cc b/tensorflow/core/kernels/mkl/mkl_conv_grad_filter_ops.cc index 7c24a23d68226e..d6663baa914741 100644 --- a/tensorflow/core/kernels/mkl/mkl_conv_grad_filter_ops.cc +++ b/tensorflow/core/kernels/mkl/mkl_conv_grad_filter_ops.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/core/kernels/mkl/mkl_conv_ops.h" #include "tensorflow/core/util/use_cudnn.h" #include "tensorflow/core/util/work_sharder.h" -#ifdef DNNL_AARCH64_USE_ACL +#if defined(DNNL_AARCH64_USE_ACL) && defined(ENABLE_ONEDNN_OPENMP) #include "tensorflow/core/platform/mutex.h" #endif @@ -93,7 +93,7 @@ class MklConvBwdFilterPrimitive : public MklPrimitive { void Execute(const T* src_data, const T* diff_filter_data, const T* diff_bias_data, const T* diff_dst_data, std::shared_ptr bwd_filter_stream) { -#ifdef DNNL_AARCH64_USE_ACL +#if defined(DNNL_AARCH64_USE_ACL) && defined(ENABLE_ONEDNN_OPENMP) mutex_lock lock(primitive_execution_mu_); #endif #if !defined(ENABLE_ONEDNN_OPENMP) && !defined(ENABLE_ONEDNN_V3) @@ -315,7 +315,7 @@ class MklConvBwdFilterPrimitive : public MklPrimitive { struct ConvBwdFilterContext context_; -#ifdef DNNL_AARCH64_USE_ACL +#if defined(DNNL_AARCH64_USE_ACL) && defined(ENABLE_ONEDNN_OPENMP) mutex primitive_execution_mu_; #endif }; diff --git a/tensorflow/core/kernels/mkl/mkl_conv_grad_input_ops.cc b/tensorflow/core/kernels/mkl/mkl_conv_grad_input_ops.cc index 16a6db176843b1..19a484a387a409 100644 --- a/tensorflow/core/kernels/mkl/mkl_conv_grad_input_ops.cc +++ b/tensorflow/core/kernels/mkl/mkl_conv_grad_input_ops.cc @@ -30,7 +30,7 @@ limitations under the License. #include "tensorflow/core/kernels/mkl/mkl_conv_ops.h" #include "tensorflow/core/util/use_cudnn.h" #include "tensorflow/core/util/work_sharder.h" -#ifdef DNNL_AARCH64_USE_ACL +#if defined(DNNL_AARCH64_USE_ACL) && defined(ENABLE_ONEDNN_OPENMP) #include "tensorflow/core/platform/mutex.h" #endif @@ -100,7 +100,7 @@ class MklConvBwdInputPrimitive : public MklPrimitive { void Execute(const T* diff_src_data, const T* filter_data, const T* diff_dst_data, std::shared_ptr bwd_input_stream) { -#ifdef DNNL_AARCH64_USE_ACL +#if defined(DNNL_AARCH64_USE_ACL) && defined(ENABLE_ONEDNN_OPENMP) mutex_lock lock(primitive_execution_mu_); #endif #if !defined(ENABLE_ONEDNN_OPENMP) && !defined(ENABLE_ONEDNN_V3) @@ -255,7 +255,7 @@ class MklConvBwdInputPrimitive : public MklPrimitive { } struct ConvBwdInputContext context_; -#ifdef DNNL_AARCH64_USE_ACL +#if defined(DNNL_AARCH64_USE_ACL) && defined(ENABLE_ONEDNN_OPENMP) mutex primitive_execution_mu_; #endif }; diff --git a/tensorflow/core/kernels/mkl/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl/mkl_conv_ops.cc index 6d76cc4aa5cfbc..84486cc1abb9d3 100644 --- a/tensorflow/core/kernels/mkl/mkl_conv_ops.cc +++ b/tensorflow/core/kernels/mkl/mkl_conv_ops.cc @@ -27,7 +27,7 @@ limitations under the License. #include "tensorflow/core/kernels/mkl/mkl_kernel_util.h" #include "tensorflow/core/kernels/mkl/mkl_quantized_conv_ops.h" #include "tensorflow/core/kernels/no_op.h" -#ifdef DNNL_AARCH64_USE_ACL +#if defined(DNNL_AARCH64_USE_ACL) && defined(ENABLE_ONEDNN_OPENMP) #include "tensorflow/core/platform/mutex.h" #endif @@ -177,7 +177,7 @@ class MklConvFwdPrimitive : public MklPrimitive { const Tinput* bn_offset_data, const Tinput* bn_rsqrt_data, const MklConvFwdParams& convFwdDims, std::shared_ptr fwd_stream, void* sp_data) { -#ifdef DNNL_AARCH64_USE_ACL +#if defined(DNNL_AARCH64_USE_ACL) && defined(ENABLE_ONEDNN_OPENMP) // When we are using single global cache then in this case we can have // multiple threads running the same primitive that we created so this // should happen under the lock. @@ -582,7 +582,7 @@ class MklConvFwdPrimitive : public MklPrimitive { struct ConvFwdContext context_; -#ifdef DNNL_AARCH64_USE_ACL +#if defined(DNNL_AARCH64_USE_ACL) && defined(ENABLE_ONEDNN_OPENMP) // Guards Execution() mutex primitive_execution_mu_; #endif diff --git a/tensorflow/core/kernels/mkl/mkl_eltwise_activation_base_op.h b/tensorflow/core/kernels/mkl/mkl_eltwise_activation_base_op.h index 9bfddc2f1d0991..40d67a2382bfdd 100644 --- a/tensorflow/core/kernels/mkl/mkl_eltwise_activation_base_op.h +++ b/tensorflow/core/kernels/mkl/mkl_eltwise_activation_base_op.h @@ -30,7 +30,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/util/mkl_util.h" -#ifdef DNNL_AARCH64_USE_ACL +#if defined(DNNL_AARCH64_USE_ACL) && defined(ENABLE_ONEDNN_OPENMP) #include "tensorflow/core/platform/mutex.h" #endif @@ -100,7 +100,7 @@ class MklEltwiseFwdActivationPrimitive : public MklPrimitive { // src_data: input data buffer of src // dst_data: output data buffer of dst void Execute(const T* src_data, T* dst_data, OpKernelContext* op_context) { -#ifdef DNNL_AARCH64_USE_ACL +#if defined(DNNL_AARCH64_USE_ACL) && defined(ENABLE_ONEDNN_OPENMP) mutex_lock lock(primitive_execution_mu_); #endif context_.src_mem->set_data_handle( @@ -202,7 +202,7 @@ class MklEltwiseFwdActivationPrimitive : public MklPrimitive { struct EltwiseFwdActivationContext context_; -#ifdef DNNL_AARCH64_USE_ACL +#if defined(DNNL_AARCH64_USE_ACL) && defined(ENABLE_ONEDNN_OPENMP) mutex primitive_execution_mu_; #endif }; diff --git a/tensorflow/core/kernels/mkl/mkl_fused_batch_norm_op.cc b/tensorflow/core/kernels/mkl/mkl_fused_batch_norm_op.cc index 9d4736fc6a83a8..e8f0d26915ecd3 100644 --- a/tensorflow/core/kernels/mkl/mkl_fused_batch_norm_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_fused_batch_norm_op.cc @@ -24,7 +24,7 @@ limitations under the License. #include "tensorflow/core/kernels/no_op.h" #include "tensorflow/core/util/mkl_util.h" #include "tensorflow/core/util/tensor_format.h" -#ifdef DNNL_AARCH64_USE_ACL +#if defined(DNNL_AARCH64_USE_ACL) && defined(ENABLE_ONEDNN_OPENMP) #include "tensorflow/core/platform/mutex.h" #endif @@ -131,7 +131,7 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive { #endif // !ENABLE_ONEDNN_V3 U* mean_data, U* variance_data, std::shared_ptr fwd_stream, U* workspace_data) { -#ifdef DNNL_AARCH64_USE_ACL +#if defined(DNNL_AARCH64_USE_ACL) && defined(ENABLE_ONEDNN_OPENMP) mutex_lock lock(primitive_execution_mu_); #endif #if !defined(ENABLE_ONEDNN_OPENMP) && !defined(ENABLE_ONEDNN_V3) @@ -416,7 +416,7 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive { struct BatchNormFwdContext context_; -#ifdef DNNL_AARCH64_USE_ACL +#if defined(DNNL_AARCH64_USE_ACL) && defined(ENABLE_ONEDNN_OPENMP) mutex primitive_execution_mu_; #endif }; @@ -542,7 +542,7 @@ class MklFusedBatchNormBwdPrimitive : public MklPrimitive { U* diff_scale_data, U* diff_shift_data, U* res_space_data, #endif // !ENABLE_ONEDNN_V3 std::shared_ptr bwd_stream) { -#ifdef DNNL_AARCH64_USE_ACL +#if defined(DNNL_AARCH64_USE_ACL) && defined(ENABLE_ONEDNN_OPENMP) mutex_lock lock(primitive_execution_mu_); #endif #if !defined(ENABLE_ONEDNN_OPENMP) && !defined(ENABLE_ONEDNN_V3) @@ -765,7 +765,7 @@ class MklFusedBatchNormBwdPrimitive : public MklPrimitive { struct BatchNormBwdContext context_; -#ifdef DNNL_AARCH64_USE_ACL +#if defined(DNNL_AARCH64_USE_ACL) && defined(ENABLE_ONEDNN_OPENMP) mutex primitive_execution_mu_; #endif }; diff --git a/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h b/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h index 58e1e9795bfdb4..e8028e4b865653 100644 --- a/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h +++ b/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h @@ -29,7 +29,7 @@ limitations under the License. #include "tensorflow/core/kernels/mkl/mkl_kernel_util.h" #include "tensorflow/core/util/mkl_util.h" #include "tensorflow/core/util/onednn_env_vars.h" -#ifdef DNNL_AARCH64_USE_ACL +#if defined(DNNL_AARCH64_USE_ACL) && defined(ENABLE_ONEDNN_OPENMP) #include "tensorflow/core/platform/mutex.h" #endif @@ -153,7 +153,7 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive { const void* bias_data, Toutput* dst_data, const MklDnnMatMulFwdParams& matmul_fwd_params, void* sp_data, std::shared_ptr fwd_stream) { -#ifdef DNNL_AARCH64_USE_ACL +#if defined(DNNL_AARCH64_USE_ACL) && defined(ENABLE_ONEDNN_OPENMP) mutex_lock lock(primitive_execution_mu_); #endif context_.src_mem->set_data_handle( @@ -445,7 +445,7 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive { struct MklDnnMatMulFwdContext context_; -#ifdef DNNL_AARCH64_USE_ACL +#if defined(DNNL_AARCH64_USE_ACL) && defined(ENABLE_ONEDNN_OPENMP) // Guards Execution() mutex primitive_execution_mu_; #endif @@ -811,7 +811,7 @@ class MklMatMulPrimitive : public MklPrimitive { void Execute(const std::shared_ptr& stream, const Tlhs* a_data, const Trhs* b_data, const Toutput* c_data, void* sp_data, void* mul_data = nullptr, void* add_data = nullptr) { -#ifdef DNNL_AARCH64_USE_ACL +#if defined(DNNL_AARCH64_USE_ACL) && defined(ENABLE_ONEDNN_OPENMP) mutex_lock lock(primitive_execution_mu_); #endif #if !defined(ENABLE_ONEDNN_OPENMP) && !defined(ENABLE_ONEDNN_V3) @@ -1011,7 +1011,7 @@ class MklMatMulPrimitive : public MklPrimitive { } struct MklMatMulContext context_; -#ifdef DNNL_AARCH64_USE_ACL +#if defined(DNNL_AARCH64_USE_ACL) && defined(ENABLE_ONEDNN_OPENMP) mutex primitive_execution_mu_; #endif }; diff --git a/tensorflow/core/kernels/mkl/mkl_pooling_ops_common.cc b/tensorflow/core/kernels/mkl/mkl_pooling_ops_common.cc index c73233cab8dd26..1f8ebc117e782b 100644 --- a/tensorflow/core/kernels/mkl/mkl_pooling_ops_common.cc +++ b/tensorflow/core/kernels/mkl/mkl_pooling_ops_common.cc @@ -101,7 +101,7 @@ template void MklPoolingFwdPrimitive::Execute(const T* src_data, T* dst_data, void* ws_data, std::shared_ptr fwd_stream) { -#ifdef DNNL_AARCH64_USE_ACL +#if defined(DNNL_AARCH64_USE_ACL) && defined(ENABLE_ONEDNN_OPENMP) mutex_lock lock(primitive_execution_mu_); #endif #if !defined(ENABLE_ONEDNN_OPENMP) && !defined(ENABLE_ONEDNN_V3) @@ -218,7 +218,7 @@ template void MklPoolingBwdPrimitive::Execute(const T* diff_dst_data, T* diff_src_data, const void* ws_data, std::shared_ptr bwd_stream) { -#ifdef DNNL_AARCH64_USE_ACL +#if defined(DNNL_AARCH64_USE_ACL) && defined(ENABLE_ONEDNN_OPENMP) mutex_lock lock(primitive_execution_mu_); #endif #if !defined(ENABLE_ONEDNN_OPENMP) && !defined(ENABLE_ONEDNN_V3) diff --git a/tensorflow/core/kernels/mkl/mkl_pooling_ops_common.h b/tensorflow/core/kernels/mkl/mkl_pooling_ops_common.h index 012244a79691ad..fe5e7f032855a6 100644 --- a/tensorflow/core/kernels/mkl/mkl_pooling_ops_common.h +++ b/tensorflow/core/kernels/mkl/mkl_pooling_ops_common.h @@ -27,7 +27,7 @@ limitations under the License. #include "tensorflow/core/framework/ops_util.h" #include "tensorflow/core/util/mkl_util.h" #include "tensorflow/core/util/padding.h" -#ifdef DNNL_AARCH64_USE_ACL +#if defined(DNNL_AARCH64_USE_ACL) && defined(ENABLE_ONEDNN_OPENMP) #include "tensorflow/core/platform/mutex.h" #endif @@ -176,7 +176,7 @@ class MklPoolingFwdPrimitive : public MklPrimitive { struct PoolingFwdContext context_; -#ifdef DNNL_AARCH64_USE_ACL +#if defined(DNNL_AARCH64_USE_ACL) && defined(ENABLE_ONEDNN_OPENMP) mutex primitive_execution_mu_; #endif }; @@ -331,7 +331,7 @@ class MklPoolingBwdPrimitive : public MklPrimitive { }; struct PoolingBwdContext context_; -#ifdef DNNL_AARCH64_USE_ACL +#if defined(DNNL_AARCH64_USE_ACL) && defined(ENABLE_ONEDNN_OPENMP) mutex primitive_execution_mu_; #endif }; diff --git a/tensorflow/core/kernels/mkl/mkl_quantize_op.cc b/tensorflow/core/kernels/mkl/mkl_quantize_op.cc index 36e7178d5b7249..41c9d260d31c16 100644 --- a/tensorflow/core/kernels/mkl/mkl_quantize_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_quantize_op.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/core/graph/mkl_graph_util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/util/mkl_util.h" -#ifdef DNNL_AARCH64_USE_ACL +#if defined(DNNL_AARCH64_USE_ACL) && defined(ENABLE_ONEDNN_OPENMP) #include "tensorflow/core/platform/mutex.h" #endif @@ -111,7 +111,7 @@ class MklReorderWithScalePrimitive : public MklPrimitive { void* scale_data, #endif // ENABLE_ONEDNN_V3 std::shared_ptr reorder_stream) { -#ifdef DNNL_AARCH64_USE_ACL +#if defined(DNNL_AARCH64_USE_ACL) && defined(ENABLE_ONEDNN_OPENMP) mutex_lock lock(primitive_execution_mu_); #endif #if !defined(ENABLE_ONEDNN_OPENMP) && !defined(ENABLE_ONEDNN_V3) @@ -202,7 +202,7 @@ class MklReorderWithScalePrimitive : public MklPrimitive { #endif // ENABLE_ONEDNN_V3 } -#ifdef DNNL_AARCH64_USE_ACL +#if defined(DNNL_AARCH64_USE_ACL) && defined(ENABLE_ONEDNN_OPENMP) mutex primitive_execution_mu_; #endif }; diff --git a/tensorflow/core/kernels/mkl/mkl_relu_op.cc b/tensorflow/core/kernels/mkl/mkl_relu_op.cc index 03f19e21da86ca..9a43376b35bd75 100644 --- a/tensorflow/core/kernels/mkl/mkl_relu_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_relu_op.cc @@ -26,7 +26,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/util/mkl_util.h" -#ifdef DNNL_AARCH64_USE_ACL +#if defined(DNNL_AARCH64_USE_ACL) && defined(ENABLE_ONEDNN_OPENMP) #include "tensorflow/core/platform/mutex.h" #endif @@ -77,7 +77,7 @@ class MklEltwiseFwdPrimitive : public MklPrimitive { // dst_data: output data buffer of dst void Execute(const T* src_data, T* dst_data, std::shared_ptr fwd_stream) { -#ifdef DNNL_AARCH64_USE_ACL +#if defined(DNNL_AARCH64_USE_ACL) && defined(ENABLE_ONEDNN_OPENMP) mutex_lock lock(primitive_execution_mu_); #endif #ifndef ENABLE_ONEDNN_OPENMP @@ -160,7 +160,7 @@ class MklEltwiseFwdPrimitive : public MklPrimitive { struct EltwiseFwdContext context_; -#ifdef DNNL_AARCH64_USE_ACL +#if defined(DNNL_AARCH64_USE_ACL) && defined(ENABLE_ONEDNN_OPENMP) mutex primitive_execution_mu_; #endif }; @@ -259,7 +259,7 @@ class MklEltwiseBwdPrimitive : public MklPrimitive { // diff_src_data: output data buffer of diff_src void Execute(const T* src_data, const T* diff_dst_data, T* diff_src_data, std::shared_ptr bwd_stream) { -#ifdef DNNL_AARCH64_USE_ACL +#if defined(DNNL_AARCH64_USE_ACL) && defined(ENABLE_ONEDNN_OPENMP) mutex_lock lock(primitive_execution_mu_); #endif #ifndef ENABLE_ONEDNN_OPENMP @@ -368,7 +368,7 @@ class MklEltwiseBwdPrimitive : public MklPrimitive { struct EltwiseBwdContext context_; -#ifdef DNNL_AARCH64_USE_ACL +#if defined(DNNL_AARCH64_USE_ACL) && defined(ENABLE_ONEDNN_OPENMP) mutex primitive_execution_mu_; #endif }; diff --git a/tensorflow/core/kernels/mkl/mkl_softmax_op.cc b/tensorflow/core/kernels/mkl/mkl_softmax_op.cc index 50caffa5e2d1e7..27ed530aeae4a7 100644 --- a/tensorflow/core/kernels/mkl/mkl_softmax_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_softmax_op.cc @@ -26,7 +26,7 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/util/mkl_util.h" #include "tensorflow/core/util/tensor_format.h" -#ifdef DNNL_AARCH64_USE_ACL +#if defined(DNNL_AARCH64_USE_ACL) && defined(ENABLE_ONEDNN_OPENMP) #include "tensorflow/core/platform/mutex.h" #endif #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" @@ -64,7 +64,7 @@ class MklSoftmaxPrimitive : public MklPrimitive { // dst_data: output data buffer of dst void Execute(const T* src_data, T* dst_data, std::shared_ptr fwd_cpu_stream) { -#ifdef DNNL_AARCH64_USE_ACL +#if defined(DNNL_AARCH64_USE_ACL) && defined(ENABLE_ONEDNN_OPENMP) mutex_lock lock(primitive_execution_mu_); #endif #if !defined(ENABLE_ONEDNN_OPENMP) && !defined(ENABLE_ONEDNN_V3) @@ -160,7 +160,7 @@ class MklSoftmaxPrimitive : public MklPrimitive { struct SoftmaxFwdContext context_; -#ifdef DNNL_AARCH64_USE_ACL +#if defined(DNNL_AARCH64_USE_ACL) && defined(ENABLE_ONEDNN_OPENMP) mutex primitive_execution_mu_; #endif }; diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h index caa0b11305d199..0a8015fa6cce1f 100644 --- a/tensorflow/core/util/mkl_util.h +++ b/tensorflow/core/util/mkl_util.h @@ -39,7 +39,7 @@ limitations under the License. #include "tensorflow/core/util/onednn_env_vars.h" #include "tensorflow/core/util/padding.h" #include "tensorflow/core/util/tensor_format.h" -#ifdef DNNL_AARCH64_USE_ACL +#if defined(DNNL_AARCH64_USE_ACL) && defined(ENABLE_ONEDNN_OPENMP) #include "tensorflow/core/platform/mutex.h" #endif #include "tensorflow/tsl/util/onednn_threadpool.h" @@ -1859,7 +1859,7 @@ class LRUCache { } T* GetOp(const string& key) { -#ifdef DNNL_AARCH64_USE_ACL +#if defined(DNNL_AARCH64_USE_ACL) && defined(ENABLE_ONEDNN_OPENMP) mutex_lock lock(lru_mu_); #endif auto it = cache_.find(key); @@ -1875,7 +1875,7 @@ class LRUCache { } void SetOp(const string& key, T* op) { -#ifdef DNNL_AARCH64_USE_ACL +#if defined(DNNL_AARCH64_USE_ACL) && defined(ENABLE_ONEDNN_OPENMP) mutex_lock lock(lru_mu_); #endif if (lru_list_.size() >= capacity_) { @@ -1886,7 +1886,7 @@ class LRUCache { lru_list_.push_front(key); Entry entry(op, lru_list_.begin()); cache_.emplace(std::make_pair(key, std::move(entry))); -#ifdef DNNL_AARCH64_USE_ACL +#if defined(DNNL_AARCH64_USE_ACL) && defined(ENABLE_ONEDNN_OPENMP) FinishedAllocation(key); #endif } @@ -1899,7 +1899,7 @@ class LRUCache { lru_list_.clear(); } -#ifdef DNNL_AARCH64_USE_ACL +#if defined(DNNL_AARCH64_USE_ACL) && defined(ENABLE_ONEDNN_OPENMP) bool IsAllocating(const string& key) { mutex_lock lock(in_flight_mu_); return in_flight_.find(key) != in_flight_.end(); @@ -1964,7 +1964,7 @@ class LRUCache { // entry, while the back of the list is the least recently accessed entry. std::list lru_list_; -#ifdef DNNL_AARCH64_USE_ACL +#if defined(DNNL_AARCH64_USE_ACL) && defined(ENABLE_ONEDNN_OPENMP) // Guards access to the cache and LRU list mutex lru_mu_; @@ -1983,7 +1983,7 @@ class MklPrimitiveFactory { ~MklPrimitiveFactory() {} MklPrimitive* GetOp(const string& key) { -#ifndef DNNL_AARCH64_USE_ACL +#if !defined(DNNL_AARCH64_USE_ACL) || !defined(ENABLE_ONEDNN_OPENMP) auto& lru_cache = MklPrimitiveFactory::GetLRUCache(); return lru_cache.GetOp(key); #else @@ -2019,7 +2019,7 @@ class MklPrimitiveFactory { } void SetOp(const string& key, MklPrimitive* op) { -#ifndef DNNL_AARCH64_USE_ACL +#if !defined(DNNL_AARCH64_USE_ACL) || !defined(ENABLE_ONEDNN_OPENMP) auto& lru_cache = MklPrimitiveFactory::GetLRUCache(); lru_cache.SetOp(key, op); #else @@ -2069,7 +2069,7 @@ class MklPrimitiveFactory { private: static inline LRUCache& GetLRUCache() { static const int kCapacity = 1024; // cache capacity -#ifndef DNNL_AARCH64_USE_ACL +#if !defined(DNNL_AARCH64_USE_ACL) || !defined(ENABLE_ONEDNN_OPENMP) static thread_local LRUCache lru_cache_(kCapacity); #else static LRUCache lru_cache_(kCapacity); @@ -2077,7 +2077,7 @@ class MklPrimitiveFactory { return lru_cache_; } -#ifdef DNNL_AARCH64_USE_ACL +#if defined(DNNL_AARCH64_USE_ACL) && defined(ENABLE_ONEDNN_OPENMP) mutex primitive_creation_mu_; condition_variable primitive_creation_cv_; #endif diff --git a/tensorflow/tsl/util/onednn_threadpool.h b/tensorflow/tsl/util/onednn_threadpool.h index e01f68c39d4365..b8e4131aa43375 100644 --- a/tensorflow/tsl/util/onednn_threadpool.h +++ b/tensorflow/tsl/util/onednn_threadpool.h @@ -161,7 +161,9 @@ class OneDnnThreadPool : public threadpool_iface { num_threads == -1 ? eigen_interface_->NumThreads() : num_threads; #if DNNL_VERSION_MAJOR >= 3 || \ (DNNL_VERSION_MAJOR == 2 && DNNL_VERSION_MINOR >= 7) +#ifndef DNNL_AARCH64_USE_ACL dnnl_threadpool_interop_set_max_concurrency(num_threads_); +#endif // DNNL_AARCH64_USE_ACL #endif // DNNL_VERSION_MAJOR >= 3 || // (DNNL_VERSION_MAJOR == 2 && DNNL_VERSION_MINOR >= 7) } diff --git a/tensorflow/workspace2.bzl b/tensorflow/workspace2.bzl index f6b9376dd7d73b..1ac6c9c9d1a18d 100644 --- a/tensorflow/workspace2.bzl +++ b/tensorflow/workspace2.bzl @@ -211,6 +211,7 @@ def _tf_repositories(): "//third_party/mkl_dnn:onednn_acl_reorder_padded.patch", "//third_party/mkl_dnn:onednn_acl_reorder_update.patch", "//third_party/mkl_dnn:onednn_acl_reorder.patch", + "//third_party/mkl_dnn:onednn_acl_thread_local_scheduler.patch", ], sha256 = "a50993aa6265b799b040fe745e0010502f9f7103cc53a9525d59646aef006633", strip_prefix = "oneDNN-2.7.3", @@ -219,7 +220,10 @@ def _tf_repositories(): tf_http_archive( name = "compute_library", - patch_file = ["//third_party/compute_library:compute_library.patch"], + patch_file = [ + "//third_party/compute_library:compute_library.patch", + "//third_party/compute_library:acl_thread_local_scheduler.patch", + ], sha256 = "c4ca329a78da380163b2d86e91ba728349b6f0ee97d66e260a694ef37f0b0d93", strip_prefix = "ComputeLibrary-23.05.1", urls = tf_mirror_urls("https://github.com/ARM-software/ComputeLibrary/archive/v23.05.1.tar.gz"), diff --git a/third_party/compute_library/acl_thread_local_scheduler.patch b/third_party/compute_library/acl_thread_local_scheduler.patch new file mode 100644 index 00000000000000..9ebf6b71fdb44a --- /dev/null +++ b/third_party/compute_library/acl_thread_local_scheduler.patch @@ -0,0 +1,98 @@ +diff --git a/arm_compute/runtime/Scheduler.h b/arm_compute/runtime/Scheduler.h +index 9e8add1f9..cf5e2bf4c 100644 +--- a/arm_compute/runtime/Scheduler.h ++++ b/arm_compute/runtime/Scheduler.h +@@ -75,7 +75,7 @@ public: + + private: + static Type _scheduler_type; +- static std::shared_ptr _custom_scheduler; ++ static thread_local std::shared_ptr _custom_scheduler; + static std::map> _schedulers; + + Scheduler(); +diff --git a/src/cpu/operators/CpuDepthwiseConv2dAssemblyDispatch.cpp b/src/cpu/operators/CpuDepthwiseConv2dAssemblyDispatch.cpp +index a5b9eca56..d1ab19397 100644 +--- a/src/cpu/operators/CpuDepthwiseConv2dAssemblyDispatch.cpp ++++ b/src/cpu/operators/CpuDepthwiseConv2dAssemblyDispatch.cpp +@@ -60,8 +60,8 @@ void CpuDepthwiseConv2dAssemblyDispatch::configure(const ITensorInfo *src, + const ConvolutionInfo &info) + { + ARM_COMPUTE_LOG_PARAMS(src, weights, bias, dst, info); +- const CPUInfo &ci = NEScheduler::get().cpu_info(); +- const unsigned int num_threads = NEScheduler::get().num_threads(); ++ const CPUInfo &ci = CPUInfo::get(); ++ const unsigned int num_threads = CPUInfo::get().get_cpu_num(); + _pImpl->is_prepared = false; + _pImpl->are_weights_const = weights->are_values_constant(); + +diff --git a/src/cpu/operators/CpuPool2d.cpp b/src/cpu/operators/CpuPool2d.cpp +index 722cd36ee..03aef1632 100644 +--- a/src/cpu/operators/CpuPool2d.cpp ++++ b/src/cpu/operators/CpuPool2d.cpp +@@ -66,8 +66,8 @@ void CpuPool2d::configure(ITensorInfo *src, ITensorInfo *dst, const PoolingLayer + + if(run_optimised) + { +- const CPUInfo &ci = NEScheduler::get().cpu_info(); +- const unsigned int num_threads = NEScheduler::get().num_threads(); ++ const CPUInfo &ci = CPUInfo::get(); ++ const unsigned int num_threads = CPUInfo::get().get_cpu_num(); + + auto pooling_wrapper = std::make_unique(); + ARM_COMPUTE_ERROR_ON(pooling_wrapper == nullptr); +diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp + ******************************************************************************* + Copyright 2023 Arm Limited and affiliates. + SPDX-License-Identifier: Apache-2.0 + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ******************************************************************************* +index 9c8563140..f7771945a 100644 +--- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp ++++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp +@@ -623,8 +623,8 @@ void create_arm_gemm(std::unique_ptr &arm_ge + arm_gemm::Activation activation, const AsmGemmInfo &info) + { + Params p = extract_parameters(a, b, d, info); +- const CPUInfo &ci = NEScheduler::get().cpu_info(); +- unsigned int num_threads = NEScheduler::get().num_threads(); ++ const CPUInfo &ci = CPUInfo::get(); ++ unsigned int num_threads = CPUInfo::get().get_cpu_num(); + + arm_gemm::GemmConfig cfg; + cfg.weight_format = assembly_utils::map_to_arm_gemm_weight_format(info.weight_format); +@@ -696,8 +696,8 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected + ARM_COMPUTE_UNUSED(c); + arm_gemm::Activation act = assembly_utils::map_to_arm_gemm_activation(info.activation_info); + Params p = extract_parameters(a, b, d, info); +- const CPUInfo &ci = NEScheduler::get().cpu_info(); +- unsigned int num_threads = NEScheduler::get().num_threads(); ++ const CPUInfo &ci = CPUInfo::get(); ++ unsigned int num_threads = CPUInfo::get().get_cpu_num(); + arm_gemm::GemmConfig cfg; + cfg.weight_format = assembly_utils::map_to_arm_gemm_weight_format(info.weight_format); + arm_gemm::WeightFormat arm_gemm_expected_wf = assembly_utils::map_to_arm_gemm_weight_format(expected_weight_format); +diff --git a/src/runtime/Scheduler.cpp b/src/runtime/Scheduler.cpp +index 0713b9a2a..f15ac2e22 100644 +--- a/src/runtime/Scheduler.cpp ++++ b/src/runtime/Scheduler.cpp +@@ -47,7 +47,7 @@ Scheduler::Type Scheduler::_scheduler_type = Scheduler::Type::CPP; + Scheduler::Type Scheduler::_scheduler_type = Scheduler::Type::ST; + #endif /* ARM_COMPUTE_*_SCHEDULER */ + +-std::shared_ptr Scheduler::_custom_scheduler = nullptr; ++thread_local std::shared_ptr Scheduler::_custom_scheduler = nullptr; + + namespace + { diff --git a/third_party/mkl_dnn/onednn_acl_thread_local_scheduler.patch b/third_party/mkl_dnn/onednn_acl_thread_local_scheduler.patch new file mode 100644 index 00000000000000..11d6725f92eba8 --- /dev/null +++ b/third_party/mkl_dnn/onednn_acl_thread_local_scheduler.patch @@ -0,0 +1,98 @@ + ******************************************************************************* + Copyright 2023 Arm Limited and affiliates. + SPDX-License-Identifier: Apache-2.0 + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ******************************************************************************* +diff --git a/src/cpu/aarch64/acl_thread.cpp b/src/cpu/aarch64/acl_thread.cpp +index d7d83badcb..1a7bcd74ed 100644 +--- a/src/cpu/aarch64/acl_thread.cpp ++++ b/src/cpu/aarch64/acl_thread.cpp +@@ -41,14 +41,17 @@ void acl_thread_bind() { + #endif + + #if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL +-void acl_set_custom_scheduler() { +- static std::once_flag flag_once; +- // Create threadpool scheduler +- std::shared_ptr threadpool_scheduler +- = std::make_unique(); ++void acl_set_custom_scheduler(int intra_threads = 0) { ++ static thread_local std::once_flag flag_once; + // set CUSTOM scheduler in ACL + std::call_once(flag_once, +- [&]() { arm_compute::Scheduler::set(threadpool_scheduler); }); ++ [&]() { ++ // Create threadpool scheduler ++ std::shared_ptr threadpool_scheduler ++ = std::make_unique(); ++ threadpool_scheduler->set_num_threads(intra_threads); ++ ++ arm_compute::Scheduler::set(threadpool_scheduler); }); + } + + void acl_set_threadpool_num_threads() { +diff --git a/src/cpu/aarch64/acl_thread.hpp b/src/cpu/aarch64/acl_thread.hpp +index 46dde5eb05..13b3910515 100644 +--- a/src/cpu/aarch64/acl_thread.hpp ++++ b/src/cpu/aarch64/acl_thread.hpp +@@ -34,7 +34,7 @@ void acl_thread_bind(); + + #if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL + // Retrieve threadpool size during primitive execution and set ThreadpoolScheduler num_threads +-void acl_set_custom_scheduler(); ++void acl_set_custom_scheduler(int intra_threads); + void acl_set_threadpool_num_threads(); + #endif + +diff --git a/src/cpu/aarch64/acl_threadpool_scheduler.cpp b/src/cpu/aarch64/acl_threadpool_scheduler.cpp +index 418d7f30f9..7eb8a052b0 100644 +--- a/src/cpu/aarch64/acl_threadpool_scheduler.cpp ++++ b/src/cpu/aarch64/acl_threadpool_scheduler.cpp +@@ -102,8 +102,6 @@ void ThreadpoolScheduler::schedule_op(ICPPKernel *kernel, const Hints &hints, + void ThreadpoolScheduler::run_workloads( + std::vector &workloads) { + +- arm_compute::lock_guard lock(this->_run_workloads_mutex); +- + const unsigned int num_threads + = std::min(static_cast(_num_threads), + static_cast(workloads.size())); +diff --git a/src/cpu/cpu_engine.cpp b/src/cpu/cpu_engine.cpp +index 4ee70a405c..e9211f42e0 100644 +--- a/src/cpu/cpu_engine.cpp ++++ b/src/cpu/cpu_engine.cpp +@@ -47,6 +47,7 @@ status_t cpu_engine_t::create_stream(stream_t **stream, unsigned flags) { + #if DNNL_CPU_RUNTIME == DNNL_RUNTIME_THREADPOOL + status_t cpu_engine_t::create_stream(stream_t **stream, + dnnl::threadpool_interop::threadpool_iface *threadpool) { ++ dnnl::impl::cpu::aarch64::acl_thread_utils::acl_set_custom_scheduler(threadpool->get_num_threads()); + return safe_ptr_assign( + *stream, new cpu_stream_t(this, threadpool)); + } +diff --git a/src/cpu/cpu_engine.hpp b/src/cpu/cpu_engine.hpp +index 7aa077e4ef..2938650963 100644 +--- a/src/cpu/cpu_engine.hpp ++++ b/src/cpu/cpu_engine.hpp +@@ -175,11 +175,6 @@ public: + // dnnl_get_max_threads() == OMP_NUM_THREADS + dnnl::impl::cpu::aarch64::acl_thread_utils::acl_thread_bind(); + #endif +- +-#if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL +- // Set ACL scheduler for threadpool runtime +- dnnl::impl::cpu::aarch64::acl_thread_utils::acl_set_custom_scheduler(); +-#endif + #endif + return status::success; + }; From c1897e958cca297d90209292b4c570c11756b2fc Mon Sep 17 00:00:00 2001 From: Matt Callanan Date: Tue, 25 Jul 2023 10:13:19 -0700 Subject: [PATCH 113/410] #tf-data-service Don't use local protocol for local workers in test unless specified. PiperOrigin-RevId: 550921531 --- tensorflow/core/data/service/client/BUILD | 1 + tensorflow/core/data/service/client/data_service_client.cc | 6 +++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/data/service/client/BUILD b/tensorflow/core/data/service/client/BUILD index 44ae3738331cda..f3252c176331c6 100644 --- a/tensorflow/core/data/service/client/BUILD +++ b/tensorflow/core/data/service/client/BUILD @@ -52,6 +52,7 @@ cc_library( "//tensorflow/core/platform:thread_annotations", "//tensorflow/core/profiler/lib:traceme", "//tensorflow/core/profiler/lib:traceme_encode", + "//tensorflow/tsl/platform:platform_port", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", diff --git a/tensorflow/core/data/service/client/data_service_client.cc b/tensorflow/core/data/service/client/data_service_client.cc index 696af6fa8c62a9..f98bd77f5fa80e 100644 --- a/tensorflow/core/data/service/client/data_service_client.cc +++ b/tensorflow/core/data/service/client/data_service_client.cc @@ -51,6 +51,7 @@ limitations under the License. #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/profiler/lib/traceme.h" #include "tensorflow/core/profiler/lib/traceme_encode.h" +#include "tensorflow/tsl/platform/host_info.h" namespace tensorflow { namespace data { @@ -359,7 +360,10 @@ DataServiceClient::CreateAlternativeWorkerClientWithGrpcFallback( StatusOr> DataServiceClient::CreateWorkerClient(const TaskInfo& task_info) { - if (LocalWorkers::Get(task_info.worker_address()) != nullptr) { + if (params_.data_transfer_protocol == kLocalTransferProtocol || + // TODO(b/291994182): Use remote workers in unit tests. + (tsl::port::JobUid() != -1 && + LocalWorkers::Get(task_info.worker_address()) != nullptr)) { DataTransferServerInfo info; info.set_protocol(kLocalTransferProtocol); info.set_address(task_info.worker_address()); From 3a56044b35c68955f87af172f7a6cb2e48f34296 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 25 Jul 2023 10:15:50 -0700 Subject: [PATCH 114/410] [XLA] Add stack trace to the graphviz. PiperOrigin-RevId: 550922361 --- .../compiler/xla/service/hlo_graph_dumper.cc | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index eb0caab8c81cfb..b63b06588a1432 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -1253,6 +1253,21 @@ std::string HloDotDumper::GetInstructionNodeMetadata( lines.push_back(StrFormat("source: %s:%d", instr->metadata().source_file(), instr->metadata().source_line())); } + if (instr->metadata().stack_frame_id() != 0) { + auto hlo_module = instr->parent()->parent(); + int frame_id = instr->metadata().stack_frame_id(); + while (frame_id != 0) { + HloModule::StackFrame frame = hlo_module->get_stack_frame(frame_id); + if (frame.empty()) { + break; + } + frame_id = frame.parent_frame_id; + + lines.push_back(StrFormat( + "%s:%s:%d%s", frame.file_name, frame.function_name, frame.line, + frame.column == 0 ? "" : StrFormat(":%d", frame.column))); + } + } return StrJoin(lines, "\n"); } From 9f586dbf3a02f464bc12487a928d2f434ec69f6c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 25 Jul 2023 10:19:25 -0700 Subject: [PATCH 115/410] Extract input slices fusion from ir_emitter_unnested. PiperOrigin-RevId: 550923475 --- .../compiler/xla/service/gpu/fusions/BUILD | 18 -- .../xla/service/gpu/fusions/fusions.cc | 4 - .../xla/service/gpu/fusions/input_slices.cc | 184 -------------- .../xla/service/gpu/fusions/input_slices.h | 59 ----- .../xla/service/gpu/ir_emitter_unnested.cc | 225 +++++++++++++++++- .../xla/service/gpu/ir_emitter_unnested.h | 29 +++ 6 files changed, 253 insertions(+), 266 deletions(-) delete mode 100644 tensorflow/compiler/xla/service/gpu/fusions/input_slices.cc delete mode 100644 tensorflow/compiler/xla/service/gpu/fusions/input_slices.h diff --git a/tensorflow/compiler/xla/service/gpu/fusions/BUILD b/tensorflow/compiler/xla/service/gpu/fusions/BUILD index 22aeb4d6b48827..5bd8adde24b674 100644 --- a/tensorflow/compiler/xla/service/gpu/fusions/BUILD +++ b/tensorflow/compiler/xla/service/gpu/fusions/BUILD @@ -60,7 +60,6 @@ cc_library( ":copy", ":fusion_emitter", ":in_place_dynamic_update_slice", - ":input_slices", ":loop", ":reduction", ":transpose", @@ -171,20 +170,3 @@ cc_library( "@llvm-project//llvm:ir_headers", ], ) - -cc_library( - name = "input_slices", - srcs = ["input_slices.cc"], - hdrs = ["input_slices.h"], - deps = [ - ":fusion_emitter", - "//tensorflow/compiler/xla/service:elemental_ir_emitter", - "//tensorflow/compiler/xla/service/gpu:hlo_fusion_analysis", - "//tensorflow/compiler/xla/service/gpu:ir_emission_utils", - "//tensorflow/compiler/xla/service/gpu:parallel_loop_emitter", - "//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter", - "//tensorflow/compiler/xla/service/llvm_ir:kernel_support_library", - "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", - "@llvm-project//llvm:ir_headers", - ], -) diff --git a/tensorflow/compiler/xla/service/gpu/fusions/fusions.cc b/tensorflow/compiler/xla/service/gpu/fusions/fusions.cc index 602caa5ddc6ba5..cbbd60050eda7f 100644 --- a/tensorflow/compiler/xla/service/gpu/fusions/fusions.cc +++ b/tensorflow/compiler/xla/service/gpu/fusions/fusions.cc @@ -20,7 +20,6 @@ limitations under the License. #include "tensorflow/compiler/xla/mlir_hlo/lhlo/IR/lhlo_ops.h" #include "tensorflow/compiler/xla/service/gpu/fusions/copy.h" #include "tensorflow/compiler/xla/service/gpu/fusions/in_place_dynamic_update_slice.h" -#include "tensorflow/compiler/xla/service/gpu/fusions/input_slices.h" #include "tensorflow/compiler/xla/service/gpu/fusions/loop.h" #include "tensorflow/compiler/xla/service/gpu/fusions/reduction.h" #include "tensorflow/compiler/xla/service/gpu/fusions/transpose.h" @@ -52,9 +51,6 @@ std::optional> GetFusionEmitter( ElementalIrEmitter& elemental_emitter, mlir::lmhlo::FusionOp fusion_op, const HloFusionInstruction& fusion) { switch (analysis.GetEmitterFusionKind()) { - case HloFusionAnalysis::EmitterFusionKind::kInputSlices: - return std::make_unique( - ir_emitter_context, elemental_emitter, fusion_op, fusion, analysis); case HloFusionAnalysis::EmitterFusionKind::kReduction: return std::make_unique( ir_emitter_context, elemental_emitter, fusion_op, fusion, analysis); diff --git a/tensorflow/compiler/xla/service/gpu/fusions/input_slices.cc b/tensorflow/compiler/xla/service/gpu/fusions/input_slices.cc deleted file mode 100644 index 2e3edc3a0bbbb0..00000000000000 --- a/tensorflow/compiler/xla/service/gpu/fusions/input_slices.cc +++ /dev/null @@ -1,184 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu/fusions/input_slices.h" - -#include "llvm/IR/IRBuilder.h" -#include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" -#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" -#include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h" -#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" -#include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" -#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" - -namespace xla { -namespace gpu { -namespace { - -// Emits code for slices based on the below structure. An if statement with -// a guarding condition is generated for each ROOT slice. -// -// Pseudo code: -// -// Compute values of slice input operands -// -// Compute guarding_cond0 -// if (guarding_cond0) { -// Write to output of slice0 -// } -// -// Compute guarding_cond1 -// if (guarding_cond1) { -// Write to output of slice1 -// } -// -Status EmitElementForInputFusibleSlices( - ElementalIrEmitter& elemental_emitter, - const HloComputation* fused_computation, - const std::vector& inputs, - const std::vector& outputs, - const llvm_ir::IrArray::Index& index, llvm::IRBuilder<>* builder) { - VLOG(10) << "Emitting slice input fusion for " - << fused_computation->ToString(); - - HloInstruction* slice_or_tuple = fused_computation->root_instruction(); - auto slice_instructions = [&]() -> absl::Span { - if (slice_or_tuple->opcode() == HloOpcode::kSlice) { - return absl::Span(&slice_or_tuple, 1); - } - CHECK_EQ(slice_or_tuple->opcode(), HloOpcode::kTuple); - return slice_or_tuple->operands(); - }(); - - // Emit input operand values of slices. - std::vector input_ir_values; - FusedIrEmitter fused_emitter(elemental_emitter); - for (int i = 0; i < fused_computation->num_parameters(); i++) { - fused_emitter.BindGenerator( - *fused_computation->parameter_instruction(i), - [&inputs, i, builder](llvm_ir::IrArray::Index index) { - return inputs[i].EmitReadArrayElement(index, builder); - }); - } - for (const HloInstruction* slice : slice_instructions) { - auto input_generator = *fused_emitter.GetGenerator(*slice->operand(0)); - input_ir_values.push_back(input_generator(index).value()); - } - - // Emit for slice_instructions. - KernelSupportLibrary ksl(builder, llvm_ir::UnrollMode::kDefaultUnroll); - for (int64_t i = 0; i < slice_instructions.size(); ++i) { - HloInstruction* slice = slice_instructions[i]; - - // guarding_cond := index >= start && index < limit, for each dim. - std::vector index_within_ranges; - for (size_t dim = 0; dim < slice->slice_starts().size(); ++dim) { - CHECK_EQ(slice->slice_strides(dim), 1); - auto larger_or_equal_than_start = builder->CreateICmpSGE( - index.multidim()[dim], - index.GetConstantWithIndexType(slice->slice_starts(dim))); - llvm::Value* smaller_than_limit = builder->CreateICmpSLT( - index.multidim()[dim], - index.GetConstantWithIndexType(slice->slice_limits(dim))); - llvm::Value* within_range = - builder->CreateAnd(larger_or_equal_than_start, smaller_than_limit); - index_within_ranges.push_back(within_range); - } - llvm::Value* guarding_cond = builder->CreateAnd(index_within_ranges); - - auto emit_slice_elem_func = [&] { - const std::vector& src_multidim = index.multidim(); - std::vector dst_multidim(src_multidim.size()); - for (size_t dim = 0; dim < src_multidim.size(); ++dim) { - dst_multidim[dim] = builder->CreateSub( - src_multidim[dim], - index.GetConstantWithIndexType(slice->slice_starts(dim))); - } - llvm_ir::IrArray src_ir_array = outputs[i]; - llvm_ir::IrArray::Index slice_dst_index(dst_multidim, slice->shape(), - index.GetType()); - src_ir_array.EmitWriteArrayElement(slice_dst_index, input_ir_values[i], - builder); - }; - - ksl.If(absl::StrCat("slice", i), guarding_cond, emit_slice_elem_func); - } - return OkStatus(); -} - -// Gets the input shape of the ROOT slices, which will be used as the kernel -// launch dims. The slice input fusion requires the input shapes of the ROOT -// slices to be the same although the (slice) output shapes can be different. -// -// Returns the input shape of the ROOT slices if all the input shapes of ROOT -// slices are the same and the slices are non-strided. Otherwise, returns -// FailedPrecondition. -StatusOr GetConsistentInputShapeForRootSlices( - const HloComputation* fused_computation) { - const HloInstruction& root = *fused_computation->root_instruction(); - if (root.opcode() == HloOpcode::kSlice) { - return root.operands()[0]->shape(); - } - - CHECK_EQ(root.opcode(), HloOpcode::kTuple); - const Shape& first_slice_operand_shape = - root.operands()[0]->operands()[0]->shape(); - for (size_t i = 1; i < root.operands().size(); ++i) { - const HloInstruction* slice = root.operands()[i]; - const Shape& operand_shape = slice->operands()[0]->shape(); - if (!ShapeUtil::EqualIgnoringElementType(first_slice_operand_shape, - operand_shape)) { - return FailedPrecondition( - "Fused slices do not have the same input shape, fused computation = " - "%s.", - root.parent()->name()); - } - } - - return first_slice_operand_shape; -} - -} // namespace - -StatusOr InputSlicesFusion::launch_dimensions() const { - bool use_experimental_block_size = - ir_emitter_context() - .debug_options() - .xla_gpu_enable_experimental_block_size(); - return analysis_.GetLaunchDimensions(use_experimental_block_size); -} - -Status InputSlicesFusion::EmitKernel(const LaunchDimensions& launch_dims, - std::vector inputs, - std::vector outputs, - llvm::IRBuilder<>* builder, - int kernel_index) const { - TF_ASSIGN_OR_RETURN(Shape element_shape, - GetConsistentInputShapeForRootSlices( - fusion().fused_instructions_computation())); - return ParallelLoopEmitter( - [&](const llvm_ir::IrArray::Index index) -> Status { - return EmitElementForInputFusibleSlices( - elemental_emitter(), - fusion().fused_instructions_computation(), inputs, outputs, - index, builder); - }, - element_shape, launch_dims, builder) - .EmitLoop(llvm_ir::IrName(GetIrNameFromLoc(fusion_op().getLoc())), - GetIndexTypeForKernel(fusion_op(), launch_dims.launch_bound(), - builder)); -} - -} // namespace gpu -} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/fusions/input_slices.h b/tensorflow/compiler/xla/service/gpu/fusions/input_slices.h deleted file mode 100644 index 88f6ec90bf0b34..00000000000000 --- a/tensorflow/compiler/xla/service/gpu/fusions/input_slices.h +++ /dev/null @@ -1,59 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FUSIONS_INPUT_SLICES_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FUSIONS_INPUT_SLICES_H_ - -#include - -#include "tensorflow/compiler/xla/service/gpu/fusions/fusion_emitter.h" -#include "tensorflow/compiler/xla/service/gpu/hlo_fusion_analysis.h" - -namespace xla { -namespace gpu { - -// Generates code for input-fusible slices. -// -// Prerequisite: ROOT is either a slice or a tuple of slices. The input shapes -// of all ROOT slices need to be the same while their output shapes can be -// different. On the other hand, the input ranges of slices can be -// overlapping. Further generalization/specialization when the needs are seen -// in the future. -class InputSlicesFusion : public KernelFusionEmitterBase { - public: - InputSlicesFusion(IrEmitterContext& ir_emitter_context, - ElementalIrEmitter& elemental_emitter, - mlir::lmhlo::FusionOp fusion_op, - const HloFusionInstruction& fusion, - HloFusionAnalysis& analysis) - : KernelFusionEmitterBase(ir_emitter_context, elemental_emitter, - fusion_op, fusion), - analysis_(analysis) {} - StatusOr launch_dimensions() const override; - - protected: - Status EmitKernel(const LaunchDimensions& launch_dims, - std::vector inputs, - std::vector outputs, - llvm::IRBuilder<>* builder, - int kernel_index) const override; - - private: - HloFusionAnalysis& analysis_; -}; - -} // namespace gpu -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FUSIONS_INPUT_SLICES_H_ diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 75cdc5978a7e4e..e220091a188ddd 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -32,6 +32,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" @@ -156,6 +157,7 @@ namespace gpu { namespace { +using absl::InlinedVector; using absl::StrCat; using llvm_ir::IrArray; using llvm_ir::IrName; @@ -238,6 +240,61 @@ void AnnotateKernelLaunchDimensions(const LaunchDimensions& launch_dims, } } +bool IsSingleInstructionFusion(mlir::lmhlo::FusionOp fusion) { + int instruction_count = 0; + for (mlir::Operation& instr : fusion.getRegion().front()) { + if (mlir::isa( + &instr)) { + continue; + } + instruction_count++; + } + return instruction_count == 1; +} + +// Gets the input shape of the ROOT slices, which will be used as the kernel +// launch dims. The slice input fusion requires the input shapes of the ROOT +// slices to be the same although the (slice) output shapes can be different. +// +// Returns the input shape of the ROOT slices if all the input shapes of ROOT +// slices are the same and the slices are non-strided. Otherwise, returns +// FailedPrecondition. +StatusOr GetConsistentInputShapeForRootSlices( + const HloComputation* fused_computation) { + const HloInstruction& root = *fused_computation->root_instruction(); + if (root.opcode() == HloOpcode::kSlice) { + return root.operands()[0]->shape(); + } + + CHECK_EQ(root.opcode(), HloOpcode::kTuple); + const Shape& first_slice_operand_shape = + root.operands()[0]->operands()[0]->shape(); + for (size_t i = 1; i < root.operands().size(); ++i) { + const HloInstruction* slice = root.operands()[i]; + const Shape& operand_shape = slice->operands()[0]->shape(); + if (!ShapeUtil::EqualIgnoringElementType(first_slice_operand_shape, + operand_shape)) { + return FailedPrecondition( + "Fused slices do not have the same input shape, fused computation = " + "%s.", + root.parent()->name()); + } + } + + return first_slice_operand_shape; +} + +// For a row reduction, returns the number of rows we can process in parallel +// per warp. +int RowReductionGetRowsPerWarp(int reduced_dimension_size) { + if (WarpSize() % reduced_dimension_size != 0 || + reduced_dimension_size >= WarpSize()) { + return 1; + } + return WarpSize() / reduced_dimension_size; +} + StatusOr AsCudnnfMHAKind( mlir::lmhlo_gpu::FusedMhaDagSignature signature) { switch (signature) { @@ -1959,9 +2016,10 @@ Status IrEmitterUnnested::EmitFusion(mlir::Operation* op) { #endif LOG(FATAL) << "Unsupported fusion kind: " << backend_config.kind(); } + case HloFusionAnalysis::EmitterFusionKind::kInputSlices: + return EmitInputFusibleNonStridedSlices(op, fusion_analysis); case HloFusionAnalysis::EmitterFusionKind::kScatter: return EmitScatter(fusion_op, fused_computation, fusion_analysis); - case HloFusionAnalysis::EmitterFusionKind::kInputSlices: case HloFusionAnalysis::EmitterFusionKind::kLoop: case HloFusionAnalysis::EmitterFusionKind::kReduction: case HloFusionAnalysis::EmitterFusionKind::kTranspose: @@ -1970,6 +2028,44 @@ Status IrEmitterUnnested::EmitFusion(mlir::Operation* op) { } } +Status IrEmitterUnnested::EmitExtraOutputsForReduce( + const Shape& reduction_operand_shape, + const ReductionOutputMap& result_ir_arrays, const IrArray::Index& index, + const ReductionCodegenInfo& reduction_info, + const ExtraOutputGensMap& extra_output_gens) { + if (extra_output_gens.empty()) { + return OkStatus(); + } + + // Compute all extra output values before writing them. This avoids + // overwriting aliased input/output buffers before all reads occurred. + std::vector> + extra_output_ir_values; + extra_output_ir_values.reserve(extra_output_gens.size()); + + auto get_index = [&](const HloInstruction* instr) { + const Shape& s = instr->shape(); + return ShapeUtil::EqualIgnoringElementType(reduction_operand_shape, s) + ? index + : index.SourceIndexOfBitcast(reduction_operand_shape, s, &b_); + }; + + for (const auto& [instr, generator] : extra_output_gens) { + TF_ASSIGN_OR_RETURN(llvm::Value* const extra_output_ir_value, + generator(get_index(instr))); + extra_output_ir_values.emplace_back(instr, extra_output_ir_value); + } + + for (const auto& [instr, generator] : extra_output_ir_values) { + absl::Span result_ir = result_ir_arrays.at(instr); + CHECK_EQ(result_ir.size(), 1); + result_ir[0].EmitWriteArrayElement( + get_index(instr), generator, &b_, /*use_linear_index=*/ + reduction_info.GetNumPartialResults() == 1); + } + return OkStatus(); +} + Status IrEmitterUnnested::AssertNonDeterminismIsOkay( const std::string& op_name) { if (ir_emitter_context_->debug_options().xla_gpu_deterministic_ops()) { @@ -3165,6 +3261,133 @@ Status IrEmitterUnnested::EmitTargetElementLoop( return InternalError("This should be unreachable"); } +// Emits code for slices based on the below structure. An if statement with +// a guarding condition is generated for each ROOT slice. +// +// Pseudo code: +// +// Compute values of slice input operands +// +// Compute guarding_cond0 +// if (guarding_cond0) { +// Write to output of slice0 +// } +// +// Compute guarding_cond1 +// if (guarding_cond1) { +// Write to output of slice1 +// } +// +Status IrEmitterUnnested::EmitElementForInputFusibleSlices( + const HloComputation* fused_computation, + absl::Span ir_arrays, + const llvm_ir::IrArray::Index& index) { + VLOG(10) << "Emitting slice input fusion for " + << fused_computation->ToString(); + + HloInstruction* slice_or_tuple = fused_computation->root_instruction(); + auto slice_instructions = [&]() -> absl::Span { + if (slice_or_tuple->opcode() == HloOpcode::kSlice) { + return absl::Span(&slice_or_tuple, 1); + } + CHECK_EQ(slice_or_tuple->opcode(), HloOpcode::kTuple); + return slice_or_tuple->operands(); + }(); + + // Emit input operand values of slices. + std::vector input_ir_values; + FusedIrEmitter fused_emitter(elemental_emitter_); + for (int i = 0; i < fused_computation->num_parameters(); i++) { + fused_emitter.BindGenerator( + *fused_computation->parameter_instruction(i), + [this, &ir_arrays, i](llvm_ir::IrArray::Index index) { + return ir_arrays[i].EmitReadArrayElement(index, &b_); + }); + } + for (const HloInstruction* slice : slice_instructions) { + auto input_generator = *fused_emitter.GetGenerator(*slice->operand(0)); + input_ir_values.push_back(input_generator(index).value()); + } + + // Emit for slice_instructions. + KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll); + for (int64_t i = 0; i < slice_instructions.size(); ++i) { + HloInstruction* slice = slice_instructions[i]; + + // guarding_cond := index >= start && index < limit, for each dim. + std::vector index_within_ranges; + for (size_t dim = 0; dim < slice->slice_starts().size(); ++dim) { + CHECK_EQ(slice->slice_strides(dim), 1); + auto larger_or_equal_than_start = b_.CreateICmpSGE( + index.multidim()[dim], + index.GetConstantWithIndexType(slice->slice_starts(dim))); + llvm::Value* smaller_than_limit = b_.CreateICmpSLT( + index.multidim()[dim], + index.GetConstantWithIndexType(slice->slice_limits(dim))); + llvm::Value* within_range = + b_.CreateAnd(larger_or_equal_than_start, smaller_than_limit); + index_within_ranges.push_back(within_range); + } + llvm::Value* guarding_cond = b_.CreateAnd(index_within_ranges); + + auto emit_slice_elem_func = [&] { + const std::vector& src_multidim = index.multidim(); + std::vector dst_multidim(src_multidim.size()); + for (size_t dim = 0; dim < src_multidim.size(); ++dim) { + dst_multidim[dim] = + Sub(src_multidim[dim], + index.GetConstantWithIndexType(slice->slice_starts(dim))); + } + llvm_ir::IrArray src_ir_array = + ir_arrays[fused_computation->num_parameters() + i]; + IrArray::Index slice_dst_index(dst_multidim, slice->shape(), + index.GetType()); + src_ir_array.EmitWriteArrayElement(slice_dst_index, input_ir_values[i], + &b_); + }; + + ksl.If(StrCat("slice", i), guarding_cond, emit_slice_elem_func); + } + return OkStatus(); +} + +Status IrEmitterUnnested::EmitInputFusibleNonStridedSlices( + mlir::Operation* op, HloFusionAnalysis& fusion_analysis) { + auto fusion = mlir::cast(op); + + TF_ASSIGN_OR_RETURN(const HloComputation* fused_computation, + GetOrCreateSubComputationFromRegion(&fusion.getRegion(), + /*is_fusion=*/true)); + + bool use_experimental_block_size = + ir_emitter_context_->debug_options() + .xla_gpu_enable_experimental_block_size(); + TF_ASSIGN_OR_RETURN( + LaunchDimensions launch_dimensions, + fusion_analysis.GetLaunchDimensions(use_experimental_block_size)); + + TF_ASSIGN_OR_RETURN( + std::optional> opt_ir_arrays, + BuildKernelThunkForFusion(fusion, launch_dimensions)); + if (!opt_ir_arrays.has_value()) { + // The kernel was reused, no need to emit code. + return OkStatus(); + } + std::vector& ir_arrays = opt_ir_arrays.value(); + + TF_ASSIGN_OR_RETURN(Shape element_shape, + GetConsistentInputShapeForRootSlices(fused_computation)); + return ParallelLoopEmitter( + [&](const llvm_ir::IrArray::Index index) -> Status { + return EmitElementForInputFusibleSlices(fused_computation, + ir_arrays, index); + }, + element_shape, launch_dimensions, &b_) + .EmitLoop( + IrName(GetIrNameFromLoc(fusion.getLoc())), + GetIndexTypeForKernel(fusion, launch_dimensions.launch_bound(), &b_)); +} + Status IrEmitterUnnested::EmitScatter(mlir::lmhlo::FusionOp fusion_op, const HloComputation* fused_computation, HloFusionAnalysis& fusion_analysis) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index 5409e2e74292e1..39f5b34e23238f 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -317,6 +317,28 @@ class IrEmitterUnnested : public IrEmitter { absl::Span arguments, const LaunchDimensions& launch_dimensions); + // Helper for writing extra outputs from inside a reduce kernel. + Status EmitExtraOutputsForReduce(const Shape& reduction_operand_shape, + const ReductionOutputMap& result_ir_arrays, + const llvm_ir::IrArray::Index& index, + const ReductionCodegenInfo& reduction_info, + const ExtraOutputGensMap& extra_output_gens); + + // Generates code for input-fusible slices. + // + // Prerequisite: ROOT is either a slice or a tuple of slices. The input shapes + // of all ROOT slices need to be the same while their output shapes can be + // different. On the other hand, the input ranges of slices can be + // overlapping. Further generalization/specialization when the needs are seen + // in the future. + Status EmitInputFusibleNonStridedSlices(mlir::Operation* op, + HloFusionAnalysis& fusion_analysis); + + Status EmitElementForInputFusibleSlices( + const HloComputation* fused_computation, + absl::Span ir_arrays, + const llvm_ir::IrArray::Index& index); + // Emits code for an in-place scatter, modifying `thunk`s launch dimensions in // the process. Scatter indices are taken from `scatter_indices_gen`, updates // from `updates_gen`. The output buffer is expected to have the operand @@ -353,6 +375,13 @@ class IrEmitterUnnested : public IrEmitter { Status EmitScatter(const ScatterDescriptor& desc, const LaunchDimensions& launch_dimensions); + Status EmitTransposeTile(mlir::lmhlo::FusionOp fusion, + const HloComputation* fusion_hlo, + absl::Span operand_arrays, + absl::Span output_arrays, + const TilingScheme& tiling_scheme, + const LaunchDimensions& launch_dimensions); + Status EmitScatter(mlir::lmhlo::FusionOp fusion_op, const HloComputation* fused_computation, HloFusionAnalysis& fusion_analysis); From 0cbe63ed1a14e1fdad00f69962169be37a11cbe5 Mon Sep 17 00:00:00 2001 From: Anlun Xu Date: Tue, 25 Jul 2023 10:19:39 -0700 Subject: [PATCH 116/410] [xla:gpu] Insert explicit synchronization ops after stream assignment PiperOrigin-RevId: 550923540 --- .../mlir/backends/gpu/transforms/passes.td | 4 +- .../gpu/transforms/stream_assignment.cc | 82 +++++++++++++++- .../transforms/tests/stream_assignment.mlir | 95 +++++++++++++++++++ 3 files changed, 177 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.td b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.td index 79b87695eeb4b3..5103dfe0d11418 100644 --- a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.td +++ b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.td @@ -287,8 +287,8 @@ def StreamAssignmentPass: call @xla.gpu.launch.func3 {stream = 2 : i64} // Add explicit synchronization to wait for stream 1 to finish executing - // func2. - xla.stream.await {from = 0, to = 1} + // func2. + call @xla.stream.await {from = 0 : i64, to = [1]} call @xla.gpu.launch.func {stream = 0: i64} func.return } diff --git a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/stream_assignment.cc b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/stream_assignment.cc index 8fb494d2758d3d..3916e949b12c79 100644 --- a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/stream_assignment.cc +++ b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/stream_assignment.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include #include #include #include @@ -22,11 +24,13 @@ limitations under the License. #include "absl/strings/match.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/GPU/IR/GPUDialect.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "tensorflow/compiler/xla/mlir/backends/gpu/transforms/dataflow_analysis.h" +#include "tensorflow/compiler/xla/mlir/runtime/utils/custom_calls.h" namespace xla { namespace gpu { @@ -46,6 +50,8 @@ class StreamAssignmentPass void runOnOperation() override; }; +static constexpr int kNumStreams = 10; + //===----------------------------------------------------------------------===// bool IsParallelizableOp(Operation* op) { @@ -125,9 +131,78 @@ std::vector AssignStreams(const DataflowGraph& graph, int num_streams) { assign_stream_to_dependency_chain(unassigned_node, assigned_stream); } + // next: Assign all non parallelizable ops to stream 0. + return stream_assignment; } +std::optional GetAssignedStream(Operation* op) { + if (op->hasAttr("stream")) { + return op->getAttrOfType("stream").getInt(); + } + return std::nullopt; +} + +// +// Add synchronizations between assigned streams. The added custom call +// xla.streams.await() {from = A, to = [B, C, ...]} makes future work submitted +// to A wait for work that are already submitted to streams B, C, ... +// +// Pseudo code: +// For each node in the dependency graph +// If the node has a stream A assigned +// parents = A's parents +// to_streams = the assigned streams of its parents +// add xla.streams.await() {from = A, to = to_streams} before node +// +// TODO(anlunx): Handle the case where the cuda graph contains non +// parallelizable ops (cuBLAS, cuDNN). +// +void AddSynchronization(runtime::CustomCallDeclarations custom_calls, + const DataflowGraph& graph) { + for (const Node& node : graph) { + Operation* op = node.operation; + std::optional op_stream = GetAssignedStream(op); + if (!op_stream.has_value()) { + continue; + } + int from_stream = op_stream.value(); + + std::array dependent_streams; + dependent_streams.fill(false); + for (int i = 0; i < node.index; i++) { + if (std::find(graph[i].children.begin(), graph[i].children.end(), + node.index) != graph[i].children.end()) { + if (std::optional to_stream = + GetAssignedStream(graph[i].operation)) { + if (to_stream.value() != from_stream) { + dependent_streams[to_stream.value()] = true; + } + } + } + } + + ImplicitLocOpBuilder b(op->getLoc(), custom_calls.sym_table().getOp()); + llvm::SmallVector to_streams; + for (int i = 0; i < kNumStreams; i++) { + if (dependent_streams[i]) { + to_streams.push_back(b.getI64IntegerAttr(i)); + } + } + + if (to_streams.empty()) { + continue; + } + + func::FuncOp await_op = custom_calls.GetOrCreate(b, "xla.streams.await", + TypeRange(), TypeRange()); + b.setInsertionPoint(op); + auto call = b.create(await_op.getName(), TypeRange()); + call->setAttr(b.getStringAttr("from"), b.getI64IntegerAttr(from_stream)); + call->setAttr(b.getStringAttr("to"), b.getArrayAttr(to_streams)); + } +} + //===----------------------------------------------------------------------===// void StreamAssignmentPass::runOnOperation() { @@ -139,19 +214,22 @@ void StreamAssignmentPass::runOnOperation() { } SymbolTable sym_table(func_op->getParentOfType()); - ImplicitLocOpBuilder b(func_op->getLoc(), sym_table.getOp()); DataflowAnalysis dataflow_analysis(func_op); DataflowGraph graph = dataflow_analysis.GetDataflowGraph(func_op); - std::vector stream_assignment = AssignStreams(graph, 10); + std::vector stream_assignment = AssignStreams(graph, kNumStreams); for (auto [index, stream] : llvm::enumerate(stream_assignment)) { Node node = graph[index]; Operation* op = node.operation; + ImplicitLocOpBuilder b(op->getLoc(), sym_table.getOp()); if (stream != -1) { op->setAttr(b.getStringAttr("stream"), b.getI64IntegerAttr(stream)); } } + + runtime::CustomCallDeclarations custom_calls(std::move(sym_table)); + AddSynchronization(custom_calls, graph); } } // namespace diff --git a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/stream_assignment.mlir b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/stream_assignment.mlir index 8cef428174a9c2..87c3e64673e207 100644 --- a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/stream_assignment.mlir +++ b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/stream_assignment.mlir @@ -32,6 +32,7 @@ module attributes {gpu.container_module} { // CHECK-SAME: {stream = 1 : i64} gpu.launch_func @gpu_module::@fn1 blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%view_0 : memref<3x3xi64>) + // CHECK: call @xla.streams.await() {from = 0 : i64, to = [1]} // CHECK: gpu.launch_func @gpu_module::@fn2 // CHECK-SAME: {stream = 0 : i64} gpu.launch_func @gpu_module::@fn2 blocks in (%c1, %c1, %c1) @@ -70,6 +71,7 @@ module attributes {gpu.container_module} { // CHECK-SAME: {stream = 0 : i64} gpu.launch_func @gpu_module::@fn2 blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%view : memref<3x3xi64>, %view_0 : memref<3x3xi64>) + // CHECK: call @xla.streams.await() {from = 1 : i64, to = [0]} // CHECK: gpu.launch_func @gpu_module::@fn1 // CHECK-SAME: {stream = 1 : i64} gpu.launch_func @gpu_module::@fn1 blocks in (%c1, %c1, %c1) @@ -81,3 +83,96 @@ module attributes {gpu.container_module} { return } } + +// ----- +// Check that stream with multiple dependencies is handled correctly. +// A B C-->D +// | | ^ +// | |--------| +// +-------------+ +// +// Stream assignment: A->0 B->1 C->2 D->0 +// + +module attributes {gpu.container_module} { + + gpu.module @gpu_module attributes {binary = "kernel binary"} { + gpu.func @fn1(%arg0: memref<3x3xi64> {lmhlo.written} ) kernel { gpu.return } + gpu.func @fn2(%arg0: memref<3x3xi64> {lmhlo.written}, %arg1: memref<3x3xi64> {lmhlo.written}, %arg3: memref<3x3xi64>) kernel { gpu.return } + } + + + // CHECK: func @xla.gpu.cuda.graph.capture + func.func @xla.gpu.cuda.graph.capture(%arg0: memref<72xi8>, %arg1: memref<72xi8>, %arg2: memref<72xi8>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %view_0 = memref.view %arg0[%c0][] : memref<72xi8> to memref<3x3xi64> + %view_1 = memref.view %arg1[%c0][] : memref<72xi8> to memref<3x3xi64> + %view_2 = memref.view %arg2[%c0][] : memref<72xi8> to memref<3x3xi64> + + // CHECK: gpu.launch_func @gpu_module::@fn1 + // CHECK-SAME: {stream = 0 : i64} + gpu.launch_func @gpu_module::@fn1 blocks in (%c1, %c1, %c1) + threads in (%c1, %c1, %c1) args(%view_0 : memref<3x3xi64>) + // CHECK: gpu.launch_func @gpu_module::@fn1 + // CHECK-SAME: {stream = 1 : i64} + gpu.launch_func @gpu_module::@fn1 blocks in (%c1, %c1, %c1) + threads in (%c1, %c1, %c1) args(%view_1 : memref<3x3xi64>) + // CHECK: gpu.launch_func @gpu_module::@fn1 + // CHECK-SAME: {stream = 2 : i64} + gpu.launch_func @gpu_module::@fn1 blocks in (%c1, %c1, %c1) + threads in (%c1, %c1, %c1) args(%view_2 : memref<3x3xi64>) + // CHECK: call @xla.streams.await() {from = 0 : i64, to = [1, 2]} + // CHECK: gpu.launch_func @gpu_module::@fn2 + // CHECK-SAME: {stream = 0 : i64} + gpu.launch_func @gpu_module::@fn2 blocks in (%c1, %c1, %c1) + threads in (%c1, %c1, %c1) args(%view_0 : memref<3x3xi64>, %view_1 : memref<3x3xi64>, %view_2 : memref<3x3xi64>) + return + } +} + +// ----- +// Check that stream synchronization only happens when two streams joins. +// A B--->C-->D +// | ^ +// | | +// +---------+ +// +// Stream assignment: A->0 B->1 C->0 D->0 +// + +module attributes {gpu.container_module} { + + gpu.module @gpu_module attributes {binary = "kernel binary"} { + gpu.func @fn1(%arg0: memref<3x3xi64> {lmhlo.written} ) kernel { gpu.return } + gpu.func @fn2(%arg0: memref<3x3xi64> {lmhlo.written}, %arg1: memref<3x3xi64>) kernel { gpu.return } + } + + + // CHECK: func @xla.gpu.cuda.graph.capture + func.func @xla.gpu.cuda.graph.capture(%arg0: memref<72xi8>, %arg1: memref<72xi8>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %view_0 = memref.view %arg0[%c0][] : memref<72xi8> to memref<3x3xi64> + %view_1 = memref.view %arg1[%c0][] : memref<72xi8> to memref<3x3xi64> + + // CHECK: gpu.launch_func @gpu_module::@fn1 + // CHECK-SAME: {stream = 0 : i64} + gpu.launch_func @gpu_module::@fn1 blocks in (%c1, %c1, %c1) + threads in (%c1, %c1, %c1) args(%view_0 : memref<3x3xi64>) + // CHECK: gpu.launch_func @gpu_module::@fn1 + // CHECK-SAME: {stream = 1 : i64} + gpu.launch_func @gpu_module::@fn1 blocks in (%c1, %c1, %c1) + threads in (%c1, %c1, %c1) args(%view_1 : memref<3x3xi64>) + // CHECK: call @xla.streams.await() {from = 0 : i64, to = [1]} + // CHECK: gpu.launch_func @gpu_module::@fn2 + // CHECK-SAME: {stream = 0 : i64} + gpu.launch_func @gpu_module::@fn2 blocks in (%c1, %c1, %c1) + threads in (%c1, %c1, %c1) args(%view_0 : memref<3x3xi64>, %view_1 : memref<3x3xi64>) + // CHECK-NEXT: gpu.launch_func @gpu_module::@fn2 + // CHECK-SAME: {stream = 0 : i64} + gpu.launch_func @gpu_module::@fn2 blocks in (%c1, %c1, %c1) + threads in (%c1, %c1, %c1) args(%view_0 : memref<3x3xi64>, %view_1 : memref<3x3xi64>) + return + } +} From 1b49fd8614a0ee67f0c9516e6a161822ee899919 Mon Sep 17 00:00:00 2001 From: Rahul Joshi Date: Tue, 25 Jul 2023 10:37:11 -0700 Subject: [PATCH 117/410] [XLA] Print computation name after closing brace when dumping HLO - This can help when working with large computations to find which computation the instruction belongs to. PiperOrigin-RevId: 550929464 --- tensorflow/compiler/xla/hlo/ir/hlo_computation.cc | 4 ++++ tensorflow/compiler/xla/hlo/ir/hlo_instruction.h | 12 +++++++++++- tensorflow/compiler/xla/service/dump.cc | 1 + tensorflow/compiler/xla/service/hlo_parser_test.cc | 9 +++++++++ 4 files changed, 25 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/xla/hlo/ir/hlo_computation.cc b/tensorflow/compiler/xla/hlo/ir/hlo_computation.cc index 770a4919082bda..0a7e8fbab1b2cb 100644 --- a/tensorflow/compiler/xla/hlo/ir/hlo_computation.cc +++ b/tensorflow/compiler/xla/hlo/ir/hlo_computation.cc @@ -680,6 +680,10 @@ void HloComputation::Print( printer->Append(execution_thread()); printer->Append("\""); } + if (options.print_name_after_closing_brace() && instruction_count() > 5) { + printer->Append(" // "); + printer->Append(name()); + } } std::string HloComputation::ToString() const { diff --git a/tensorflow/compiler/xla/hlo/ir/hlo_instruction.h b/tensorflow/compiler/xla/hlo/ir/hlo_instruction.h index 72594b4d2dbc25..81bb229aea89e6 100644 --- a/tensorflow/compiler/xla/hlo/ir/hlo_instruction.h +++ b/tensorflow/compiler/xla/hlo/ir/hlo_instruction.h @@ -95,7 +95,8 @@ class HloPrintOptions { print_ids_(true), canonicalize_computations_(false), print_extra_attributes_(true), - syntax_sugar_async_ops_(true) {} + syntax_sugar_async_ops_(true), + print_name_after_closing_brace_(false) {} // Static reference to a default construction HloPrintOptions, to avoid // constructing a new one each time default is needed. static const HloPrintOptions& Default() { @@ -338,6 +339,11 @@ class HloPrintOptions { return *this; } + HloPrintOptions& set_print_name_after_closing_brace(bool value) { + print_name_after_closing_brace_ = value; + return *this; + } + bool print_large_constants() const { return print_large_constants_; } bool print_only_essential_constants() const { return print_only_essential_constants_; @@ -372,6 +378,9 @@ class HloPrintOptions { bool canonicalize_computations() const { return canonicalize_computations_; } int indent_amount() const { return indent_amount_; } int is_in_nested_computation() const { return is_in_nested_computation_; } + int print_name_after_closing_brace() const { + return print_name_after_closing_brace_; + } private: // The interval between the /*index=*/ annotated operands. 0 means never print @@ -398,6 +407,7 @@ class HloPrintOptions { bool canonicalize_computations_; bool print_extra_attributes_; bool syntax_sugar_async_ops_; + bool print_name_after_closing_brace_; }; // For canonical string output, we need to have a canonical way to rename diff --git a/tensorflow/compiler/xla/service/dump.cc b/tensorflow/compiler/xla/service/dump.cc index f8d73799807d94..ed360ee9d37684 100644 --- a/tensorflow/compiler/xla/service/dump.cc +++ b/tensorflow/compiler/xla/service/dump.cc @@ -396,6 +396,7 @@ static std::vector DumpHloModuleImpl( print_options.set_print_operand_index_annotation_interval(5); print_options.set_print_backend_config(true); print_options.set_print_metadata(opts.dump_hlo_metadata); + print_options.set_print_name_after_closing_brace(true); file_paths.push_back(DumpToFileInDirOrStdoutImpl( StrCat(filename, ".txt"), module.ToString(print_options), opts)); if (buffer_assn) { diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index 332dd413835fbb..fc9aa9300735fb 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -4513,6 +4513,15 @@ test { Layout({1, 0, 2, 3})); } +TEST_F(HloParserTest, ParseComputationNameClosingBrace) { + const std::string original = R"( +test { + ROOT root = f32[1,64,10,128]{1,0,2,3} parameter(0) +} // test +)"; + EXPECT_TRUE(ParseAndReturnUnverifiedModule(original).ok()); +} + TEST_F(HloParserTest, ParseSingleEntryComputation) { const std::string original = R"( ENTRY test { From dc37c6306a7bd53e88944ca3041ad57539a0c805 Mon Sep 17 00:00:00 2001 From: Juan Martinez Castellanos Date: Tue, 25 Jul 2023 10:38:58 -0700 Subject: [PATCH 118/410] Fix Python dependencies for target compiler/mlir/tensorflow:gen_mlir_passthrough_op_py. PiperOrigin-RevId: 550930091 --- tensorflow/compiler/mlir/tensorflow/BUILD | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 18148ed4d480d1..93a27ea3bf2905 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -2355,6 +2355,12 @@ tf_gen_op_wrapper_py( name = "gen_mlir_passthrough_op_py", out = "gen_mlir_passthrough_op.py", compatible_with = [], + extra_py_deps = [ + "//tensorflow/python:pywrap_tfe", + "//tensorflow/python/util:dispatch", + "//tensorflow/python/util:deprecation", + "//tensorflow/python/util:tf_export", + ], py_lib_rule = py_strict_library, deps = [":mlir_passthrough_op"], ) From a6bab9028be057a7b650b2f2a01de494a178b7de Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 25 Jul 2023 10:41:43 -0700 Subject: [PATCH 119/410] Minor fix checkfile tests for the @quantize_i8 and @dequantize_qi8. PiperOrigin-RevId: 550930987 --- .../tensorflow/tests/insert_quantized_functions.mlir | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/insert_quantized_functions.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/insert_quantized_functions.mlir index c661b6f7293a27..74d0e1b3c8e9ec 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/insert_quantized_functions.mlir +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/insert_quantized_functions.mlir @@ -35,12 +35,10 @@ module { // CHECK: func private @quantize_i8 // CHECK: func private @dequantize_i8 -// UQ-CHECK-NOT: func private @dequantize_i8 // UQ-CHECK-NOT: func private @internal_conv2d_fn // UQ-CHECK-NOT: func private @internal_requantize_qi8_fn // UQ-CHECK-NOT: func private @internal_requantize_no_activation_fn // UQ-CHECK-NOT: func private @internal_requantize_and_relu_fn -// UQ-CHECK-NOT: func private @quantize_i8 // UQ-CHECK: func private @quantized_conv2d_with_bias_fn // UQ-CHECK-SAME: tf_quant.quantized_ops = ["Conv2D", "BiasAdd"] // UQ-CHECK: func private @quantized_conv2d_with_bias_and_relu_fn @@ -53,3 +51,5 @@ module { // UQ-CHECK: func private @quantized_depthwise_conv2d_with_bias_and_relu6_fn // UQ-CHECK: func private @quantized_depthwise_conv2d_with_relu_fn // UQ-CHECK: func private @quantized_depthwise_conv2d_with_relu6_fn +// UQ-CHECK: func private @quantize_i8 +// UQ-CHECK: func private @dequantize_i8 From 01ad888ffabda93b01e4f0ac5ee802448703ea7f Mon Sep 17 00:00:00 2001 From: Son Tuan Vu Date: Tue, 25 Jul 2023 10:57:29 -0700 Subject: [PATCH 120/410] [XLA:GPU][NFC] Temporarily disable cuda graphs when profiling for CUDA version >= 12 PiperOrigin-RevId: 550936118 --- tensorflow/compiler/xla/service/gpu/runtime/graph_launch.cc | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/runtime/graph_launch.cc b/tensorflow/compiler/xla/service/gpu/runtime/graph_launch.cc index de914f695d726a..ae4c6f12306a52 100644 --- a/tensorflow/compiler/xla/service/gpu/runtime/graph_launch.cc +++ b/tensorflow/compiler/xla/service/gpu/runtime/graph_launch.cc @@ -563,11 +563,9 @@ static absl::Status LaunchGraph( // TODO(ezhulenev): Cupti tracing leads to deadlocks in CUDA 11. Always fall // back on regular execution if we detect tracing activity. -#if defined(CUDA_VERSION) && CUDA_VERSION >= 12000 - bool is_profiling = false; -#else + // HLO ops profiling can't capture individual kernels when cuda graphs are + // enabled, so we want to make sure they are disabled when profiling. bool is_profiling = tsl::profiler::ScopedAnnotationStack::IsEnabled(); -#endif if (count < num_runs_to_instantiate || is_profiling) { VLOG(3) << "Run gpu graph in op-by-op mode: ordinal = " << capture.ordinal; From 039111caeb1bbfe7d4ecc0b3a14fa77f7719bfae Mon Sep 17 00:00:00 2001 From: rahulbatra85 Date: Tue, 25 Jul 2023 11:12:19 -0700 Subject: [PATCH 121/410] PR #4484: [ROCm]: Enable bf16 support in NCCL Imported from GitHub PR https://github.com/openxla/xla/pull/4484 Copybara import of the project: -- d275a9e8064f14b8ccf0379c765d9223ccbf6084 by Rahul Batra : [ROCm]: Enable bf16 support in NCCL Merging this change closes #4484 PiperOrigin-RevId: 550941099 --- tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.cc | 2 +- tensorflow/compiler/xla/service/gpu/nccl_utils.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.cc b/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.cc index 0f193ab0d327e9..f888837bd2d0c7 100644 --- a/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.cc @@ -44,7 +44,7 @@ bool IsTypeSupportedByNccl(PrimitiveType element_type, case F16: case F32: case F64: -#if defined(__CUDA_BF16_TYPES_EXIST__) +#if defined(__CUDA_BF16_TYPES_EXIST__) || TENSORFLOW_USE_ROCM case BF16: #endif case C64: diff --git a/tensorflow/compiler/xla/service/gpu/nccl_utils.cc b/tensorflow/compiler/xla/service/gpu/nccl_utils.cc index 68b630676544e5..f058aea63e8409 100644 --- a/tensorflow/compiler/xla/service/gpu/nccl_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/nccl_utils.cc @@ -105,7 +105,7 @@ StatusOr ToNcclDataType(PrimitiveType element_type, // For collectives that just move data around, we can use ncclFloat16 for // 16-bit integer data types. return ncclFloat16; -#if defined(__CUDA_BF16_TYPES_EXIST__) +#if defined(__CUDA_BF16_TYPES_EXIST__) || TENSORFLOW_USE_ROCM case BF16: return ncclBfloat16; #endif From 231550c3b23831c143da65ad83fb191da8ac661f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 25 Jul 2023 11:18:06 -0700 Subject: [PATCH 122/410] Update requirements_lock files with the newest versions PiperOrigin-RevId: 550943032 --- requirements_lock_3_10.txt | 358 ++++++++++++++++++----------------- requirements_lock_3_11.txt | 358 ++++++++++++++++++----------------- requirements_lock_3_9.txt | 370 ++++++++++++++++++------------------- 3 files changed, 537 insertions(+), 549 deletions(-) diff --git a/requirements_lock_3_10.txt b/requirements_lock_3_10.txt index a3c84ffba291ad..65c54527d47670 100644 --- a/requirements_lock_3_10.txt +++ b/requirements_lock_3_10.txt @@ -6,90 +6,90 @@ cachetools==5.3.1 \ --hash=sha256:95ef631eeaea14ba2e36f06437f36463aac3a096799e876ee55e5cdccb102590 \ --hash=sha256:dce83f2d9b4e1f732a8cd44af8e8fab2dbe46201467fc98b3ef8f269092bf62b # via google-auth -certifi==2023.5.7 \ - --hash=sha256:0f0d56dc5a6ad56fd4ba36484d6cc34451e1c6548c61daad8c320169f91eddc7 \ - --hash=sha256:c6c2e98f5c7869efca1f8916fed228dd91539f9f1b444c314c06eef02980c716 +certifi==2023.7.22 \ + --hash=sha256:539cc1d13202e33ca466e88b2807e29f4c13049d6d87031a3c110744495cb082 \ + --hash=sha256:92d6037539857d8206b8f6ae472e8b77db8058fec5937a1ef3f54304089edbb9 # via requests -charset-normalizer==3.1.0 \ - --hash=sha256:04afa6387e2b282cf78ff3dbce20f0cc071c12dc8f685bd40960cc68644cfea6 \ - --hash=sha256:04eefcee095f58eaabe6dc3cc2262f3bcd776d2c67005880894f447b3f2cb9c1 \ - --hash=sha256:0be65ccf618c1e7ac9b849c315cc2e8a8751d9cfdaa43027d4f6624bd587ab7e \ - --hash=sha256:0c95f12b74681e9ae127728f7e5409cbbef9cd914d5896ef238cc779b8152373 \ - --hash=sha256:0ca564606d2caafb0abe6d1b5311c2649e8071eb241b2d64e75a0d0065107e62 \ - --hash=sha256:10c93628d7497c81686e8e5e557aafa78f230cd9e77dd0c40032ef90c18f2230 \ - --hash=sha256:11d117e6c63e8f495412d37e7dc2e2fff09c34b2d09dbe2bee3c6229577818be \ - --hash=sha256:11d3bcb7be35e7b1bba2c23beedac81ee893ac9871d0ba79effc7fc01167db6c \ - --hash=sha256:12a2b561af122e3d94cdb97fe6fb2bb2b82cef0cdca131646fdb940a1eda04f0 \ - --hash=sha256:12d1a39aa6b8c6f6248bb54550efcc1c38ce0d8096a146638fd4738e42284448 \ - --hash=sha256:1435ae15108b1cb6fffbcea2af3d468683b7afed0169ad718451f8db5d1aff6f \ - --hash=sha256:1c60b9c202d00052183c9be85e5eaf18a4ada0a47d188a83c8f5c5b23252f649 \ - --hash=sha256:1e8fcdd8f672a1c4fc8d0bd3a2b576b152d2a349782d1eb0f6b8e52e9954731d \ - --hash=sha256:20064ead0717cf9a73a6d1e779b23d149b53daf971169289ed2ed43a71e8d3b0 \ - --hash=sha256:21fa558996782fc226b529fdd2ed7866c2c6ec91cee82735c98a197fae39f706 \ - --hash=sha256:22908891a380d50738e1f978667536f6c6b526a2064156203d418f4856d6e86a \ - --hash=sha256:3160a0fd9754aab7d47f95a6b63ab355388d890163eb03b2d2b87ab0a30cfa59 \ - --hash=sha256:322102cdf1ab682ecc7d9b1c5eed4ec59657a65e1c146a0da342b78f4112db23 \ - --hash=sha256:34e0a2f9c370eb95597aae63bf85eb5e96826d81e3dcf88b8886012906f509b5 \ - --hash=sha256:3573d376454d956553c356df45bb824262c397c6e26ce43e8203c4c540ee0acb \ - --hash=sha256:3747443b6a904001473370d7810aa19c3a180ccd52a7157aacc264a5ac79265e \ - --hash=sha256:38e812a197bf8e71a59fe55b757a84c1f946d0ac114acafaafaf21667a7e169e \ - --hash=sha256:3a06f32c9634a8705f4ca9946d667609f52cf130d5548881401f1eb2c39b1e2c \ - --hash=sha256:3a5fc78f9e3f501a1614a98f7c54d3969f3ad9bba8ba3d9b438c3bc5d047dd28 \ - --hash=sha256:3d9098b479e78c85080c98e1e35ff40b4a31d8953102bb0fd7d1b6f8a2111a3d \ - --hash=sha256:3dc5b6a8ecfdc5748a7e429782598e4f17ef378e3e272eeb1340ea57c9109f41 \ - --hash=sha256:4155b51ae05ed47199dc5b2a4e62abccb274cee6b01da5b895099b61b1982974 \ - --hash=sha256:49919f8400b5e49e961f320c735388ee686a62327e773fa5b3ce6721f7e785ce \ - --hash=sha256:53d0a3fa5f8af98a1e261de6a3943ca631c526635eb5817a87a59d9a57ebf48f \ - --hash=sha256:5f008525e02908b20e04707a4f704cd286d94718f48bb33edddc7d7b584dddc1 \ - --hash=sha256:628c985afb2c7d27a4800bfb609e03985aaecb42f955049957814e0491d4006d \ - --hash=sha256:65ed923f84a6844de5fd29726b888e58c62820e0769b76565480e1fdc3d062f8 \ - --hash=sha256:6734e606355834f13445b6adc38b53c0fd45f1a56a9ba06c2058f86893ae8017 \ - --hash=sha256:6baf0baf0d5d265fa7944feb9f7451cc316bfe30e8df1a61b1bb08577c554f31 \ - --hash=sha256:6f4f4668e1831850ebcc2fd0b1cd11721947b6dc7c00bf1c6bd3c929ae14f2c7 \ - --hash=sha256:6f5c2e7bc8a4bf7c426599765b1bd33217ec84023033672c1e9a8b35eaeaaaf8 \ - --hash=sha256:6f6c7a8a57e9405cad7485f4c9d3172ae486cfef1344b5ddd8e5239582d7355e \ - --hash=sha256:7381c66e0561c5757ffe616af869b916c8b4e42b367ab29fedc98481d1e74e14 \ - --hash=sha256:73dc03a6a7e30b7edc5b01b601e53e7fc924b04e1835e8e407c12c037e81adbd \ - --hash=sha256:74db0052d985cf37fa111828d0dd230776ac99c740e1a758ad99094be4f1803d \ - --hash=sha256:75f2568b4189dda1c567339b48cba4ac7384accb9c2a7ed655cd86b04055c795 \ - --hash=sha256:78cacd03e79d009d95635e7d6ff12c21eb89b894c354bd2b2ed0b4763373693b \ - --hash=sha256:80d1543d58bd3d6c271b66abf454d437a438dff01c3e62fdbcd68f2a11310d4b \ - --hash=sha256:830d2948a5ec37c386d3170c483063798d7879037492540f10a475e3fd6f244b \ - --hash=sha256:891cf9b48776b5c61c700b55a598621fdb7b1e301a550365571e9624f270c203 \ - --hash=sha256:8f25e17ab3039b05f762b0a55ae0b3632b2e073d9c8fc88e89aca31a6198e88f \ - --hash=sha256:9a3267620866c9d17b959a84dd0bd2d45719b817245e49371ead79ed4f710d19 \ - --hash=sha256:a04f86f41a8916fe45ac5024ec477f41f886b3c435da2d4e3d2709b22ab02af1 \ - --hash=sha256:aaf53a6cebad0eae578f062c7d462155eada9c172bd8c4d250b8c1d8eb7f916a \ - --hash=sha256:abc1185d79f47c0a7aaf7e2412a0eb2c03b724581139193d2d82b3ad8cbb00ac \ - --hash=sha256:ac0aa6cd53ab9a31d397f8303f92c42f534693528fafbdb997c82bae6e477ad9 \ - --hash=sha256:ac3775e3311661d4adace3697a52ac0bab17edd166087d493b52d4f4f553f9f0 \ - --hash=sha256:b06f0d3bf045158d2fb8837c5785fe9ff9b8c93358be64461a1089f5da983137 \ - --hash=sha256:b116502087ce8a6b7a5f1814568ccbd0e9f6cfd99948aa59b0e241dc57cf739f \ - --hash=sha256:b82fab78e0b1329e183a65260581de4375f619167478dddab510c6c6fb04d9b6 \ - --hash=sha256:bd7163182133c0c7701b25e604cf1611c0d87712e56e88e7ee5d72deab3e76b5 \ - --hash=sha256:c36bcbc0d5174a80d6cccf43a0ecaca44e81d25be4b7f90f0ed7bcfbb5a00909 \ - --hash=sha256:c3af8e0f07399d3176b179f2e2634c3ce9c1301379a6b8c9c9aeecd481da494f \ - --hash=sha256:c84132a54c750fda57729d1e2599bb598f5fa0344085dbde5003ba429a4798c0 \ - --hash=sha256:cb7b2ab0188829593b9de646545175547a70d9a6e2b63bf2cd87a0a391599324 \ - --hash=sha256:cca4def576f47a09a943666b8f829606bcb17e2bc2d5911a46c8f8da45f56755 \ - --hash=sha256:cf6511efa4801b9b38dc5546d7547d5b5c6ef4b081c60b23e4d941d0eba9cbeb \ - --hash=sha256:d16fd5252f883eb074ca55cb622bc0bee49b979ae4e8639fff6ca3ff44f9f854 \ - --hash=sha256:d2686f91611f9e17f4548dbf050e75b079bbc2a82be565832bc8ea9047b61c8c \ - --hash=sha256:d7fc3fca01da18fbabe4625d64bb612b533533ed10045a2ac3dd194bfa656b60 \ - --hash=sha256:dd5653e67b149503c68c4018bf07e42eeed6b4e956b24c00ccdf93ac79cdff84 \ - --hash=sha256:de5695a6f1d8340b12a5d6d4484290ee74d61e467c39ff03b39e30df62cf83a0 \ - --hash=sha256:e0ac8959c929593fee38da1c2b64ee9778733cdf03c482c9ff1d508b6b593b2b \ - --hash=sha256:e1b25e3ad6c909f398df8921780d6a3d120d8c09466720226fc621605b6f92b1 \ - --hash=sha256:e633940f28c1e913615fd624fcdd72fdba807bf53ea6925d6a588e84e1151531 \ - --hash=sha256:e89df2958e5159b811af9ff0f92614dabf4ff617c03a4c1c6ff53bf1c399e0e1 \ - --hash=sha256:ea9f9c6034ea2d93d9147818f17c2a0860d41b71c38b9ce4d55f21b6f9165a11 \ - --hash=sha256:f645caaf0008bacf349875a974220f1f1da349c5dbe7c4ec93048cdc785a3326 \ - --hash=sha256:f8303414c7b03f794347ad062c0516cee0e15f7a612abd0ce1e25caf6ceb47df \ - --hash=sha256:fca62a8301b605b954ad2e9c3666f9d97f63872aa4efcae5492baca2056b74ab +charset-normalizer==3.2.0 \ + --hash=sha256:04e57ab9fbf9607b77f7d057974694b4f6b142da9ed4a199859d9d4d5c63fe96 \ + --hash=sha256:09393e1b2a9461950b1c9a45d5fd251dc7c6f228acab64da1c9c0165d9c7765c \ + --hash=sha256:0b87549028f680ca955556e3bd57013ab47474c3124dc069faa0b6545b6c9710 \ + --hash=sha256:1000fba1057b92a65daec275aec30586c3de2401ccdcd41f8a5c1e2c87078706 \ + --hash=sha256:1249cbbf3d3b04902ff081ffbb33ce3377fa6e4c7356f759f3cd076cc138d020 \ + --hash=sha256:1920d4ff15ce893210c1f0c0e9d19bfbecb7983c76b33f046c13a8ffbd570252 \ + --hash=sha256:193cbc708ea3aca45e7221ae58f0fd63f933753a9bfb498a3b474878f12caaad \ + --hash=sha256:1a100c6d595a7f316f1b6f01d20815d916e75ff98c27a01ae817439ea7726329 \ + --hash=sha256:1f30b48dd7fa1474554b0b0f3fdfdd4c13b5c737a3c6284d3cdc424ec0ffff3a \ + --hash=sha256:203f0c8871d5a7987be20c72442488a0b8cfd0f43b7973771640fc593f56321f \ + --hash=sha256:246de67b99b6851627d945db38147d1b209a899311b1305dd84916f2b88526c6 \ + --hash=sha256:2dee8e57f052ef5353cf608e0b4c871aee320dd1b87d351c28764fc0ca55f9f4 \ + --hash=sha256:2efb1bd13885392adfda4614c33d3b68dee4921fd0ac1d3988f8cbb7d589e72a \ + --hash=sha256:2f4ac36d8e2b4cc1aa71df3dd84ff8efbe3bfb97ac41242fbcfc053c67434f46 \ + --hash=sha256:3170c9399da12c9dc66366e9d14da8bf7147e1e9d9ea566067bbce7bb74bd9c2 \ + --hash=sha256:3b1613dd5aee995ec6d4c69f00378bbd07614702a315a2cf6c1d21461fe17c23 \ + --hash=sha256:3bb3d25a8e6c0aedd251753a79ae98a093c7e7b471faa3aa9a93a81431987ace \ + --hash=sha256:3bb7fda7260735efe66d5107fb7e6af6a7c04c7fce9b2514e04b7a74b06bf5dd \ + --hash=sha256:41b25eaa7d15909cf3ac4c96088c1f266a9a93ec44f87f1d13d4a0e86c81b982 \ + --hash=sha256:45de3f87179c1823e6d9e32156fb14c1927fcc9aba21433f088fdfb555b77c10 \ + --hash=sha256:46fb8c61d794b78ec7134a715a3e564aafc8f6b5e338417cb19fe9f57a5a9bf2 \ + --hash=sha256:48021783bdf96e3d6de03a6e39a1171ed5bd7e8bb93fc84cc649d11490f87cea \ + --hash=sha256:4957669ef390f0e6719db3613ab3a7631e68424604a7b448f079bee145da6e09 \ + --hash=sha256:5e86d77b090dbddbe78867a0275cb4df08ea195e660f1f7f13435a4649e954e5 \ + --hash=sha256:6339d047dab2780cc6220f46306628e04d9750f02f983ddb37439ca47ced7149 \ + --hash=sha256:681eb3d7e02e3c3655d1b16059fbfb605ac464c834a0c629048a30fad2b27489 \ + --hash=sha256:6c409c0deba34f147f77efaa67b8e4bb83d2f11c8806405f76397ae5b8c0d1c9 \ + --hash=sha256:7095f6fbfaa55defb6b733cfeb14efaae7a29f0b59d8cf213be4e7ca0b857b80 \ + --hash=sha256:70c610f6cbe4b9fce272c407dd9d07e33e6bf7b4aa1b7ffb6f6ded8e634e3592 \ + --hash=sha256:72814c01533f51d68702802d74f77ea026b5ec52793c791e2da806a3844a46c3 \ + --hash=sha256:7a4826ad2bd6b07ca615c74ab91f32f6c96d08f6fcc3902ceeedaec8cdc3bcd6 \ + --hash=sha256:7c70087bfee18a42b4040bb9ec1ca15a08242cf5867c58726530bdf3945672ed \ + --hash=sha256:855eafa5d5a2034b4621c74925d89c5efef61418570e5ef9b37717d9c796419c \ + --hash=sha256:8700f06d0ce6f128de3ccdbc1acaea1ee264d2caa9ca05daaf492fde7c2a7200 \ + --hash=sha256:89f1b185a01fe560bc8ae5f619e924407efca2191b56ce749ec84982fc59a32a \ + --hash=sha256:8b2c760cfc7042b27ebdb4a43a4453bd829a5742503599144d54a032c5dc7e9e \ + --hash=sha256:8c2f5e83493748286002f9369f3e6607c565a6a90425a3a1fef5ae32a36d749d \ + --hash=sha256:8e098148dd37b4ce3baca71fb394c81dc5d9c7728c95df695d2dca218edf40e6 \ + --hash=sha256:94aea8eff76ee6d1cdacb07dd2123a68283cb5569e0250feab1240058f53b623 \ + --hash=sha256:95eb302ff792e12aba9a8b8f8474ab229a83c103d74a750ec0bd1c1eea32e669 \ + --hash=sha256:9bd9b3b31adcb054116447ea22caa61a285d92e94d710aa5ec97992ff5eb7cf3 \ + --hash=sha256:9e608aafdb55eb9f255034709e20d5a83b6d60c054df0802fa9c9883d0a937aa \ + --hash=sha256:a103b3a7069b62f5d4890ae1b8f0597618f628b286b03d4bc9195230b154bfa9 \ + --hash=sha256:a386ebe437176aab38c041de1260cd3ea459c6ce5263594399880bbc398225b2 \ + --hash=sha256:a38856a971c602f98472050165cea2cdc97709240373041b69030be15047691f \ + --hash=sha256:a401b4598e5d3f4a9a811f3daf42ee2291790c7f9d74b18d75d6e21dda98a1a1 \ + --hash=sha256:a7647ebdfb9682b7bb97e2a5e7cb6ae735b1c25008a70b906aecca294ee96cf4 \ + --hash=sha256:aaf63899c94de41fe3cf934601b0f7ccb6b428c6e4eeb80da72c58eab077b19a \ + --hash=sha256:b0dac0ff919ba34d4df1b6131f59ce95b08b9065233446be7e459f95554c0dc8 \ + --hash=sha256:baacc6aee0b2ef6f3d308e197b5d7a81c0e70b06beae1f1fcacffdbd124fe0e3 \ + --hash=sha256:bf420121d4c8dce6b889f0e8e4ec0ca34b7f40186203f06a946fa0276ba54029 \ + --hash=sha256:c04a46716adde8d927adb9457bbe39cf473e1e2c2f5d0a16ceb837e5d841ad4f \ + --hash=sha256:c0b21078a4b56965e2b12f247467b234734491897e99c1d51cee628da9786959 \ + --hash=sha256:c1c76a1743432b4b60ab3358c937a3fe1341c828ae6194108a94c69028247f22 \ + --hash=sha256:c4983bf937209c57240cff65906b18bb35e64ae872da6a0db937d7b4af845dd7 \ + --hash=sha256:c4fb39a81950ec280984b3a44f5bd12819953dc5fa3a7e6fa7a80db5ee853952 \ + --hash=sha256:c57921cda3a80d0f2b8aec7e25c8aa14479ea92b5b51b6876d975d925a2ea346 \ + --hash=sha256:c8063cf17b19661471ecbdb3df1c84f24ad2e389e326ccaf89e3fb2484d8dd7e \ + --hash=sha256:ccd16eb18a849fd8dcb23e23380e2f0a354e8daa0c984b8a732d9cfaba3a776d \ + --hash=sha256:cd6dbe0238f7743d0efe563ab46294f54f9bc8f4b9bcf57c3c666cc5bc9d1299 \ + --hash=sha256:d62e51710986674142526ab9f78663ca2b0726066ae26b78b22e0f5e571238dd \ + --hash=sha256:db901e2ac34c931d73054d9797383d0f8009991e723dab15109740a63e7f902a \ + --hash=sha256:e03b8895a6990c9ab2cdcd0f2fe44088ca1c65ae592b8f795c3294af00a461c3 \ + --hash=sha256:e1c8a2f4c69e08e89632defbfabec2feb8a8d99edc9f89ce33c4b9e36ab63037 \ + --hash=sha256:e4b749b9cc6ee664a3300bb3a273c1ca8068c46be705b6c31cf5d276f8628a94 \ + --hash=sha256:e6a5bf2cba5ae1bb80b154ed68a3cfa2fa00fde979a7f50d6598d3e17d9ac20c \ + --hash=sha256:e857a2232ba53ae940d3456f7533ce6ca98b81917d47adc3c7fd55dad8fab858 \ + --hash=sha256:ee4006268ed33370957f55bf2e6f4d263eaf4dc3cfc473d1d90baff6ed36ce4a \ + --hash=sha256:eef9df1eefada2c09a5e7a40991b9fc6ac6ef20b1372abd48d2794a316dc0449 \ + --hash=sha256:f058f6963fd82eb143c692cecdc89e075fa0828db2e5b291070485390b2f1c9c \ + --hash=sha256:f25c229a6ba38a35ae6e25ca1264621cc25d4d38dca2942a7fce0b67a4efe918 \ + --hash=sha256:f2a1d0fd4242bd8643ce6f98927cf9c04540af6efa92323e9d3124f57727bfc1 \ + --hash=sha256:f7560358a6811e52e9c4d142d497f1a6e10103d3a6881f18d04dbce3729c0e2c \ + --hash=sha256:f779d3ad205f108d14e99bb3859aa7dd8e9c68874617c72354d7ecaec2a054ac \ + --hash=sha256:f87f746ee241d30d6ed93969de31e5ffd09a2961a051e60ae6bddde9ec3583aa # via requests -google-auth==2.19.1 \ - --hash=sha256:a9cfa88b3e16196845e64a3658eb953992129d13ac7337b064c6546f77c17183 \ - --hash=sha256:ea165e014c7cbd496558796b627c271aa8c18b4cba79dc1cc962b24c5efdfb85 +google-auth==2.22.0 \ + --hash=sha256:164cba9af4e6e4e40c3a4f90a1a6c12ee56f14c0b4868d1ca91b32826ab334ce \ + --hash=sha256:d61d1b40897407b574da67da1a833bdc10d5a11642566e506565d1b1a46ba873 # via # google-auth-oauthlib # tb-nightly @@ -97,81 +97,77 @@ google-auth-oauthlib==1.0.0 \ --hash=sha256:95880ca704928c300f48194d1770cf5b1462835b6e49db61445a520f793fd5fb \ --hash=sha256:e375064964820b47221a7e1b7ee1fd77051b6323c3f9e3e19785f78ab67ecfc5 # via tb-nightly -grpcio==1.54.2 \ - --hash=sha256:0212e2f7fdf7592e4b9d365087da30cb4d71e16a6f213120c89b4f8fb35a3ab3 \ - --hash=sha256:09d4bfd84686cd36fd11fd45a0732c7628308d094b14d28ea74a81db0bce2ed3 \ - --hash=sha256:1e623e0cf99a0ac114f091b3083a1848dbc64b0b99e181473b5a4a68d4f6f821 \ - --hash=sha256:2288d76e4d4aa7ef3fe7a73c1c470b66ea68e7969930e746a8cd8eca6ef2a2ea \ - --hash=sha256:2296356b5c9605b73ed6a52660b538787094dae13786ba53080595d52df13a98 \ - --hash=sha256:2a1e601ee31ef30a9e2c601d0867e236ac54c922d32ed9f727b70dd5d82600d5 \ - --hash=sha256:2be88c081e33f20630ac3343d8ad9f1125f32987968e9c8c75c051c9800896e8 \ - --hash=sha256:33d40954199bddbb6a78f8f6f2b2082660f381cd2583ec860a6c2fa7c8400c08 \ - --hash=sha256:40e1cbf69d6741b40f750f3cccc64326f927ac6145a9914d33879e586002350c \ - --hash=sha256:46a057329938b08e5f0e12ea3d7aed3ecb20a0c34c4a324ef34e00cecdb88a12 \ - --hash=sha256:4864f99aac207e3e45c5e26c6cbb0ad82917869abc2f156283be86c05286485c \ - --hash=sha256:4c44e1a765b31e175c391f22e8fc73b2a2ece0e5e6ff042743d8109b5d2eff9f \ - --hash=sha256:4cb283f630624ebb16c834e5ac3d7880831b07cbe76cb08ab7a271eeaeb8943e \ - --hash=sha256:5008964885e8d23313c8e5ea0d44433be9bfd7e24482574e8cc43c02c02fc796 \ - --hash=sha256:50a9f075eeda5097aa9a182bb3877fe1272875e45370368ac0ee16ab9e22d019 \ - --hash=sha256:51630c92591d6d3fe488a7c706bd30a61594d144bac7dee20c8e1ce78294f474 \ - --hash=sha256:5cc928cfe6c360c1df636cf7991ab96f059666ac7b40b75a769410cc6217df9c \ - --hash=sha256:61f7203e2767800edee7a1e1040aaaf124a35ce0c7fe0883965c6b762defe598 \ - --hash=sha256:66233ccd2a9371158d96e05d082043d47dadb18cbb294dc5accfdafc2e6b02a7 \ - --hash=sha256:70fcac7b94f4c904152809a050164650ac81c08e62c27aa9f156ac518029ebbe \ - --hash=sha256:714242ad0afa63a2e6dabd522ae22e1d76e07060b5af2ddda5474ba4f14c2c94 \ - --hash=sha256:782f4f8662a2157c4190d0f99eaaebc602899e84fb1e562a944e5025929e351c \ - --hash=sha256:7fc2b4edb938c8faa4b3c3ea90ca0dd89b7565a049e8e4e11b77e60e4ed2cc05 \ - --hash=sha256:881d058c5ccbea7cc2c92085a11947b572498a27ef37d3eef4887f499054dca8 \ - --hash=sha256:89dde0ac72a858a44a2feb8e43dc68c0c66f7857a23f806e81e1b7cc7044c9cf \ - --hash=sha256:8cdbcbd687e576d48f7886157c95052825ca9948c0ed2afdc0134305067be88b \ - --hash=sha256:8d6192c37a30a115f4663592861f50e130caed33efc4eec24d92ec881c92d771 \ - --hash=sha256:96a41817d2c763b1d0b32675abeb9179aa2371c72aefdf74b2d2b99a1b92417b \ - --hash=sha256:9bdbb7624d65dc0ed2ed8e954e79ab1724526f09b1efa88dcd9a1815bf28be5f \ - --hash=sha256:9bf88004fe086c786dc56ef8dd6cb49c026833fdd6f42cb853008bce3f907148 \ - --hash=sha256:a08920fa1a97d4b8ee5db2f31195de4a9def1a91bc003544eb3c9e6b8977960a \ - --hash=sha256:a2f5a1f1080ccdc7cbaf1171b2cf384d852496fe81ddedeb882d42b85727f610 \ - --hash=sha256:b04202453941a63b36876a7172b45366dc0cde10d5fd7855c0f4a4e673c0357a \ - --hash=sha256:b38b3de8cff5bc70f8f9c615f51b48eff7313fc9aca354f09f81b73036e7ddfa \ - --hash=sha256:b52d00d1793d290c81ad6a27058f5224a7d5f527867e5b580742e1bd211afeee \ - --hash=sha256:b74ae837368cfffeb3f6b498688a123e6b960951be4dec0e869de77e7fa0439e \ - --hash=sha256:be48496b0e00460717225e7680de57c38be1d8629dc09dadcd1b3389d70d942b \ - --hash=sha256:c0e3155fc5335ec7b3b70f15230234e529ca3607b20a562b6c75fb1b1218874c \ - --hash=sha256:c2392f5b5d84b71d853918687d806c1aa4308109e5ca158a16e16a6be71041eb \ - --hash=sha256:c72956972e4b508dd39fdc7646637a791a9665b478e768ffa5f4fe42123d5de1 \ - --hash=sha256:dc80c9c6b608bf98066a038e0172013a49cfa9a08d53335aefefda2c64fc68f4 \ - --hash=sha256:e416c8baf925b5a1aff31f7f5aecc0060b25d50cce3a5a7255dc5cf2f1d4e5eb \ - --hash=sha256:f8da84bbc61a4e92af54dc96344f328e5822d574f767e9b08e1602bb5ddc254a \ - --hash=sha256:f900ed4ad7a0f1f05d35f955e0943944d5a75f607a836958c6b8ab2a81730ef2 \ - --hash=sha256:fd6c6c29717724acf9fc1847c4515d57e4dc12762452457b9cb37461f30a81bb +grpcio==1.56.2 \ + --hash=sha256:06e84ad9ae7668a109e970c7411e7992751a116494cba7c4fb877656527f9a57 \ + --hash=sha256:0ff789ae7d8ddd76d2ac02e7d13bfef6fc4928ac01e1dcaa182be51b6bcc0aaa \ + --hash=sha256:10954662f77dc36c9a1fb5cc4a537f746580d6b5734803be1e587252682cda8d \ + --hash=sha256:139f66656a762572ae718fa0d1f2dce47c05e9fbf7a16acd704c354405b97df9 \ + --hash=sha256:1c31e52a04e62c8577a7bf772b3e7bed4df9c9e0dd90f92b6ffa07c16cab63c9 \ + --hash=sha256:33971197c47965cc1d97d78d842163c283e998223b151bab0499b951fd2c0b12 \ + --hash=sha256:345356b307cce5d14355e8e055b4ca5f99bc857c33a3dc1ddbc544fca9cd0475 \ + --hash=sha256:373b48f210f43327a41e397391715cd11cfce9ded2fe76a5068f9bacf91cc226 \ + --hash=sha256:3ccb621749a81dc7755243665a70ce45536ec413ef5818e013fe8dfbf5aa497b \ + --hash=sha256:42a3bbb2bc07aef72a7d97e71aabecaf3e4eb616d39e5211e2cfe3689de860ca \ + --hash=sha256:42e63904ee37ae46aa23de50dac8b145b3596f43598fa33fe1098ab2cbda6ff5 \ + --hash=sha256:4eb37dd8dd1aa40d601212afa27ca5be255ba792e2e0b24d67b8af5e012cdb7d \ + --hash=sha256:51173e8fa6d9a2d85c14426bdee5f5c4a0654fd5fddcc21fe9d09ab0f6eb8b35 \ + --hash=sha256:5144feb20fe76e73e60c7d73ec3bf54f320247d1ebe737d10672480371878b48 \ + --hash=sha256:5344be476ac37eb9c9ad09c22f4ea193c1316bf074f1daf85bddb1b31fda5116 \ + --hash=sha256:6108e5933eb8c22cd3646e72d5b54772c29f57482fd4c41a0640aab99eb5071d \ + --hash=sha256:6a007a541dff984264981fbafeb052bfe361db63578948d857907df9488d8774 \ + --hash=sha256:6ee26e9dfb3996aff7c870f09dc7ad44a5f6732b8bdb5a5f9905737ac6fd4ef1 \ + --hash=sha256:750de923b456ca8c0f1354d6befca45d1f3b3a789e76efc16741bd4132752d95 \ + --hash=sha256:7c5ede2e2558f088c49a1ddda19080e4c23fb5d171de80a726b61b567e3766ed \ + --hash=sha256:830215173ad45d670140ff99aac3b461f9be9a6b11bee1a17265aaaa746a641a \ + --hash=sha256:8391cea5ce72f4a12368afd17799474015d5d3dc00c936a907eb7c7eaaea98a5 \ + --hash=sha256:8940d6de7068af018dfa9a959a3510e9b7b543f4c405e88463a1cbaa3b2b379a \ + --hash=sha256:89a49cc5ad08a38b6141af17e00d1dd482dc927c7605bc77af457b5a0fca807c \ + --hash=sha256:900bc0096c2ca2d53f2e5cebf98293a7c32f532c4aeb926345e9747452233950 \ + --hash=sha256:97e0efaebbfd222bcaac2f1735c010c1d3b167112d9d237daebbeedaaccf3d1d \ + --hash=sha256:9e04d4e4cfafa7c5264e535b5d28e786f0571bea609c3f0aaab13e891e933e9c \ + --hash=sha256:a4c60abd950d6de3e4f1ddbc318075654d275c29c846ab6a043d6ed2c52e4c8c \ + --hash=sha256:a6ff459dac39541e6a2763a4439c4ca6bc9ecb4acc05a99b79246751f9894756 \ + --hash=sha256:a72797549935c9e0b9bc1def1768c8b5a709538fa6ab0678e671aec47ebfd55e \ + --hash=sha256:af4063ef2b11b96d949dccbc5a987272f38d55c23c4c01841ea65a517906397f \ + --hash=sha256:b975b85d1d5efc36cf8b237c5f3849b64d1ba33d6282f5e991f28751317504a1 \ + --hash=sha256:bf0b9959e673505ee5869950642428046edb91f99942607c2ecf635f8a4b31c9 \ + --hash=sha256:c0c85c5cbe8b30a32fa6d802588d55ffabf720e985abe9590c7c886919d875d4 \ + --hash=sha256:c3f3237a57e42f79f1e560726576aedb3a7ef931f4e3accb84ebf6acc485d316 \ + --hash=sha256:c3fa3ab0fb200a2c66493828ed06ccd1a94b12eddbfb985e7fd3e5723ff156c6 \ + --hash=sha256:c435f5ce1705de48e08fcbcfaf8aee660d199c90536e3e06f2016af7d6a938dd \ + --hash=sha256:c90da4b124647547a68cf2f197174ada30c7bb9523cb976665dfd26a9963d328 \ + --hash=sha256:cbdf2c498e077282cd427cfd88bdce4668019791deef0be8155385ab2ba7837f \ + --hash=sha256:d1fbad1f9077372b6587ec589c1fc120b417b6c8ad72d3e3cc86bbbd0a3cee93 \ + --hash=sha256:d39f5d4af48c138cb146763eda14eb7d8b3ccbbec9fe86fb724cd16e0e914c64 \ + --hash=sha256:ddb4a6061933bd9332b74eac0da25f17f32afa7145a33a0f9711ad74f924b1b8 \ + --hash=sha256:ded637176addc1d3eef35331c39acc598bac550d213f0a1bedabfceaa2244c87 \ + --hash=sha256:f20fd21f7538f8107451156dd1fe203300b79a9ddceba1ee0ac8132521a008ed \ + --hash=sha256:fda2783c12f553cdca11c08e5af6eecbd717280dc8fbe28a110897af1c15a88c # via # -r ./requirements.in # tb-nightly -h5py==3.8.0 \ - --hash=sha256:03890b1c123d024fb0239a3279737d5432498c1901c354f8b10d8221d1d16235 \ - --hash=sha256:0fef76e10b9216657fa37e7edff6d8be0709b25bd5066474c229b56cf0098df9 \ - --hash=sha256:26ffc344ec9984d2cd3ca0265007299a8bac8d85c1ad48f4639d8d3aed2af171 \ - --hash=sha256:290e00fa2de74a10688d1bac98d5a9cdd43f14f58e562c580b5b3dfbd358ecae \ - --hash=sha256:33b15aae79e9147aebe1d0e54099cbcde8d65e3e227cd5b59e49b1272aa0e09d \ - --hash=sha256:36761693efbe53df179627a775476dcbc37727d6e920958277a7efbc18f1fb73 \ - --hash=sha256:377865821fe80ad984d003723d6f8890bd54ceeb5981b43c0313b9df95411b30 \ - --hash=sha256:49bc857635f935fa30e92e61ac1e87496df8f260a6945a3235e43a9890426866 \ - --hash=sha256:4a506fc223def428f4329e7e1f9fe1c8c593eab226e7c0942c8d75308ad49950 \ - --hash=sha256:533d7dad466ddb7e3b30af274b630eb7c1a6e4ddf01d1c373a0334dc2152110a \ - --hash=sha256:5fd2252d1fc364ba0e93dd0b7089f4906b66805cb4e6aca7fa8874ac08649647 \ - --hash=sha256:6fead82f0c4000cf38d53f9c030780d81bfa0220218aee13b90b7701c937d95f \ - --hash=sha256:7f3350fc0a8407d668b13247861c2acd23f7f5fe7d060a3ad9b0820f5fcbcae0 \ - --hash=sha256:8f55d9c6c84d7d09c79fb85979e97b81ec6071cc776a97eb6b96f8f6ec767323 \ - --hash=sha256:98a240cd4c1bfd568aaa52ec42d263131a2582dab82d74d3d42a0d954cac12be \ - --hash=sha256:9f6f6ffadd6bfa9b2c5b334805eb4b19ca0a5620433659d8f7fb86692c40a359 \ - --hash=sha256:b685453e538b2b5934c58a644ac3f3b3d0cec1a01b6fb26d57388e9f9b674ad0 \ - --hash=sha256:b7865de06779b14d98068da387333ad9bf2756b5b579cc887fac169bc08f87c3 \ - --hash=sha256:bacaa1c16810dd2b3e4417f8e730971b7c4d53d234de61fe4a918db78e80e1e4 \ - --hash=sha256:bae730580ae928de409d63cbe4fdca4c82c3ad2bed30511d19d34e995d63c77e \ - --hash=sha256:c3389b63222b1c7a158bb7fe69d11ca00066740ec5574596d47a2fe5317f563a \ - --hash=sha256:c873ba9fd4fa875ad62ce0e4891725e257a8fe7f5abdbc17e51a5d54819be55c \ - --hash=sha256:db03e3f2c716205fbdabb34d0848459840585225eb97b4f08998c743821ca323 \ - --hash=sha256:f47f757d1b76f0ecb8aa0508ec8d1b390df67a8b67ee2515dc1b046f3a1596ea \ - --hash=sha256:f891b17e3a3e974e93f9e34e7cca9f530806543571ce078998676a555837d91d +h5py==3.9.0 \ + --hash=sha256:12aa556d540f11a2cae53ea7cfb94017353bd271fb3962e1296b342f6550d1b8 \ + --hash=sha256:23e74b878bbe1653ab34ca49b83cac85529cd0b36b9d625516c5830cc5ca2eac \ + --hash=sha256:36408f8c62f50007d14e000f9f3acf77e103b9e932c114cbe52a3089e50ebf94 \ + --hash=sha256:3f457089c5d524b7998e3649bc63240679b8fb0a3859ea53bbb06841f3d755f1 \ + --hash=sha256:54f01202cdea754ab4227dd27014bdbd561a4bbe4b631424fd812f7c2ce9c6ac \ + --hash=sha256:551e358db05a874a0f827b22e95b30092f2303edc4b91bb62ad2f10e0236e1a0 \ + --hash=sha256:64acceaf6aff92af091a4b83f6dee3cf8d3061f924a6bb3a33eb6c4658a8348b \ + --hash=sha256:6822a814b9d8b8363ff102f76ea8d026f0ca25850bb579d85376029ee3e73b93 \ + --hash=sha256:78e44686334cbbf2dd21d9df15823bc38663f27a3061f6a032c68a3e30c47bf7 \ + --hash=sha256:79bbca34696c6f9eeeb36a91776070c49a060b2879828e2c8fa6c58b8ed10dd1 \ + --hash=sha256:804c7fb42a34c8ab3a3001901c977a5c24d2e9c586a0f3e7c0a389130b4276fc \ + --hash=sha256:8d9492391ff5c3c80ec30ae2fe82a3f0efd1e750833739c25b0d090e3be1b095 \ + --hash=sha256:95f7a745efd0d56076999b52e8da5fad5d30823bac98b59c68ae75588d09991a \ + --hash=sha256:9da9e7e63376c32704e37ad4cea2dceae6964cee0d8515185b3ab9cbd6b947bc \ + --hash=sha256:a4e20897c88759cbcbd38fb45b507adc91af3e0f67722aa302d71f02dd44d286 \ + --hash=sha256:a6284061f3214335e1eec883a6ee497dbe7a79f19e6a57fed2dd1f03acd5a8cb \ + --hash=sha256:d97409e17915798029e297a84124705c8080da901307ea58f29234e09b073ddc \ + --hash=sha256:dbf5225543ca35ce9f61c950b73899a82be7ba60d58340e76d0bd42bf659235a \ + --hash=sha256:e604db6521c1e367c6bd7fad239c847f53cc46646f2d2651372d05ae5e95f817 \ + --hash=sha256:eb7bdd5e601dd1739698af383be03f3dad0465fe67184ebd5afca770f50df9d6 \ + --hash=sha256:f68b41efd110ce9af1cbe6fa8af9f4dcbadace6db972d30828b911949e28fadd # via -r ./requirements.in idna==3.4 \ --hash=sha256:814f528e8dead7d329833b91c5faa87d60bf71824cd12a7530b5526063d02cb4 \ @@ -180,12 +176,12 @@ idna==3.4 \ jax==0.4.7 \ --hash=sha256:5e7002d74db25f97c99b979d4ba1233b1ef26e1597e5fc468ad11d1c8a9dc4f8 # via -r ./requirements.in -keras-nightly==2.14.0.dev2023061207 \ - --hash=sha256:210671f010a0b21a5507be86b8e9e909f81b9f321cd3c51e1efdfdd41061919f \ - --hash=sha256:ad869b2bce863e111e4a57c7f5785f56097d93f683b5315df7f59917be1fa279 +keras-nightly==2.14.0.dev2023072407 \ + --hash=sha256:60ca7fae3ad903eeff858f45ddf9dc0dc395cdf41a07c26e807cb7f07e955441 \ + --hash=sha256:9eb387e3488f5ca87a4686b1ea93b8bd85b36d1006934e130f02185daf192ba5 # via -r ./requirements.in -lit==16.0.5.post0 \ - --hash=sha256:71745d9e58dad3717735d27e2a9cca0e9ca6861d067da73c307e02fd38c98479 +lit==16.0.6 \ + --hash=sha256:84623c9c23b6b14763d637f4e63e6b721b3446ada40bf7001d8fee70b8e77a9a # via -r ./requirements.in markdown==3.4.3 \ --hash=sha256:065fd4df22da73a625f14890dd77eb8040edcbd68794bcd35943be14490608b2 \ @@ -315,20 +311,20 @@ portpicker==1.5.2 \ --hash=sha256:01113f51c3cc63290a44dd7ae6e3eb9f8fe1b8a1f9d7988a897944230c39cd52 \ --hash=sha256:c55683ad725f5c00a41bc7db0225223e8be024b1fa564d039ed3390e4fd48fb3 # via -r ./requirements.in -protobuf==4.23.2 \ - --hash=sha256:09310bce43353b46d73ba7e3bca78273b9bc50349509b9698e64d288c6372c2a \ - --hash=sha256:20874e7ca4436f683b64ebdbee2129a5a2c301579a67d1a7dda2cdf62fb7f5f7 \ - --hash=sha256:25e3370eda26469b58b602e29dff069cfaae8eaa0ef4550039cc5ef8dc004511 \ - --hash=sha256:281342ea5eb631c86697e1e048cb7e73b8a4e85f3299a128c116f05f5c668f8f \ - --hash=sha256:384dd44cb4c43f2ccddd3645389a23ae61aeb8cfa15ca3a0f60e7c3ea09b28b3 \ - --hash=sha256:54a533b971288af3b9926e53850c7eb186886c0c84e61daa8444385a4720297f \ - --hash=sha256:6c081863c379bb1741be8f8193e893511312b1d7329b4a75445d1ea9955be69e \ - --hash=sha256:86df87016d290143c7ce3be3ad52d055714ebaebb57cc659c387e76cfacd81aa \ - --hash=sha256:8da6070310d634c99c0db7df48f10da495cc283fd9e9234877f0cd182d43ab7f \ - --hash=sha256:b2cfab63a230b39ae603834718db74ac11e52bccaaf19bf20f5cce1a84cf76df \ - --hash=sha256:c52cfcbfba8eb791255edd675c1fe6056f723bf832fa67f0442218f8817c076e \ - --hash=sha256:ce744938406de1e64b91410f473736e815f28c3b71201302612a68bf01517fea \ - --hash=sha256:efabbbbac1ab519a514579ba9ec52f006c28ae19d97915951f69fa70da2c9e91 +protobuf==4.23.4 \ + --hash=sha256:0a5759f5696895de8cc913f084e27fd4125e8fb0914bb729a17816a33819f474 \ + --hash=sha256:351cc90f7d10839c480aeb9b870a211e322bf05f6ab3f55fcb2f51331f80a7d2 \ + --hash=sha256:5fea3c64d41ea5ecf5697b83e41d09b9589e6f20b677ab3c48e5f242d9b7897b \ + --hash=sha256:6dd9b9940e3f17077e820b75851126615ee38643c2c5332aa7a359988820c720 \ + --hash=sha256:7b19b6266d92ca6a2a87effa88ecc4af73ebc5cfde194dc737cf8ef23a9a3b12 \ + --hash=sha256:8547bf44fe8cec3c69e3042f5c4fb3e36eb2a7a013bb0a44c018fc1e427aafbd \ + --hash=sha256:9053df6df8e5a76c84339ee4a9f5a2661ceee4a0dab019e8663c50ba324208b0 \ + --hash=sha256:c3e0939433c40796ca4cfc0fac08af50b00eb66a40bbbc5dee711998fb0bbc1e \ + --hash=sha256:ccd9430c0719dce806b93f89c91de7977304729e55377f872a92465d548329a9 \ + --hash=sha256:e1c915778d8ced71e26fcf43c0866d7499891bca14c4368448a82edc61fdbc70 \ + --hash=sha256:e9d0be5bf34b275b9f87ba7407796556abeeba635455d036c7351f7c183ef8ff \ + --hash=sha256:effeac51ab79332d44fba74660d40ae79985901ac21bca408f8dc335a81aa597 \ + --hash=sha256:fee88269a090ada09ca63551bf2f573eb2424035bcf2cb1b121895b01a46594a # via tb-nightly psutil==5.9.5 \ --hash=sha256:104a5cc0e31baa2bcf67900be36acde157756b9c44017b86b2c049f11957887d \ @@ -406,16 +402,16 @@ six==1.16.0 \ --hash=sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926 \ --hash=sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254 # via google-auth -tb-nightly==2.14.0a20230612 \ - --hash=sha256:1ad7d57386f6103df69d8f3083552d185d33c5a07fb18e691b60236d9b4dc679 +tb-nightly==2.14.0a20230724 \ + --hash=sha256:eed70d9ddca771b938d7ee90b90673bdbfcd4da8bf292badab2088f26c4612fe # via -r ./requirements.in -tensorboard-data-server==0.7.0 \ - --hash=sha256:64aa1be7c23e80b1a42c13b686eb0875bb70f5e755f4d2b8de5c1d880cf2267f \ - --hash=sha256:753d4214799b31da7b6d93837959abebbc6afa86e69eacf1e9a317a48daa31eb \ - --hash=sha256:eb7fa518737944dbf4f0cf83c2e40a7ac346bf91be2e6a0215de98be74e85454 +tensorboard-data-server==0.7.1 \ + --hash=sha256:255c02b7f5b03dd5c0a88c928e563441ff39e1d4b4a234cdbe09f016e53d9594 \ + --hash=sha256:9938bd39f5041797b33921066fba0eab03a0dd10d1887a05e62ae58841ad4c3f \ + --hash=sha256:be8d016a1aa394e6198280d4a3dc37898f56467310c5f5e617cac10a783e055a # via tb-nightly -tf-estimator-nightly==2.14.0.dev2023060108 \ - --hash=sha256:09f63c090f29b74ebb36076c0ef1105c8b0358c6920847f312926012926ff7ce +tf-estimator-nightly==2.14.0.dev2023072408 \ + --hash=sha256:80b79192e643e923a2e075837785626bce5c35d2d2d34c03eee4fbcea8afb884 # via -r ./requirements.in urllib3==1.26.16 \ --hash=sha256:8d36afa7616d8ab714608411b4a3b13e58f463aee519024578e062e141dce20f \ diff --git a/requirements_lock_3_11.txt b/requirements_lock_3_11.txt index a3c84ffba291ad..65c54527d47670 100644 --- a/requirements_lock_3_11.txt +++ b/requirements_lock_3_11.txt @@ -6,90 +6,90 @@ cachetools==5.3.1 \ --hash=sha256:95ef631eeaea14ba2e36f06437f36463aac3a096799e876ee55e5cdccb102590 \ --hash=sha256:dce83f2d9b4e1f732a8cd44af8e8fab2dbe46201467fc98b3ef8f269092bf62b # via google-auth -certifi==2023.5.7 \ - --hash=sha256:0f0d56dc5a6ad56fd4ba36484d6cc34451e1c6548c61daad8c320169f91eddc7 \ - --hash=sha256:c6c2e98f5c7869efca1f8916fed228dd91539f9f1b444c314c06eef02980c716 +certifi==2023.7.22 \ + --hash=sha256:539cc1d13202e33ca466e88b2807e29f4c13049d6d87031a3c110744495cb082 \ + --hash=sha256:92d6037539857d8206b8f6ae472e8b77db8058fec5937a1ef3f54304089edbb9 # via requests -charset-normalizer==3.1.0 \ - --hash=sha256:04afa6387e2b282cf78ff3dbce20f0cc071c12dc8f685bd40960cc68644cfea6 \ - --hash=sha256:04eefcee095f58eaabe6dc3cc2262f3bcd776d2c67005880894f447b3f2cb9c1 \ - --hash=sha256:0be65ccf618c1e7ac9b849c315cc2e8a8751d9cfdaa43027d4f6624bd587ab7e \ - --hash=sha256:0c95f12b74681e9ae127728f7e5409cbbef9cd914d5896ef238cc779b8152373 \ - --hash=sha256:0ca564606d2caafb0abe6d1b5311c2649e8071eb241b2d64e75a0d0065107e62 \ - --hash=sha256:10c93628d7497c81686e8e5e557aafa78f230cd9e77dd0c40032ef90c18f2230 \ - --hash=sha256:11d117e6c63e8f495412d37e7dc2e2fff09c34b2d09dbe2bee3c6229577818be \ - --hash=sha256:11d3bcb7be35e7b1bba2c23beedac81ee893ac9871d0ba79effc7fc01167db6c \ - --hash=sha256:12a2b561af122e3d94cdb97fe6fb2bb2b82cef0cdca131646fdb940a1eda04f0 \ - --hash=sha256:12d1a39aa6b8c6f6248bb54550efcc1c38ce0d8096a146638fd4738e42284448 \ - --hash=sha256:1435ae15108b1cb6fffbcea2af3d468683b7afed0169ad718451f8db5d1aff6f \ - --hash=sha256:1c60b9c202d00052183c9be85e5eaf18a4ada0a47d188a83c8f5c5b23252f649 \ - --hash=sha256:1e8fcdd8f672a1c4fc8d0bd3a2b576b152d2a349782d1eb0f6b8e52e9954731d \ - --hash=sha256:20064ead0717cf9a73a6d1e779b23d149b53daf971169289ed2ed43a71e8d3b0 \ - --hash=sha256:21fa558996782fc226b529fdd2ed7866c2c6ec91cee82735c98a197fae39f706 \ - --hash=sha256:22908891a380d50738e1f978667536f6c6b526a2064156203d418f4856d6e86a \ - --hash=sha256:3160a0fd9754aab7d47f95a6b63ab355388d890163eb03b2d2b87ab0a30cfa59 \ - --hash=sha256:322102cdf1ab682ecc7d9b1c5eed4ec59657a65e1c146a0da342b78f4112db23 \ - --hash=sha256:34e0a2f9c370eb95597aae63bf85eb5e96826d81e3dcf88b8886012906f509b5 \ - --hash=sha256:3573d376454d956553c356df45bb824262c397c6e26ce43e8203c4c540ee0acb \ - --hash=sha256:3747443b6a904001473370d7810aa19c3a180ccd52a7157aacc264a5ac79265e \ - --hash=sha256:38e812a197bf8e71a59fe55b757a84c1f946d0ac114acafaafaf21667a7e169e \ - --hash=sha256:3a06f32c9634a8705f4ca9946d667609f52cf130d5548881401f1eb2c39b1e2c \ - --hash=sha256:3a5fc78f9e3f501a1614a98f7c54d3969f3ad9bba8ba3d9b438c3bc5d047dd28 \ - --hash=sha256:3d9098b479e78c85080c98e1e35ff40b4a31d8953102bb0fd7d1b6f8a2111a3d \ - --hash=sha256:3dc5b6a8ecfdc5748a7e429782598e4f17ef378e3e272eeb1340ea57c9109f41 \ - --hash=sha256:4155b51ae05ed47199dc5b2a4e62abccb274cee6b01da5b895099b61b1982974 \ - --hash=sha256:49919f8400b5e49e961f320c735388ee686a62327e773fa5b3ce6721f7e785ce \ - --hash=sha256:53d0a3fa5f8af98a1e261de6a3943ca631c526635eb5817a87a59d9a57ebf48f \ - --hash=sha256:5f008525e02908b20e04707a4f704cd286d94718f48bb33edddc7d7b584dddc1 \ - --hash=sha256:628c985afb2c7d27a4800bfb609e03985aaecb42f955049957814e0491d4006d \ - --hash=sha256:65ed923f84a6844de5fd29726b888e58c62820e0769b76565480e1fdc3d062f8 \ - --hash=sha256:6734e606355834f13445b6adc38b53c0fd45f1a56a9ba06c2058f86893ae8017 \ - --hash=sha256:6baf0baf0d5d265fa7944feb9f7451cc316bfe30e8df1a61b1bb08577c554f31 \ - --hash=sha256:6f4f4668e1831850ebcc2fd0b1cd11721947b6dc7c00bf1c6bd3c929ae14f2c7 \ - --hash=sha256:6f5c2e7bc8a4bf7c426599765b1bd33217ec84023033672c1e9a8b35eaeaaaf8 \ - --hash=sha256:6f6c7a8a57e9405cad7485f4c9d3172ae486cfef1344b5ddd8e5239582d7355e \ - --hash=sha256:7381c66e0561c5757ffe616af869b916c8b4e42b367ab29fedc98481d1e74e14 \ - --hash=sha256:73dc03a6a7e30b7edc5b01b601e53e7fc924b04e1835e8e407c12c037e81adbd \ - --hash=sha256:74db0052d985cf37fa111828d0dd230776ac99c740e1a758ad99094be4f1803d \ - --hash=sha256:75f2568b4189dda1c567339b48cba4ac7384accb9c2a7ed655cd86b04055c795 \ - --hash=sha256:78cacd03e79d009d95635e7d6ff12c21eb89b894c354bd2b2ed0b4763373693b \ - --hash=sha256:80d1543d58bd3d6c271b66abf454d437a438dff01c3e62fdbcd68f2a11310d4b \ - --hash=sha256:830d2948a5ec37c386d3170c483063798d7879037492540f10a475e3fd6f244b \ - --hash=sha256:891cf9b48776b5c61c700b55a598621fdb7b1e301a550365571e9624f270c203 \ - --hash=sha256:8f25e17ab3039b05f762b0a55ae0b3632b2e073d9c8fc88e89aca31a6198e88f \ - --hash=sha256:9a3267620866c9d17b959a84dd0bd2d45719b817245e49371ead79ed4f710d19 \ - --hash=sha256:a04f86f41a8916fe45ac5024ec477f41f886b3c435da2d4e3d2709b22ab02af1 \ - --hash=sha256:aaf53a6cebad0eae578f062c7d462155eada9c172bd8c4d250b8c1d8eb7f916a \ - --hash=sha256:abc1185d79f47c0a7aaf7e2412a0eb2c03b724581139193d2d82b3ad8cbb00ac \ - --hash=sha256:ac0aa6cd53ab9a31d397f8303f92c42f534693528fafbdb997c82bae6e477ad9 \ - --hash=sha256:ac3775e3311661d4adace3697a52ac0bab17edd166087d493b52d4f4f553f9f0 \ - --hash=sha256:b06f0d3bf045158d2fb8837c5785fe9ff9b8c93358be64461a1089f5da983137 \ - --hash=sha256:b116502087ce8a6b7a5f1814568ccbd0e9f6cfd99948aa59b0e241dc57cf739f \ - --hash=sha256:b82fab78e0b1329e183a65260581de4375f619167478dddab510c6c6fb04d9b6 \ - --hash=sha256:bd7163182133c0c7701b25e604cf1611c0d87712e56e88e7ee5d72deab3e76b5 \ - --hash=sha256:c36bcbc0d5174a80d6cccf43a0ecaca44e81d25be4b7f90f0ed7bcfbb5a00909 \ - --hash=sha256:c3af8e0f07399d3176b179f2e2634c3ce9c1301379a6b8c9c9aeecd481da494f \ - --hash=sha256:c84132a54c750fda57729d1e2599bb598f5fa0344085dbde5003ba429a4798c0 \ - --hash=sha256:cb7b2ab0188829593b9de646545175547a70d9a6e2b63bf2cd87a0a391599324 \ - --hash=sha256:cca4def576f47a09a943666b8f829606bcb17e2bc2d5911a46c8f8da45f56755 \ - --hash=sha256:cf6511efa4801b9b38dc5546d7547d5b5c6ef4b081c60b23e4d941d0eba9cbeb \ - --hash=sha256:d16fd5252f883eb074ca55cb622bc0bee49b979ae4e8639fff6ca3ff44f9f854 \ - --hash=sha256:d2686f91611f9e17f4548dbf050e75b079bbc2a82be565832bc8ea9047b61c8c \ - --hash=sha256:d7fc3fca01da18fbabe4625d64bb612b533533ed10045a2ac3dd194bfa656b60 \ - --hash=sha256:dd5653e67b149503c68c4018bf07e42eeed6b4e956b24c00ccdf93ac79cdff84 \ - --hash=sha256:de5695a6f1d8340b12a5d6d4484290ee74d61e467c39ff03b39e30df62cf83a0 \ - --hash=sha256:e0ac8959c929593fee38da1c2b64ee9778733cdf03c482c9ff1d508b6b593b2b \ - --hash=sha256:e1b25e3ad6c909f398df8921780d6a3d120d8c09466720226fc621605b6f92b1 \ - --hash=sha256:e633940f28c1e913615fd624fcdd72fdba807bf53ea6925d6a588e84e1151531 \ - --hash=sha256:e89df2958e5159b811af9ff0f92614dabf4ff617c03a4c1c6ff53bf1c399e0e1 \ - --hash=sha256:ea9f9c6034ea2d93d9147818f17c2a0860d41b71c38b9ce4d55f21b6f9165a11 \ - --hash=sha256:f645caaf0008bacf349875a974220f1f1da349c5dbe7c4ec93048cdc785a3326 \ - --hash=sha256:f8303414c7b03f794347ad062c0516cee0e15f7a612abd0ce1e25caf6ceb47df \ - --hash=sha256:fca62a8301b605b954ad2e9c3666f9d97f63872aa4efcae5492baca2056b74ab +charset-normalizer==3.2.0 \ + --hash=sha256:04e57ab9fbf9607b77f7d057974694b4f6b142da9ed4a199859d9d4d5c63fe96 \ + --hash=sha256:09393e1b2a9461950b1c9a45d5fd251dc7c6f228acab64da1c9c0165d9c7765c \ + --hash=sha256:0b87549028f680ca955556e3bd57013ab47474c3124dc069faa0b6545b6c9710 \ + --hash=sha256:1000fba1057b92a65daec275aec30586c3de2401ccdcd41f8a5c1e2c87078706 \ + --hash=sha256:1249cbbf3d3b04902ff081ffbb33ce3377fa6e4c7356f759f3cd076cc138d020 \ + --hash=sha256:1920d4ff15ce893210c1f0c0e9d19bfbecb7983c76b33f046c13a8ffbd570252 \ + --hash=sha256:193cbc708ea3aca45e7221ae58f0fd63f933753a9bfb498a3b474878f12caaad \ + --hash=sha256:1a100c6d595a7f316f1b6f01d20815d916e75ff98c27a01ae817439ea7726329 \ + --hash=sha256:1f30b48dd7fa1474554b0b0f3fdfdd4c13b5c737a3c6284d3cdc424ec0ffff3a \ + --hash=sha256:203f0c8871d5a7987be20c72442488a0b8cfd0f43b7973771640fc593f56321f \ + --hash=sha256:246de67b99b6851627d945db38147d1b209a899311b1305dd84916f2b88526c6 \ + --hash=sha256:2dee8e57f052ef5353cf608e0b4c871aee320dd1b87d351c28764fc0ca55f9f4 \ + --hash=sha256:2efb1bd13885392adfda4614c33d3b68dee4921fd0ac1d3988f8cbb7d589e72a \ + --hash=sha256:2f4ac36d8e2b4cc1aa71df3dd84ff8efbe3bfb97ac41242fbcfc053c67434f46 \ + --hash=sha256:3170c9399da12c9dc66366e9d14da8bf7147e1e9d9ea566067bbce7bb74bd9c2 \ + --hash=sha256:3b1613dd5aee995ec6d4c69f00378bbd07614702a315a2cf6c1d21461fe17c23 \ + --hash=sha256:3bb3d25a8e6c0aedd251753a79ae98a093c7e7b471faa3aa9a93a81431987ace \ + --hash=sha256:3bb7fda7260735efe66d5107fb7e6af6a7c04c7fce9b2514e04b7a74b06bf5dd \ + --hash=sha256:41b25eaa7d15909cf3ac4c96088c1f266a9a93ec44f87f1d13d4a0e86c81b982 \ + --hash=sha256:45de3f87179c1823e6d9e32156fb14c1927fcc9aba21433f088fdfb555b77c10 \ + --hash=sha256:46fb8c61d794b78ec7134a715a3e564aafc8f6b5e338417cb19fe9f57a5a9bf2 \ + --hash=sha256:48021783bdf96e3d6de03a6e39a1171ed5bd7e8bb93fc84cc649d11490f87cea \ + --hash=sha256:4957669ef390f0e6719db3613ab3a7631e68424604a7b448f079bee145da6e09 \ + --hash=sha256:5e86d77b090dbddbe78867a0275cb4df08ea195e660f1f7f13435a4649e954e5 \ + --hash=sha256:6339d047dab2780cc6220f46306628e04d9750f02f983ddb37439ca47ced7149 \ + --hash=sha256:681eb3d7e02e3c3655d1b16059fbfb605ac464c834a0c629048a30fad2b27489 \ + --hash=sha256:6c409c0deba34f147f77efaa67b8e4bb83d2f11c8806405f76397ae5b8c0d1c9 \ + --hash=sha256:7095f6fbfaa55defb6b733cfeb14efaae7a29f0b59d8cf213be4e7ca0b857b80 \ + --hash=sha256:70c610f6cbe4b9fce272c407dd9d07e33e6bf7b4aa1b7ffb6f6ded8e634e3592 \ + --hash=sha256:72814c01533f51d68702802d74f77ea026b5ec52793c791e2da806a3844a46c3 \ + --hash=sha256:7a4826ad2bd6b07ca615c74ab91f32f6c96d08f6fcc3902ceeedaec8cdc3bcd6 \ + --hash=sha256:7c70087bfee18a42b4040bb9ec1ca15a08242cf5867c58726530bdf3945672ed \ + --hash=sha256:855eafa5d5a2034b4621c74925d89c5efef61418570e5ef9b37717d9c796419c \ + --hash=sha256:8700f06d0ce6f128de3ccdbc1acaea1ee264d2caa9ca05daaf492fde7c2a7200 \ + --hash=sha256:89f1b185a01fe560bc8ae5f619e924407efca2191b56ce749ec84982fc59a32a \ + --hash=sha256:8b2c760cfc7042b27ebdb4a43a4453bd829a5742503599144d54a032c5dc7e9e \ + --hash=sha256:8c2f5e83493748286002f9369f3e6607c565a6a90425a3a1fef5ae32a36d749d \ + --hash=sha256:8e098148dd37b4ce3baca71fb394c81dc5d9c7728c95df695d2dca218edf40e6 \ + --hash=sha256:94aea8eff76ee6d1cdacb07dd2123a68283cb5569e0250feab1240058f53b623 \ + --hash=sha256:95eb302ff792e12aba9a8b8f8474ab229a83c103d74a750ec0bd1c1eea32e669 \ + --hash=sha256:9bd9b3b31adcb054116447ea22caa61a285d92e94d710aa5ec97992ff5eb7cf3 \ + --hash=sha256:9e608aafdb55eb9f255034709e20d5a83b6d60c054df0802fa9c9883d0a937aa \ + --hash=sha256:a103b3a7069b62f5d4890ae1b8f0597618f628b286b03d4bc9195230b154bfa9 \ + --hash=sha256:a386ebe437176aab38c041de1260cd3ea459c6ce5263594399880bbc398225b2 \ + --hash=sha256:a38856a971c602f98472050165cea2cdc97709240373041b69030be15047691f \ + --hash=sha256:a401b4598e5d3f4a9a811f3daf42ee2291790c7f9d74b18d75d6e21dda98a1a1 \ + --hash=sha256:a7647ebdfb9682b7bb97e2a5e7cb6ae735b1c25008a70b906aecca294ee96cf4 \ + --hash=sha256:aaf63899c94de41fe3cf934601b0f7ccb6b428c6e4eeb80da72c58eab077b19a \ + --hash=sha256:b0dac0ff919ba34d4df1b6131f59ce95b08b9065233446be7e459f95554c0dc8 \ + --hash=sha256:baacc6aee0b2ef6f3d308e197b5d7a81c0e70b06beae1f1fcacffdbd124fe0e3 \ + --hash=sha256:bf420121d4c8dce6b889f0e8e4ec0ca34b7f40186203f06a946fa0276ba54029 \ + --hash=sha256:c04a46716adde8d927adb9457bbe39cf473e1e2c2f5d0a16ceb837e5d841ad4f \ + --hash=sha256:c0b21078a4b56965e2b12f247467b234734491897e99c1d51cee628da9786959 \ + --hash=sha256:c1c76a1743432b4b60ab3358c937a3fe1341c828ae6194108a94c69028247f22 \ + --hash=sha256:c4983bf937209c57240cff65906b18bb35e64ae872da6a0db937d7b4af845dd7 \ + --hash=sha256:c4fb39a81950ec280984b3a44f5bd12819953dc5fa3a7e6fa7a80db5ee853952 \ + --hash=sha256:c57921cda3a80d0f2b8aec7e25c8aa14479ea92b5b51b6876d975d925a2ea346 \ + --hash=sha256:c8063cf17b19661471ecbdb3df1c84f24ad2e389e326ccaf89e3fb2484d8dd7e \ + --hash=sha256:ccd16eb18a849fd8dcb23e23380e2f0a354e8daa0c984b8a732d9cfaba3a776d \ + --hash=sha256:cd6dbe0238f7743d0efe563ab46294f54f9bc8f4b9bcf57c3c666cc5bc9d1299 \ + --hash=sha256:d62e51710986674142526ab9f78663ca2b0726066ae26b78b22e0f5e571238dd \ + --hash=sha256:db901e2ac34c931d73054d9797383d0f8009991e723dab15109740a63e7f902a \ + --hash=sha256:e03b8895a6990c9ab2cdcd0f2fe44088ca1c65ae592b8f795c3294af00a461c3 \ + --hash=sha256:e1c8a2f4c69e08e89632defbfabec2feb8a8d99edc9f89ce33c4b9e36ab63037 \ + --hash=sha256:e4b749b9cc6ee664a3300bb3a273c1ca8068c46be705b6c31cf5d276f8628a94 \ + --hash=sha256:e6a5bf2cba5ae1bb80b154ed68a3cfa2fa00fde979a7f50d6598d3e17d9ac20c \ + --hash=sha256:e857a2232ba53ae940d3456f7533ce6ca98b81917d47adc3c7fd55dad8fab858 \ + --hash=sha256:ee4006268ed33370957f55bf2e6f4d263eaf4dc3cfc473d1d90baff6ed36ce4a \ + --hash=sha256:eef9df1eefada2c09a5e7a40991b9fc6ac6ef20b1372abd48d2794a316dc0449 \ + --hash=sha256:f058f6963fd82eb143c692cecdc89e075fa0828db2e5b291070485390b2f1c9c \ + --hash=sha256:f25c229a6ba38a35ae6e25ca1264621cc25d4d38dca2942a7fce0b67a4efe918 \ + --hash=sha256:f2a1d0fd4242bd8643ce6f98927cf9c04540af6efa92323e9d3124f57727bfc1 \ + --hash=sha256:f7560358a6811e52e9c4d142d497f1a6e10103d3a6881f18d04dbce3729c0e2c \ + --hash=sha256:f779d3ad205f108d14e99bb3859aa7dd8e9c68874617c72354d7ecaec2a054ac \ + --hash=sha256:f87f746ee241d30d6ed93969de31e5ffd09a2961a051e60ae6bddde9ec3583aa # via requests -google-auth==2.19.1 \ - --hash=sha256:a9cfa88b3e16196845e64a3658eb953992129d13ac7337b064c6546f77c17183 \ - --hash=sha256:ea165e014c7cbd496558796b627c271aa8c18b4cba79dc1cc962b24c5efdfb85 +google-auth==2.22.0 \ + --hash=sha256:164cba9af4e6e4e40c3a4f90a1a6c12ee56f14c0b4868d1ca91b32826ab334ce \ + --hash=sha256:d61d1b40897407b574da67da1a833bdc10d5a11642566e506565d1b1a46ba873 # via # google-auth-oauthlib # tb-nightly @@ -97,81 +97,77 @@ google-auth-oauthlib==1.0.0 \ --hash=sha256:95880ca704928c300f48194d1770cf5b1462835b6e49db61445a520f793fd5fb \ --hash=sha256:e375064964820b47221a7e1b7ee1fd77051b6323c3f9e3e19785f78ab67ecfc5 # via tb-nightly -grpcio==1.54.2 \ - --hash=sha256:0212e2f7fdf7592e4b9d365087da30cb4d71e16a6f213120c89b4f8fb35a3ab3 \ - --hash=sha256:09d4bfd84686cd36fd11fd45a0732c7628308d094b14d28ea74a81db0bce2ed3 \ - --hash=sha256:1e623e0cf99a0ac114f091b3083a1848dbc64b0b99e181473b5a4a68d4f6f821 \ - --hash=sha256:2288d76e4d4aa7ef3fe7a73c1c470b66ea68e7969930e746a8cd8eca6ef2a2ea \ - --hash=sha256:2296356b5c9605b73ed6a52660b538787094dae13786ba53080595d52df13a98 \ - --hash=sha256:2a1e601ee31ef30a9e2c601d0867e236ac54c922d32ed9f727b70dd5d82600d5 \ - --hash=sha256:2be88c081e33f20630ac3343d8ad9f1125f32987968e9c8c75c051c9800896e8 \ - --hash=sha256:33d40954199bddbb6a78f8f6f2b2082660f381cd2583ec860a6c2fa7c8400c08 \ - --hash=sha256:40e1cbf69d6741b40f750f3cccc64326f927ac6145a9914d33879e586002350c \ - --hash=sha256:46a057329938b08e5f0e12ea3d7aed3ecb20a0c34c4a324ef34e00cecdb88a12 \ - --hash=sha256:4864f99aac207e3e45c5e26c6cbb0ad82917869abc2f156283be86c05286485c \ - --hash=sha256:4c44e1a765b31e175c391f22e8fc73b2a2ece0e5e6ff042743d8109b5d2eff9f \ - --hash=sha256:4cb283f630624ebb16c834e5ac3d7880831b07cbe76cb08ab7a271eeaeb8943e \ - --hash=sha256:5008964885e8d23313c8e5ea0d44433be9bfd7e24482574e8cc43c02c02fc796 \ - --hash=sha256:50a9f075eeda5097aa9a182bb3877fe1272875e45370368ac0ee16ab9e22d019 \ - --hash=sha256:51630c92591d6d3fe488a7c706bd30a61594d144bac7dee20c8e1ce78294f474 \ - --hash=sha256:5cc928cfe6c360c1df636cf7991ab96f059666ac7b40b75a769410cc6217df9c \ - --hash=sha256:61f7203e2767800edee7a1e1040aaaf124a35ce0c7fe0883965c6b762defe598 \ - --hash=sha256:66233ccd2a9371158d96e05d082043d47dadb18cbb294dc5accfdafc2e6b02a7 \ - --hash=sha256:70fcac7b94f4c904152809a050164650ac81c08e62c27aa9f156ac518029ebbe \ - --hash=sha256:714242ad0afa63a2e6dabd522ae22e1d76e07060b5af2ddda5474ba4f14c2c94 \ - --hash=sha256:782f4f8662a2157c4190d0f99eaaebc602899e84fb1e562a944e5025929e351c \ - --hash=sha256:7fc2b4edb938c8faa4b3c3ea90ca0dd89b7565a049e8e4e11b77e60e4ed2cc05 \ - --hash=sha256:881d058c5ccbea7cc2c92085a11947b572498a27ef37d3eef4887f499054dca8 \ - --hash=sha256:89dde0ac72a858a44a2feb8e43dc68c0c66f7857a23f806e81e1b7cc7044c9cf \ - --hash=sha256:8cdbcbd687e576d48f7886157c95052825ca9948c0ed2afdc0134305067be88b \ - --hash=sha256:8d6192c37a30a115f4663592861f50e130caed33efc4eec24d92ec881c92d771 \ - --hash=sha256:96a41817d2c763b1d0b32675abeb9179aa2371c72aefdf74b2d2b99a1b92417b \ - --hash=sha256:9bdbb7624d65dc0ed2ed8e954e79ab1724526f09b1efa88dcd9a1815bf28be5f \ - --hash=sha256:9bf88004fe086c786dc56ef8dd6cb49c026833fdd6f42cb853008bce3f907148 \ - --hash=sha256:a08920fa1a97d4b8ee5db2f31195de4a9def1a91bc003544eb3c9e6b8977960a \ - --hash=sha256:a2f5a1f1080ccdc7cbaf1171b2cf384d852496fe81ddedeb882d42b85727f610 \ - --hash=sha256:b04202453941a63b36876a7172b45366dc0cde10d5fd7855c0f4a4e673c0357a \ - --hash=sha256:b38b3de8cff5bc70f8f9c615f51b48eff7313fc9aca354f09f81b73036e7ddfa \ - --hash=sha256:b52d00d1793d290c81ad6a27058f5224a7d5f527867e5b580742e1bd211afeee \ - --hash=sha256:b74ae837368cfffeb3f6b498688a123e6b960951be4dec0e869de77e7fa0439e \ - --hash=sha256:be48496b0e00460717225e7680de57c38be1d8629dc09dadcd1b3389d70d942b \ - --hash=sha256:c0e3155fc5335ec7b3b70f15230234e529ca3607b20a562b6c75fb1b1218874c \ - --hash=sha256:c2392f5b5d84b71d853918687d806c1aa4308109e5ca158a16e16a6be71041eb \ - --hash=sha256:c72956972e4b508dd39fdc7646637a791a9665b478e768ffa5f4fe42123d5de1 \ - --hash=sha256:dc80c9c6b608bf98066a038e0172013a49cfa9a08d53335aefefda2c64fc68f4 \ - --hash=sha256:e416c8baf925b5a1aff31f7f5aecc0060b25d50cce3a5a7255dc5cf2f1d4e5eb \ - --hash=sha256:f8da84bbc61a4e92af54dc96344f328e5822d574f767e9b08e1602bb5ddc254a \ - --hash=sha256:f900ed4ad7a0f1f05d35f955e0943944d5a75f607a836958c6b8ab2a81730ef2 \ - --hash=sha256:fd6c6c29717724acf9fc1847c4515d57e4dc12762452457b9cb37461f30a81bb +grpcio==1.56.2 \ + --hash=sha256:06e84ad9ae7668a109e970c7411e7992751a116494cba7c4fb877656527f9a57 \ + --hash=sha256:0ff789ae7d8ddd76d2ac02e7d13bfef6fc4928ac01e1dcaa182be51b6bcc0aaa \ + --hash=sha256:10954662f77dc36c9a1fb5cc4a537f746580d6b5734803be1e587252682cda8d \ + --hash=sha256:139f66656a762572ae718fa0d1f2dce47c05e9fbf7a16acd704c354405b97df9 \ + --hash=sha256:1c31e52a04e62c8577a7bf772b3e7bed4df9c9e0dd90f92b6ffa07c16cab63c9 \ + --hash=sha256:33971197c47965cc1d97d78d842163c283e998223b151bab0499b951fd2c0b12 \ + --hash=sha256:345356b307cce5d14355e8e055b4ca5f99bc857c33a3dc1ddbc544fca9cd0475 \ + --hash=sha256:373b48f210f43327a41e397391715cd11cfce9ded2fe76a5068f9bacf91cc226 \ + --hash=sha256:3ccb621749a81dc7755243665a70ce45536ec413ef5818e013fe8dfbf5aa497b \ + --hash=sha256:42a3bbb2bc07aef72a7d97e71aabecaf3e4eb616d39e5211e2cfe3689de860ca \ + --hash=sha256:42e63904ee37ae46aa23de50dac8b145b3596f43598fa33fe1098ab2cbda6ff5 \ + --hash=sha256:4eb37dd8dd1aa40d601212afa27ca5be255ba792e2e0b24d67b8af5e012cdb7d \ + --hash=sha256:51173e8fa6d9a2d85c14426bdee5f5c4a0654fd5fddcc21fe9d09ab0f6eb8b35 \ + --hash=sha256:5144feb20fe76e73e60c7d73ec3bf54f320247d1ebe737d10672480371878b48 \ + --hash=sha256:5344be476ac37eb9c9ad09c22f4ea193c1316bf074f1daf85bddb1b31fda5116 \ + --hash=sha256:6108e5933eb8c22cd3646e72d5b54772c29f57482fd4c41a0640aab99eb5071d \ + --hash=sha256:6a007a541dff984264981fbafeb052bfe361db63578948d857907df9488d8774 \ + --hash=sha256:6ee26e9dfb3996aff7c870f09dc7ad44a5f6732b8bdb5a5f9905737ac6fd4ef1 \ + --hash=sha256:750de923b456ca8c0f1354d6befca45d1f3b3a789e76efc16741bd4132752d95 \ + --hash=sha256:7c5ede2e2558f088c49a1ddda19080e4c23fb5d171de80a726b61b567e3766ed \ + --hash=sha256:830215173ad45d670140ff99aac3b461f9be9a6b11bee1a17265aaaa746a641a \ + --hash=sha256:8391cea5ce72f4a12368afd17799474015d5d3dc00c936a907eb7c7eaaea98a5 \ + --hash=sha256:8940d6de7068af018dfa9a959a3510e9b7b543f4c405e88463a1cbaa3b2b379a \ + --hash=sha256:89a49cc5ad08a38b6141af17e00d1dd482dc927c7605bc77af457b5a0fca807c \ + --hash=sha256:900bc0096c2ca2d53f2e5cebf98293a7c32f532c4aeb926345e9747452233950 \ + --hash=sha256:97e0efaebbfd222bcaac2f1735c010c1d3b167112d9d237daebbeedaaccf3d1d \ + --hash=sha256:9e04d4e4cfafa7c5264e535b5d28e786f0571bea609c3f0aaab13e891e933e9c \ + --hash=sha256:a4c60abd950d6de3e4f1ddbc318075654d275c29c846ab6a043d6ed2c52e4c8c \ + --hash=sha256:a6ff459dac39541e6a2763a4439c4ca6bc9ecb4acc05a99b79246751f9894756 \ + --hash=sha256:a72797549935c9e0b9bc1def1768c8b5a709538fa6ab0678e671aec47ebfd55e \ + --hash=sha256:af4063ef2b11b96d949dccbc5a987272f38d55c23c4c01841ea65a517906397f \ + --hash=sha256:b975b85d1d5efc36cf8b237c5f3849b64d1ba33d6282f5e991f28751317504a1 \ + --hash=sha256:bf0b9959e673505ee5869950642428046edb91f99942607c2ecf635f8a4b31c9 \ + --hash=sha256:c0c85c5cbe8b30a32fa6d802588d55ffabf720e985abe9590c7c886919d875d4 \ + --hash=sha256:c3f3237a57e42f79f1e560726576aedb3a7ef931f4e3accb84ebf6acc485d316 \ + --hash=sha256:c3fa3ab0fb200a2c66493828ed06ccd1a94b12eddbfb985e7fd3e5723ff156c6 \ + --hash=sha256:c435f5ce1705de48e08fcbcfaf8aee660d199c90536e3e06f2016af7d6a938dd \ + --hash=sha256:c90da4b124647547a68cf2f197174ada30c7bb9523cb976665dfd26a9963d328 \ + --hash=sha256:cbdf2c498e077282cd427cfd88bdce4668019791deef0be8155385ab2ba7837f \ + --hash=sha256:d1fbad1f9077372b6587ec589c1fc120b417b6c8ad72d3e3cc86bbbd0a3cee93 \ + --hash=sha256:d39f5d4af48c138cb146763eda14eb7d8b3ccbbec9fe86fb724cd16e0e914c64 \ + --hash=sha256:ddb4a6061933bd9332b74eac0da25f17f32afa7145a33a0f9711ad74f924b1b8 \ + --hash=sha256:ded637176addc1d3eef35331c39acc598bac550d213f0a1bedabfceaa2244c87 \ + --hash=sha256:f20fd21f7538f8107451156dd1fe203300b79a9ddceba1ee0ac8132521a008ed \ + --hash=sha256:fda2783c12f553cdca11c08e5af6eecbd717280dc8fbe28a110897af1c15a88c # via # -r ./requirements.in # tb-nightly -h5py==3.8.0 \ - --hash=sha256:03890b1c123d024fb0239a3279737d5432498c1901c354f8b10d8221d1d16235 \ - --hash=sha256:0fef76e10b9216657fa37e7edff6d8be0709b25bd5066474c229b56cf0098df9 \ - --hash=sha256:26ffc344ec9984d2cd3ca0265007299a8bac8d85c1ad48f4639d8d3aed2af171 \ - --hash=sha256:290e00fa2de74a10688d1bac98d5a9cdd43f14f58e562c580b5b3dfbd358ecae \ - --hash=sha256:33b15aae79e9147aebe1d0e54099cbcde8d65e3e227cd5b59e49b1272aa0e09d \ - --hash=sha256:36761693efbe53df179627a775476dcbc37727d6e920958277a7efbc18f1fb73 \ - --hash=sha256:377865821fe80ad984d003723d6f8890bd54ceeb5981b43c0313b9df95411b30 \ - --hash=sha256:49bc857635f935fa30e92e61ac1e87496df8f260a6945a3235e43a9890426866 \ - --hash=sha256:4a506fc223def428f4329e7e1f9fe1c8c593eab226e7c0942c8d75308ad49950 \ - --hash=sha256:533d7dad466ddb7e3b30af274b630eb7c1a6e4ddf01d1c373a0334dc2152110a \ - --hash=sha256:5fd2252d1fc364ba0e93dd0b7089f4906b66805cb4e6aca7fa8874ac08649647 \ - --hash=sha256:6fead82f0c4000cf38d53f9c030780d81bfa0220218aee13b90b7701c937d95f \ - --hash=sha256:7f3350fc0a8407d668b13247861c2acd23f7f5fe7d060a3ad9b0820f5fcbcae0 \ - --hash=sha256:8f55d9c6c84d7d09c79fb85979e97b81ec6071cc776a97eb6b96f8f6ec767323 \ - --hash=sha256:98a240cd4c1bfd568aaa52ec42d263131a2582dab82d74d3d42a0d954cac12be \ - --hash=sha256:9f6f6ffadd6bfa9b2c5b334805eb4b19ca0a5620433659d8f7fb86692c40a359 \ - --hash=sha256:b685453e538b2b5934c58a644ac3f3b3d0cec1a01b6fb26d57388e9f9b674ad0 \ - --hash=sha256:b7865de06779b14d98068da387333ad9bf2756b5b579cc887fac169bc08f87c3 \ - --hash=sha256:bacaa1c16810dd2b3e4417f8e730971b7c4d53d234de61fe4a918db78e80e1e4 \ - --hash=sha256:bae730580ae928de409d63cbe4fdca4c82c3ad2bed30511d19d34e995d63c77e \ - --hash=sha256:c3389b63222b1c7a158bb7fe69d11ca00066740ec5574596d47a2fe5317f563a \ - --hash=sha256:c873ba9fd4fa875ad62ce0e4891725e257a8fe7f5abdbc17e51a5d54819be55c \ - --hash=sha256:db03e3f2c716205fbdabb34d0848459840585225eb97b4f08998c743821ca323 \ - --hash=sha256:f47f757d1b76f0ecb8aa0508ec8d1b390df67a8b67ee2515dc1b046f3a1596ea \ - --hash=sha256:f891b17e3a3e974e93f9e34e7cca9f530806543571ce078998676a555837d91d +h5py==3.9.0 \ + --hash=sha256:12aa556d540f11a2cae53ea7cfb94017353bd271fb3962e1296b342f6550d1b8 \ + --hash=sha256:23e74b878bbe1653ab34ca49b83cac85529cd0b36b9d625516c5830cc5ca2eac \ + --hash=sha256:36408f8c62f50007d14e000f9f3acf77e103b9e932c114cbe52a3089e50ebf94 \ + --hash=sha256:3f457089c5d524b7998e3649bc63240679b8fb0a3859ea53bbb06841f3d755f1 \ + --hash=sha256:54f01202cdea754ab4227dd27014bdbd561a4bbe4b631424fd812f7c2ce9c6ac \ + --hash=sha256:551e358db05a874a0f827b22e95b30092f2303edc4b91bb62ad2f10e0236e1a0 \ + --hash=sha256:64acceaf6aff92af091a4b83f6dee3cf8d3061f924a6bb3a33eb6c4658a8348b \ + --hash=sha256:6822a814b9d8b8363ff102f76ea8d026f0ca25850bb579d85376029ee3e73b93 \ + --hash=sha256:78e44686334cbbf2dd21d9df15823bc38663f27a3061f6a032c68a3e30c47bf7 \ + --hash=sha256:79bbca34696c6f9eeeb36a91776070c49a060b2879828e2c8fa6c58b8ed10dd1 \ + --hash=sha256:804c7fb42a34c8ab3a3001901c977a5c24d2e9c586a0f3e7c0a389130b4276fc \ + --hash=sha256:8d9492391ff5c3c80ec30ae2fe82a3f0efd1e750833739c25b0d090e3be1b095 \ + --hash=sha256:95f7a745efd0d56076999b52e8da5fad5d30823bac98b59c68ae75588d09991a \ + --hash=sha256:9da9e7e63376c32704e37ad4cea2dceae6964cee0d8515185b3ab9cbd6b947bc \ + --hash=sha256:a4e20897c88759cbcbd38fb45b507adc91af3e0f67722aa302d71f02dd44d286 \ + --hash=sha256:a6284061f3214335e1eec883a6ee497dbe7a79f19e6a57fed2dd1f03acd5a8cb \ + --hash=sha256:d97409e17915798029e297a84124705c8080da901307ea58f29234e09b073ddc \ + --hash=sha256:dbf5225543ca35ce9f61c950b73899a82be7ba60d58340e76d0bd42bf659235a \ + --hash=sha256:e604db6521c1e367c6bd7fad239c847f53cc46646f2d2651372d05ae5e95f817 \ + --hash=sha256:eb7bdd5e601dd1739698af383be03f3dad0465fe67184ebd5afca770f50df9d6 \ + --hash=sha256:f68b41efd110ce9af1cbe6fa8af9f4dcbadace6db972d30828b911949e28fadd # via -r ./requirements.in idna==3.4 \ --hash=sha256:814f528e8dead7d329833b91c5faa87d60bf71824cd12a7530b5526063d02cb4 \ @@ -180,12 +176,12 @@ idna==3.4 \ jax==0.4.7 \ --hash=sha256:5e7002d74db25f97c99b979d4ba1233b1ef26e1597e5fc468ad11d1c8a9dc4f8 # via -r ./requirements.in -keras-nightly==2.14.0.dev2023061207 \ - --hash=sha256:210671f010a0b21a5507be86b8e9e909f81b9f321cd3c51e1efdfdd41061919f \ - --hash=sha256:ad869b2bce863e111e4a57c7f5785f56097d93f683b5315df7f59917be1fa279 +keras-nightly==2.14.0.dev2023072407 \ + --hash=sha256:60ca7fae3ad903eeff858f45ddf9dc0dc395cdf41a07c26e807cb7f07e955441 \ + --hash=sha256:9eb387e3488f5ca87a4686b1ea93b8bd85b36d1006934e130f02185daf192ba5 # via -r ./requirements.in -lit==16.0.5.post0 \ - --hash=sha256:71745d9e58dad3717735d27e2a9cca0e9ca6861d067da73c307e02fd38c98479 +lit==16.0.6 \ + --hash=sha256:84623c9c23b6b14763d637f4e63e6b721b3446ada40bf7001d8fee70b8e77a9a # via -r ./requirements.in markdown==3.4.3 \ --hash=sha256:065fd4df22da73a625f14890dd77eb8040edcbd68794bcd35943be14490608b2 \ @@ -315,20 +311,20 @@ portpicker==1.5.2 \ --hash=sha256:01113f51c3cc63290a44dd7ae6e3eb9f8fe1b8a1f9d7988a897944230c39cd52 \ --hash=sha256:c55683ad725f5c00a41bc7db0225223e8be024b1fa564d039ed3390e4fd48fb3 # via -r ./requirements.in -protobuf==4.23.2 \ - --hash=sha256:09310bce43353b46d73ba7e3bca78273b9bc50349509b9698e64d288c6372c2a \ - --hash=sha256:20874e7ca4436f683b64ebdbee2129a5a2c301579a67d1a7dda2cdf62fb7f5f7 \ - --hash=sha256:25e3370eda26469b58b602e29dff069cfaae8eaa0ef4550039cc5ef8dc004511 \ - --hash=sha256:281342ea5eb631c86697e1e048cb7e73b8a4e85f3299a128c116f05f5c668f8f \ - --hash=sha256:384dd44cb4c43f2ccddd3645389a23ae61aeb8cfa15ca3a0f60e7c3ea09b28b3 \ - --hash=sha256:54a533b971288af3b9926e53850c7eb186886c0c84e61daa8444385a4720297f \ - --hash=sha256:6c081863c379bb1741be8f8193e893511312b1d7329b4a75445d1ea9955be69e \ - --hash=sha256:86df87016d290143c7ce3be3ad52d055714ebaebb57cc659c387e76cfacd81aa \ - --hash=sha256:8da6070310d634c99c0db7df48f10da495cc283fd9e9234877f0cd182d43ab7f \ - --hash=sha256:b2cfab63a230b39ae603834718db74ac11e52bccaaf19bf20f5cce1a84cf76df \ - --hash=sha256:c52cfcbfba8eb791255edd675c1fe6056f723bf832fa67f0442218f8817c076e \ - --hash=sha256:ce744938406de1e64b91410f473736e815f28c3b71201302612a68bf01517fea \ - --hash=sha256:efabbbbac1ab519a514579ba9ec52f006c28ae19d97915951f69fa70da2c9e91 +protobuf==4.23.4 \ + --hash=sha256:0a5759f5696895de8cc913f084e27fd4125e8fb0914bb729a17816a33819f474 \ + --hash=sha256:351cc90f7d10839c480aeb9b870a211e322bf05f6ab3f55fcb2f51331f80a7d2 \ + --hash=sha256:5fea3c64d41ea5ecf5697b83e41d09b9589e6f20b677ab3c48e5f242d9b7897b \ + --hash=sha256:6dd9b9940e3f17077e820b75851126615ee38643c2c5332aa7a359988820c720 \ + --hash=sha256:7b19b6266d92ca6a2a87effa88ecc4af73ebc5cfde194dc737cf8ef23a9a3b12 \ + --hash=sha256:8547bf44fe8cec3c69e3042f5c4fb3e36eb2a7a013bb0a44c018fc1e427aafbd \ + --hash=sha256:9053df6df8e5a76c84339ee4a9f5a2661ceee4a0dab019e8663c50ba324208b0 \ + --hash=sha256:c3e0939433c40796ca4cfc0fac08af50b00eb66a40bbbc5dee711998fb0bbc1e \ + --hash=sha256:ccd9430c0719dce806b93f89c91de7977304729e55377f872a92465d548329a9 \ + --hash=sha256:e1c915778d8ced71e26fcf43c0866d7499891bca14c4368448a82edc61fdbc70 \ + --hash=sha256:e9d0be5bf34b275b9f87ba7407796556abeeba635455d036c7351f7c183ef8ff \ + --hash=sha256:effeac51ab79332d44fba74660d40ae79985901ac21bca408f8dc335a81aa597 \ + --hash=sha256:fee88269a090ada09ca63551bf2f573eb2424035bcf2cb1b121895b01a46594a # via tb-nightly psutil==5.9.5 \ --hash=sha256:104a5cc0e31baa2bcf67900be36acde157756b9c44017b86b2c049f11957887d \ @@ -406,16 +402,16 @@ six==1.16.0 \ --hash=sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926 \ --hash=sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254 # via google-auth -tb-nightly==2.14.0a20230612 \ - --hash=sha256:1ad7d57386f6103df69d8f3083552d185d33c5a07fb18e691b60236d9b4dc679 +tb-nightly==2.14.0a20230724 \ + --hash=sha256:eed70d9ddca771b938d7ee90b90673bdbfcd4da8bf292badab2088f26c4612fe # via -r ./requirements.in -tensorboard-data-server==0.7.0 \ - --hash=sha256:64aa1be7c23e80b1a42c13b686eb0875bb70f5e755f4d2b8de5c1d880cf2267f \ - --hash=sha256:753d4214799b31da7b6d93837959abebbc6afa86e69eacf1e9a317a48daa31eb \ - --hash=sha256:eb7fa518737944dbf4f0cf83c2e40a7ac346bf91be2e6a0215de98be74e85454 +tensorboard-data-server==0.7.1 \ + --hash=sha256:255c02b7f5b03dd5c0a88c928e563441ff39e1d4b4a234cdbe09f016e53d9594 \ + --hash=sha256:9938bd39f5041797b33921066fba0eab03a0dd10d1887a05e62ae58841ad4c3f \ + --hash=sha256:be8d016a1aa394e6198280d4a3dc37898f56467310c5f5e617cac10a783e055a # via tb-nightly -tf-estimator-nightly==2.14.0.dev2023060108 \ - --hash=sha256:09f63c090f29b74ebb36076c0ef1105c8b0358c6920847f312926012926ff7ce +tf-estimator-nightly==2.14.0.dev2023072408 \ + --hash=sha256:80b79192e643e923a2e075837785626bce5c35d2d2d34c03eee4fbcea8afb884 # via -r ./requirements.in urllib3==1.26.16 \ --hash=sha256:8d36afa7616d8ab714608411b4a3b13e58f463aee519024578e062e141dce20f \ diff --git a/requirements_lock_3_9.txt b/requirements_lock_3_9.txt index 52dc137e39f429..0bda44766a13e9 100644 --- a/requirements_lock_3_9.txt +++ b/requirements_lock_3_9.txt @@ -6,90 +6,90 @@ cachetools==5.3.1 \ --hash=sha256:95ef631eeaea14ba2e36f06437f36463aac3a096799e876ee55e5cdccb102590 \ --hash=sha256:dce83f2d9b4e1f732a8cd44af8e8fab2dbe46201467fc98b3ef8f269092bf62b # via google-auth -certifi==2023.5.7 \ - --hash=sha256:0f0d56dc5a6ad56fd4ba36484d6cc34451e1c6548c61daad8c320169f91eddc7 \ - --hash=sha256:c6c2e98f5c7869efca1f8916fed228dd91539f9f1b444c314c06eef02980c716 +certifi==2023.7.22 \ + --hash=sha256:539cc1d13202e33ca466e88b2807e29f4c13049d6d87031a3c110744495cb082 \ + --hash=sha256:92d6037539857d8206b8f6ae472e8b77db8058fec5937a1ef3f54304089edbb9 # via requests -charset-normalizer==3.1.0 \ - --hash=sha256:04afa6387e2b282cf78ff3dbce20f0cc071c12dc8f685bd40960cc68644cfea6 \ - --hash=sha256:04eefcee095f58eaabe6dc3cc2262f3bcd776d2c67005880894f447b3f2cb9c1 \ - --hash=sha256:0be65ccf618c1e7ac9b849c315cc2e8a8751d9cfdaa43027d4f6624bd587ab7e \ - --hash=sha256:0c95f12b74681e9ae127728f7e5409cbbef9cd914d5896ef238cc779b8152373 \ - --hash=sha256:0ca564606d2caafb0abe6d1b5311c2649e8071eb241b2d64e75a0d0065107e62 \ - --hash=sha256:10c93628d7497c81686e8e5e557aafa78f230cd9e77dd0c40032ef90c18f2230 \ - --hash=sha256:11d117e6c63e8f495412d37e7dc2e2fff09c34b2d09dbe2bee3c6229577818be \ - --hash=sha256:11d3bcb7be35e7b1bba2c23beedac81ee893ac9871d0ba79effc7fc01167db6c \ - --hash=sha256:12a2b561af122e3d94cdb97fe6fb2bb2b82cef0cdca131646fdb940a1eda04f0 \ - --hash=sha256:12d1a39aa6b8c6f6248bb54550efcc1c38ce0d8096a146638fd4738e42284448 \ - --hash=sha256:1435ae15108b1cb6fffbcea2af3d468683b7afed0169ad718451f8db5d1aff6f \ - --hash=sha256:1c60b9c202d00052183c9be85e5eaf18a4ada0a47d188a83c8f5c5b23252f649 \ - --hash=sha256:1e8fcdd8f672a1c4fc8d0bd3a2b576b152d2a349782d1eb0f6b8e52e9954731d \ - --hash=sha256:20064ead0717cf9a73a6d1e779b23d149b53daf971169289ed2ed43a71e8d3b0 \ - --hash=sha256:21fa558996782fc226b529fdd2ed7866c2c6ec91cee82735c98a197fae39f706 \ - --hash=sha256:22908891a380d50738e1f978667536f6c6b526a2064156203d418f4856d6e86a \ - --hash=sha256:3160a0fd9754aab7d47f95a6b63ab355388d890163eb03b2d2b87ab0a30cfa59 \ - --hash=sha256:322102cdf1ab682ecc7d9b1c5eed4ec59657a65e1c146a0da342b78f4112db23 \ - --hash=sha256:34e0a2f9c370eb95597aae63bf85eb5e96826d81e3dcf88b8886012906f509b5 \ - --hash=sha256:3573d376454d956553c356df45bb824262c397c6e26ce43e8203c4c540ee0acb \ - --hash=sha256:3747443b6a904001473370d7810aa19c3a180ccd52a7157aacc264a5ac79265e \ - --hash=sha256:38e812a197bf8e71a59fe55b757a84c1f946d0ac114acafaafaf21667a7e169e \ - --hash=sha256:3a06f32c9634a8705f4ca9946d667609f52cf130d5548881401f1eb2c39b1e2c \ - --hash=sha256:3a5fc78f9e3f501a1614a98f7c54d3969f3ad9bba8ba3d9b438c3bc5d047dd28 \ - --hash=sha256:3d9098b479e78c85080c98e1e35ff40b4a31d8953102bb0fd7d1b6f8a2111a3d \ - --hash=sha256:3dc5b6a8ecfdc5748a7e429782598e4f17ef378e3e272eeb1340ea57c9109f41 \ - --hash=sha256:4155b51ae05ed47199dc5b2a4e62abccb274cee6b01da5b895099b61b1982974 \ - --hash=sha256:49919f8400b5e49e961f320c735388ee686a62327e773fa5b3ce6721f7e785ce \ - --hash=sha256:53d0a3fa5f8af98a1e261de6a3943ca631c526635eb5817a87a59d9a57ebf48f \ - --hash=sha256:5f008525e02908b20e04707a4f704cd286d94718f48bb33edddc7d7b584dddc1 \ - --hash=sha256:628c985afb2c7d27a4800bfb609e03985aaecb42f955049957814e0491d4006d \ - --hash=sha256:65ed923f84a6844de5fd29726b888e58c62820e0769b76565480e1fdc3d062f8 \ - --hash=sha256:6734e606355834f13445b6adc38b53c0fd45f1a56a9ba06c2058f86893ae8017 \ - --hash=sha256:6baf0baf0d5d265fa7944feb9f7451cc316bfe30e8df1a61b1bb08577c554f31 \ - --hash=sha256:6f4f4668e1831850ebcc2fd0b1cd11721947b6dc7c00bf1c6bd3c929ae14f2c7 \ - --hash=sha256:6f5c2e7bc8a4bf7c426599765b1bd33217ec84023033672c1e9a8b35eaeaaaf8 \ - --hash=sha256:6f6c7a8a57e9405cad7485f4c9d3172ae486cfef1344b5ddd8e5239582d7355e \ - --hash=sha256:7381c66e0561c5757ffe616af869b916c8b4e42b367ab29fedc98481d1e74e14 \ - --hash=sha256:73dc03a6a7e30b7edc5b01b601e53e7fc924b04e1835e8e407c12c037e81adbd \ - --hash=sha256:74db0052d985cf37fa111828d0dd230776ac99c740e1a758ad99094be4f1803d \ - --hash=sha256:75f2568b4189dda1c567339b48cba4ac7384accb9c2a7ed655cd86b04055c795 \ - --hash=sha256:78cacd03e79d009d95635e7d6ff12c21eb89b894c354bd2b2ed0b4763373693b \ - --hash=sha256:80d1543d58bd3d6c271b66abf454d437a438dff01c3e62fdbcd68f2a11310d4b \ - --hash=sha256:830d2948a5ec37c386d3170c483063798d7879037492540f10a475e3fd6f244b \ - --hash=sha256:891cf9b48776b5c61c700b55a598621fdb7b1e301a550365571e9624f270c203 \ - --hash=sha256:8f25e17ab3039b05f762b0a55ae0b3632b2e073d9c8fc88e89aca31a6198e88f \ - --hash=sha256:9a3267620866c9d17b959a84dd0bd2d45719b817245e49371ead79ed4f710d19 \ - --hash=sha256:a04f86f41a8916fe45ac5024ec477f41f886b3c435da2d4e3d2709b22ab02af1 \ - --hash=sha256:aaf53a6cebad0eae578f062c7d462155eada9c172bd8c4d250b8c1d8eb7f916a \ - --hash=sha256:abc1185d79f47c0a7aaf7e2412a0eb2c03b724581139193d2d82b3ad8cbb00ac \ - --hash=sha256:ac0aa6cd53ab9a31d397f8303f92c42f534693528fafbdb997c82bae6e477ad9 \ - --hash=sha256:ac3775e3311661d4adace3697a52ac0bab17edd166087d493b52d4f4f553f9f0 \ - --hash=sha256:b06f0d3bf045158d2fb8837c5785fe9ff9b8c93358be64461a1089f5da983137 \ - --hash=sha256:b116502087ce8a6b7a5f1814568ccbd0e9f6cfd99948aa59b0e241dc57cf739f \ - --hash=sha256:b82fab78e0b1329e183a65260581de4375f619167478dddab510c6c6fb04d9b6 \ - --hash=sha256:bd7163182133c0c7701b25e604cf1611c0d87712e56e88e7ee5d72deab3e76b5 \ - --hash=sha256:c36bcbc0d5174a80d6cccf43a0ecaca44e81d25be4b7f90f0ed7bcfbb5a00909 \ - --hash=sha256:c3af8e0f07399d3176b179f2e2634c3ce9c1301379a6b8c9c9aeecd481da494f \ - --hash=sha256:c84132a54c750fda57729d1e2599bb598f5fa0344085dbde5003ba429a4798c0 \ - --hash=sha256:cb7b2ab0188829593b9de646545175547a70d9a6e2b63bf2cd87a0a391599324 \ - --hash=sha256:cca4def576f47a09a943666b8f829606bcb17e2bc2d5911a46c8f8da45f56755 \ - --hash=sha256:cf6511efa4801b9b38dc5546d7547d5b5c6ef4b081c60b23e4d941d0eba9cbeb \ - --hash=sha256:d16fd5252f883eb074ca55cb622bc0bee49b979ae4e8639fff6ca3ff44f9f854 \ - --hash=sha256:d2686f91611f9e17f4548dbf050e75b079bbc2a82be565832bc8ea9047b61c8c \ - --hash=sha256:d7fc3fca01da18fbabe4625d64bb612b533533ed10045a2ac3dd194bfa656b60 \ - --hash=sha256:dd5653e67b149503c68c4018bf07e42eeed6b4e956b24c00ccdf93ac79cdff84 \ - --hash=sha256:de5695a6f1d8340b12a5d6d4484290ee74d61e467c39ff03b39e30df62cf83a0 \ - --hash=sha256:e0ac8959c929593fee38da1c2b64ee9778733cdf03c482c9ff1d508b6b593b2b \ - --hash=sha256:e1b25e3ad6c909f398df8921780d6a3d120d8c09466720226fc621605b6f92b1 \ - --hash=sha256:e633940f28c1e913615fd624fcdd72fdba807bf53ea6925d6a588e84e1151531 \ - --hash=sha256:e89df2958e5159b811af9ff0f92614dabf4ff617c03a4c1c6ff53bf1c399e0e1 \ - --hash=sha256:ea9f9c6034ea2d93d9147818f17c2a0860d41b71c38b9ce4d55f21b6f9165a11 \ - --hash=sha256:f645caaf0008bacf349875a974220f1f1da349c5dbe7c4ec93048cdc785a3326 \ - --hash=sha256:f8303414c7b03f794347ad062c0516cee0e15f7a612abd0ce1e25caf6ceb47df \ - --hash=sha256:fca62a8301b605b954ad2e9c3666f9d97f63872aa4efcae5492baca2056b74ab +charset-normalizer==3.2.0 \ + --hash=sha256:04e57ab9fbf9607b77f7d057974694b4f6b142da9ed4a199859d9d4d5c63fe96 \ + --hash=sha256:09393e1b2a9461950b1c9a45d5fd251dc7c6f228acab64da1c9c0165d9c7765c \ + --hash=sha256:0b87549028f680ca955556e3bd57013ab47474c3124dc069faa0b6545b6c9710 \ + --hash=sha256:1000fba1057b92a65daec275aec30586c3de2401ccdcd41f8a5c1e2c87078706 \ + --hash=sha256:1249cbbf3d3b04902ff081ffbb33ce3377fa6e4c7356f759f3cd076cc138d020 \ + --hash=sha256:1920d4ff15ce893210c1f0c0e9d19bfbecb7983c76b33f046c13a8ffbd570252 \ + --hash=sha256:193cbc708ea3aca45e7221ae58f0fd63f933753a9bfb498a3b474878f12caaad \ + --hash=sha256:1a100c6d595a7f316f1b6f01d20815d916e75ff98c27a01ae817439ea7726329 \ + --hash=sha256:1f30b48dd7fa1474554b0b0f3fdfdd4c13b5c737a3c6284d3cdc424ec0ffff3a \ + --hash=sha256:203f0c8871d5a7987be20c72442488a0b8cfd0f43b7973771640fc593f56321f \ + --hash=sha256:246de67b99b6851627d945db38147d1b209a899311b1305dd84916f2b88526c6 \ + --hash=sha256:2dee8e57f052ef5353cf608e0b4c871aee320dd1b87d351c28764fc0ca55f9f4 \ + --hash=sha256:2efb1bd13885392adfda4614c33d3b68dee4921fd0ac1d3988f8cbb7d589e72a \ + --hash=sha256:2f4ac36d8e2b4cc1aa71df3dd84ff8efbe3bfb97ac41242fbcfc053c67434f46 \ + --hash=sha256:3170c9399da12c9dc66366e9d14da8bf7147e1e9d9ea566067bbce7bb74bd9c2 \ + --hash=sha256:3b1613dd5aee995ec6d4c69f00378bbd07614702a315a2cf6c1d21461fe17c23 \ + --hash=sha256:3bb3d25a8e6c0aedd251753a79ae98a093c7e7b471faa3aa9a93a81431987ace \ + --hash=sha256:3bb7fda7260735efe66d5107fb7e6af6a7c04c7fce9b2514e04b7a74b06bf5dd \ + --hash=sha256:41b25eaa7d15909cf3ac4c96088c1f266a9a93ec44f87f1d13d4a0e86c81b982 \ + --hash=sha256:45de3f87179c1823e6d9e32156fb14c1927fcc9aba21433f088fdfb555b77c10 \ + --hash=sha256:46fb8c61d794b78ec7134a715a3e564aafc8f6b5e338417cb19fe9f57a5a9bf2 \ + --hash=sha256:48021783bdf96e3d6de03a6e39a1171ed5bd7e8bb93fc84cc649d11490f87cea \ + --hash=sha256:4957669ef390f0e6719db3613ab3a7631e68424604a7b448f079bee145da6e09 \ + --hash=sha256:5e86d77b090dbddbe78867a0275cb4df08ea195e660f1f7f13435a4649e954e5 \ + --hash=sha256:6339d047dab2780cc6220f46306628e04d9750f02f983ddb37439ca47ced7149 \ + --hash=sha256:681eb3d7e02e3c3655d1b16059fbfb605ac464c834a0c629048a30fad2b27489 \ + --hash=sha256:6c409c0deba34f147f77efaa67b8e4bb83d2f11c8806405f76397ae5b8c0d1c9 \ + --hash=sha256:7095f6fbfaa55defb6b733cfeb14efaae7a29f0b59d8cf213be4e7ca0b857b80 \ + --hash=sha256:70c610f6cbe4b9fce272c407dd9d07e33e6bf7b4aa1b7ffb6f6ded8e634e3592 \ + --hash=sha256:72814c01533f51d68702802d74f77ea026b5ec52793c791e2da806a3844a46c3 \ + --hash=sha256:7a4826ad2bd6b07ca615c74ab91f32f6c96d08f6fcc3902ceeedaec8cdc3bcd6 \ + --hash=sha256:7c70087bfee18a42b4040bb9ec1ca15a08242cf5867c58726530bdf3945672ed \ + --hash=sha256:855eafa5d5a2034b4621c74925d89c5efef61418570e5ef9b37717d9c796419c \ + --hash=sha256:8700f06d0ce6f128de3ccdbc1acaea1ee264d2caa9ca05daaf492fde7c2a7200 \ + --hash=sha256:89f1b185a01fe560bc8ae5f619e924407efca2191b56ce749ec84982fc59a32a \ + --hash=sha256:8b2c760cfc7042b27ebdb4a43a4453bd829a5742503599144d54a032c5dc7e9e \ + --hash=sha256:8c2f5e83493748286002f9369f3e6607c565a6a90425a3a1fef5ae32a36d749d \ + --hash=sha256:8e098148dd37b4ce3baca71fb394c81dc5d9c7728c95df695d2dca218edf40e6 \ + --hash=sha256:94aea8eff76ee6d1cdacb07dd2123a68283cb5569e0250feab1240058f53b623 \ + --hash=sha256:95eb302ff792e12aba9a8b8f8474ab229a83c103d74a750ec0bd1c1eea32e669 \ + --hash=sha256:9bd9b3b31adcb054116447ea22caa61a285d92e94d710aa5ec97992ff5eb7cf3 \ + --hash=sha256:9e608aafdb55eb9f255034709e20d5a83b6d60c054df0802fa9c9883d0a937aa \ + --hash=sha256:a103b3a7069b62f5d4890ae1b8f0597618f628b286b03d4bc9195230b154bfa9 \ + --hash=sha256:a386ebe437176aab38c041de1260cd3ea459c6ce5263594399880bbc398225b2 \ + --hash=sha256:a38856a971c602f98472050165cea2cdc97709240373041b69030be15047691f \ + --hash=sha256:a401b4598e5d3f4a9a811f3daf42ee2291790c7f9d74b18d75d6e21dda98a1a1 \ + --hash=sha256:a7647ebdfb9682b7bb97e2a5e7cb6ae735b1c25008a70b906aecca294ee96cf4 \ + --hash=sha256:aaf63899c94de41fe3cf934601b0f7ccb6b428c6e4eeb80da72c58eab077b19a \ + --hash=sha256:b0dac0ff919ba34d4df1b6131f59ce95b08b9065233446be7e459f95554c0dc8 \ + --hash=sha256:baacc6aee0b2ef6f3d308e197b5d7a81c0e70b06beae1f1fcacffdbd124fe0e3 \ + --hash=sha256:bf420121d4c8dce6b889f0e8e4ec0ca34b7f40186203f06a946fa0276ba54029 \ + --hash=sha256:c04a46716adde8d927adb9457bbe39cf473e1e2c2f5d0a16ceb837e5d841ad4f \ + --hash=sha256:c0b21078a4b56965e2b12f247467b234734491897e99c1d51cee628da9786959 \ + --hash=sha256:c1c76a1743432b4b60ab3358c937a3fe1341c828ae6194108a94c69028247f22 \ + --hash=sha256:c4983bf937209c57240cff65906b18bb35e64ae872da6a0db937d7b4af845dd7 \ + --hash=sha256:c4fb39a81950ec280984b3a44f5bd12819953dc5fa3a7e6fa7a80db5ee853952 \ + --hash=sha256:c57921cda3a80d0f2b8aec7e25c8aa14479ea92b5b51b6876d975d925a2ea346 \ + --hash=sha256:c8063cf17b19661471ecbdb3df1c84f24ad2e389e326ccaf89e3fb2484d8dd7e \ + --hash=sha256:ccd16eb18a849fd8dcb23e23380e2f0a354e8daa0c984b8a732d9cfaba3a776d \ + --hash=sha256:cd6dbe0238f7743d0efe563ab46294f54f9bc8f4b9bcf57c3c666cc5bc9d1299 \ + --hash=sha256:d62e51710986674142526ab9f78663ca2b0726066ae26b78b22e0f5e571238dd \ + --hash=sha256:db901e2ac34c931d73054d9797383d0f8009991e723dab15109740a63e7f902a \ + --hash=sha256:e03b8895a6990c9ab2cdcd0f2fe44088ca1c65ae592b8f795c3294af00a461c3 \ + --hash=sha256:e1c8a2f4c69e08e89632defbfabec2feb8a8d99edc9f89ce33c4b9e36ab63037 \ + --hash=sha256:e4b749b9cc6ee664a3300bb3a273c1ca8068c46be705b6c31cf5d276f8628a94 \ + --hash=sha256:e6a5bf2cba5ae1bb80b154ed68a3cfa2fa00fde979a7f50d6598d3e17d9ac20c \ + --hash=sha256:e857a2232ba53ae940d3456f7533ce6ca98b81917d47adc3c7fd55dad8fab858 \ + --hash=sha256:ee4006268ed33370957f55bf2e6f4d263eaf4dc3cfc473d1d90baff6ed36ce4a \ + --hash=sha256:eef9df1eefada2c09a5e7a40991b9fc6ac6ef20b1372abd48d2794a316dc0449 \ + --hash=sha256:f058f6963fd82eb143c692cecdc89e075fa0828db2e5b291070485390b2f1c9c \ + --hash=sha256:f25c229a6ba38a35ae6e25ca1264621cc25d4d38dca2942a7fce0b67a4efe918 \ + --hash=sha256:f2a1d0fd4242bd8643ce6f98927cf9c04540af6efa92323e9d3124f57727bfc1 \ + --hash=sha256:f7560358a6811e52e9c4d142d497f1a6e10103d3a6881f18d04dbce3729c0e2c \ + --hash=sha256:f779d3ad205f108d14e99bb3859aa7dd8e9c68874617c72354d7ecaec2a054ac \ + --hash=sha256:f87f746ee241d30d6ed93969de31e5ffd09a2961a051e60ae6bddde9ec3583aa # via requests -google-auth==2.19.1 \ - --hash=sha256:a9cfa88b3e16196845e64a3658eb953992129d13ac7337b064c6546f77c17183 \ - --hash=sha256:ea165e014c7cbd496558796b627c271aa8c18b4cba79dc1cc962b24c5efdfb85 +google-auth==2.22.0 \ + --hash=sha256:164cba9af4e6e4e40c3a4f90a1a6c12ee56f14c0b4868d1ca91b32826ab334ce \ + --hash=sha256:d61d1b40897407b574da67da1a833bdc10d5a11642566e506565d1b1a46ba873 # via # google-auth-oauthlib # tb-nightly @@ -97,99 +97,95 @@ google-auth-oauthlib==1.0.0 \ --hash=sha256:95880ca704928c300f48194d1770cf5b1462835b6e49db61445a520f793fd5fb \ --hash=sha256:e375064964820b47221a7e1b7ee1fd77051b6323c3f9e3e19785f78ab67ecfc5 # via tb-nightly -grpcio==1.54.2 \ - --hash=sha256:0212e2f7fdf7592e4b9d365087da30cb4d71e16a6f213120c89b4f8fb35a3ab3 \ - --hash=sha256:09d4bfd84686cd36fd11fd45a0732c7628308d094b14d28ea74a81db0bce2ed3 \ - --hash=sha256:1e623e0cf99a0ac114f091b3083a1848dbc64b0b99e181473b5a4a68d4f6f821 \ - --hash=sha256:2288d76e4d4aa7ef3fe7a73c1c470b66ea68e7969930e746a8cd8eca6ef2a2ea \ - --hash=sha256:2296356b5c9605b73ed6a52660b538787094dae13786ba53080595d52df13a98 \ - --hash=sha256:2a1e601ee31ef30a9e2c601d0867e236ac54c922d32ed9f727b70dd5d82600d5 \ - --hash=sha256:2be88c081e33f20630ac3343d8ad9f1125f32987968e9c8c75c051c9800896e8 \ - --hash=sha256:33d40954199bddbb6a78f8f6f2b2082660f381cd2583ec860a6c2fa7c8400c08 \ - --hash=sha256:40e1cbf69d6741b40f750f3cccc64326f927ac6145a9914d33879e586002350c \ - --hash=sha256:46a057329938b08e5f0e12ea3d7aed3ecb20a0c34c4a324ef34e00cecdb88a12 \ - --hash=sha256:4864f99aac207e3e45c5e26c6cbb0ad82917869abc2f156283be86c05286485c \ - --hash=sha256:4c44e1a765b31e175c391f22e8fc73b2a2ece0e5e6ff042743d8109b5d2eff9f \ - --hash=sha256:4cb283f630624ebb16c834e5ac3d7880831b07cbe76cb08ab7a271eeaeb8943e \ - --hash=sha256:5008964885e8d23313c8e5ea0d44433be9bfd7e24482574e8cc43c02c02fc796 \ - --hash=sha256:50a9f075eeda5097aa9a182bb3877fe1272875e45370368ac0ee16ab9e22d019 \ - --hash=sha256:51630c92591d6d3fe488a7c706bd30a61594d144bac7dee20c8e1ce78294f474 \ - --hash=sha256:5cc928cfe6c360c1df636cf7991ab96f059666ac7b40b75a769410cc6217df9c \ - --hash=sha256:61f7203e2767800edee7a1e1040aaaf124a35ce0c7fe0883965c6b762defe598 \ - --hash=sha256:66233ccd2a9371158d96e05d082043d47dadb18cbb294dc5accfdafc2e6b02a7 \ - --hash=sha256:70fcac7b94f4c904152809a050164650ac81c08e62c27aa9f156ac518029ebbe \ - --hash=sha256:714242ad0afa63a2e6dabd522ae22e1d76e07060b5af2ddda5474ba4f14c2c94 \ - --hash=sha256:782f4f8662a2157c4190d0f99eaaebc602899e84fb1e562a944e5025929e351c \ - --hash=sha256:7fc2b4edb938c8faa4b3c3ea90ca0dd89b7565a049e8e4e11b77e60e4ed2cc05 \ - --hash=sha256:881d058c5ccbea7cc2c92085a11947b572498a27ef37d3eef4887f499054dca8 \ - --hash=sha256:89dde0ac72a858a44a2feb8e43dc68c0c66f7857a23f806e81e1b7cc7044c9cf \ - --hash=sha256:8cdbcbd687e576d48f7886157c95052825ca9948c0ed2afdc0134305067be88b \ - --hash=sha256:8d6192c37a30a115f4663592861f50e130caed33efc4eec24d92ec881c92d771 \ - --hash=sha256:96a41817d2c763b1d0b32675abeb9179aa2371c72aefdf74b2d2b99a1b92417b \ - --hash=sha256:9bdbb7624d65dc0ed2ed8e954e79ab1724526f09b1efa88dcd9a1815bf28be5f \ - --hash=sha256:9bf88004fe086c786dc56ef8dd6cb49c026833fdd6f42cb853008bce3f907148 \ - --hash=sha256:a08920fa1a97d4b8ee5db2f31195de4a9def1a91bc003544eb3c9e6b8977960a \ - --hash=sha256:a2f5a1f1080ccdc7cbaf1171b2cf384d852496fe81ddedeb882d42b85727f610 \ - --hash=sha256:b04202453941a63b36876a7172b45366dc0cde10d5fd7855c0f4a4e673c0357a \ - --hash=sha256:b38b3de8cff5bc70f8f9c615f51b48eff7313fc9aca354f09f81b73036e7ddfa \ - --hash=sha256:b52d00d1793d290c81ad6a27058f5224a7d5f527867e5b580742e1bd211afeee \ - --hash=sha256:b74ae837368cfffeb3f6b498688a123e6b960951be4dec0e869de77e7fa0439e \ - --hash=sha256:be48496b0e00460717225e7680de57c38be1d8629dc09dadcd1b3389d70d942b \ - --hash=sha256:c0e3155fc5335ec7b3b70f15230234e529ca3607b20a562b6c75fb1b1218874c \ - --hash=sha256:c2392f5b5d84b71d853918687d806c1aa4308109e5ca158a16e16a6be71041eb \ - --hash=sha256:c72956972e4b508dd39fdc7646637a791a9665b478e768ffa5f4fe42123d5de1 \ - --hash=sha256:dc80c9c6b608bf98066a038e0172013a49cfa9a08d53335aefefda2c64fc68f4 \ - --hash=sha256:e416c8baf925b5a1aff31f7f5aecc0060b25d50cce3a5a7255dc5cf2f1d4e5eb \ - --hash=sha256:f8da84bbc61a4e92af54dc96344f328e5822d574f767e9b08e1602bb5ddc254a \ - --hash=sha256:f900ed4ad7a0f1f05d35f955e0943944d5a75f607a836958c6b8ab2a81730ef2 \ - --hash=sha256:fd6c6c29717724acf9fc1847c4515d57e4dc12762452457b9cb37461f30a81bb +grpcio==1.56.2 \ + --hash=sha256:06e84ad9ae7668a109e970c7411e7992751a116494cba7c4fb877656527f9a57 \ + --hash=sha256:0ff789ae7d8ddd76d2ac02e7d13bfef6fc4928ac01e1dcaa182be51b6bcc0aaa \ + --hash=sha256:10954662f77dc36c9a1fb5cc4a537f746580d6b5734803be1e587252682cda8d \ + --hash=sha256:139f66656a762572ae718fa0d1f2dce47c05e9fbf7a16acd704c354405b97df9 \ + --hash=sha256:1c31e52a04e62c8577a7bf772b3e7bed4df9c9e0dd90f92b6ffa07c16cab63c9 \ + --hash=sha256:33971197c47965cc1d97d78d842163c283e998223b151bab0499b951fd2c0b12 \ + --hash=sha256:345356b307cce5d14355e8e055b4ca5f99bc857c33a3dc1ddbc544fca9cd0475 \ + --hash=sha256:373b48f210f43327a41e397391715cd11cfce9ded2fe76a5068f9bacf91cc226 \ + --hash=sha256:3ccb621749a81dc7755243665a70ce45536ec413ef5818e013fe8dfbf5aa497b \ + --hash=sha256:42a3bbb2bc07aef72a7d97e71aabecaf3e4eb616d39e5211e2cfe3689de860ca \ + --hash=sha256:42e63904ee37ae46aa23de50dac8b145b3596f43598fa33fe1098ab2cbda6ff5 \ + --hash=sha256:4eb37dd8dd1aa40d601212afa27ca5be255ba792e2e0b24d67b8af5e012cdb7d \ + --hash=sha256:51173e8fa6d9a2d85c14426bdee5f5c4a0654fd5fddcc21fe9d09ab0f6eb8b35 \ + --hash=sha256:5144feb20fe76e73e60c7d73ec3bf54f320247d1ebe737d10672480371878b48 \ + --hash=sha256:5344be476ac37eb9c9ad09c22f4ea193c1316bf074f1daf85bddb1b31fda5116 \ + --hash=sha256:6108e5933eb8c22cd3646e72d5b54772c29f57482fd4c41a0640aab99eb5071d \ + --hash=sha256:6a007a541dff984264981fbafeb052bfe361db63578948d857907df9488d8774 \ + --hash=sha256:6ee26e9dfb3996aff7c870f09dc7ad44a5f6732b8bdb5a5f9905737ac6fd4ef1 \ + --hash=sha256:750de923b456ca8c0f1354d6befca45d1f3b3a789e76efc16741bd4132752d95 \ + --hash=sha256:7c5ede2e2558f088c49a1ddda19080e4c23fb5d171de80a726b61b567e3766ed \ + --hash=sha256:830215173ad45d670140ff99aac3b461f9be9a6b11bee1a17265aaaa746a641a \ + --hash=sha256:8391cea5ce72f4a12368afd17799474015d5d3dc00c936a907eb7c7eaaea98a5 \ + --hash=sha256:8940d6de7068af018dfa9a959a3510e9b7b543f4c405e88463a1cbaa3b2b379a \ + --hash=sha256:89a49cc5ad08a38b6141af17e00d1dd482dc927c7605bc77af457b5a0fca807c \ + --hash=sha256:900bc0096c2ca2d53f2e5cebf98293a7c32f532c4aeb926345e9747452233950 \ + --hash=sha256:97e0efaebbfd222bcaac2f1735c010c1d3b167112d9d237daebbeedaaccf3d1d \ + --hash=sha256:9e04d4e4cfafa7c5264e535b5d28e786f0571bea609c3f0aaab13e891e933e9c \ + --hash=sha256:a4c60abd950d6de3e4f1ddbc318075654d275c29c846ab6a043d6ed2c52e4c8c \ + --hash=sha256:a6ff459dac39541e6a2763a4439c4ca6bc9ecb4acc05a99b79246751f9894756 \ + --hash=sha256:a72797549935c9e0b9bc1def1768c8b5a709538fa6ab0678e671aec47ebfd55e \ + --hash=sha256:af4063ef2b11b96d949dccbc5a987272f38d55c23c4c01841ea65a517906397f \ + --hash=sha256:b975b85d1d5efc36cf8b237c5f3849b64d1ba33d6282f5e991f28751317504a1 \ + --hash=sha256:bf0b9959e673505ee5869950642428046edb91f99942607c2ecf635f8a4b31c9 \ + --hash=sha256:c0c85c5cbe8b30a32fa6d802588d55ffabf720e985abe9590c7c886919d875d4 \ + --hash=sha256:c3f3237a57e42f79f1e560726576aedb3a7ef931f4e3accb84ebf6acc485d316 \ + --hash=sha256:c3fa3ab0fb200a2c66493828ed06ccd1a94b12eddbfb985e7fd3e5723ff156c6 \ + --hash=sha256:c435f5ce1705de48e08fcbcfaf8aee660d199c90536e3e06f2016af7d6a938dd \ + --hash=sha256:c90da4b124647547a68cf2f197174ada30c7bb9523cb976665dfd26a9963d328 \ + --hash=sha256:cbdf2c498e077282cd427cfd88bdce4668019791deef0be8155385ab2ba7837f \ + --hash=sha256:d1fbad1f9077372b6587ec589c1fc120b417b6c8ad72d3e3cc86bbbd0a3cee93 \ + --hash=sha256:d39f5d4af48c138cb146763eda14eb7d8b3ccbbec9fe86fb724cd16e0e914c64 \ + --hash=sha256:ddb4a6061933bd9332b74eac0da25f17f32afa7145a33a0f9711ad74f924b1b8 \ + --hash=sha256:ded637176addc1d3eef35331c39acc598bac550d213f0a1bedabfceaa2244c87 \ + --hash=sha256:f20fd21f7538f8107451156dd1fe203300b79a9ddceba1ee0ac8132521a008ed \ + --hash=sha256:fda2783c12f553cdca11c08e5af6eecbd717280dc8fbe28a110897af1c15a88c # via # -r ./requirements.in # tb-nightly -h5py==3.8.0 \ - --hash=sha256:03890b1c123d024fb0239a3279737d5432498c1901c354f8b10d8221d1d16235 \ - --hash=sha256:0fef76e10b9216657fa37e7edff6d8be0709b25bd5066474c229b56cf0098df9 \ - --hash=sha256:26ffc344ec9984d2cd3ca0265007299a8bac8d85c1ad48f4639d8d3aed2af171 \ - --hash=sha256:290e00fa2de74a10688d1bac98d5a9cdd43f14f58e562c580b5b3dfbd358ecae \ - --hash=sha256:33b15aae79e9147aebe1d0e54099cbcde8d65e3e227cd5b59e49b1272aa0e09d \ - --hash=sha256:36761693efbe53df179627a775476dcbc37727d6e920958277a7efbc18f1fb73 \ - --hash=sha256:377865821fe80ad984d003723d6f8890bd54ceeb5981b43c0313b9df95411b30 \ - --hash=sha256:49bc857635f935fa30e92e61ac1e87496df8f260a6945a3235e43a9890426866 \ - --hash=sha256:4a506fc223def428f4329e7e1f9fe1c8c593eab226e7c0942c8d75308ad49950 \ - --hash=sha256:533d7dad466ddb7e3b30af274b630eb7c1a6e4ddf01d1c373a0334dc2152110a \ - --hash=sha256:5fd2252d1fc364ba0e93dd0b7089f4906b66805cb4e6aca7fa8874ac08649647 \ - --hash=sha256:6fead82f0c4000cf38d53f9c030780d81bfa0220218aee13b90b7701c937d95f \ - --hash=sha256:7f3350fc0a8407d668b13247861c2acd23f7f5fe7d060a3ad9b0820f5fcbcae0 \ - --hash=sha256:8f55d9c6c84d7d09c79fb85979e97b81ec6071cc776a97eb6b96f8f6ec767323 \ - --hash=sha256:98a240cd4c1bfd568aaa52ec42d263131a2582dab82d74d3d42a0d954cac12be \ - --hash=sha256:9f6f6ffadd6bfa9b2c5b334805eb4b19ca0a5620433659d8f7fb86692c40a359 \ - --hash=sha256:b685453e538b2b5934c58a644ac3f3b3d0cec1a01b6fb26d57388e9f9b674ad0 \ - --hash=sha256:b7865de06779b14d98068da387333ad9bf2756b5b579cc887fac169bc08f87c3 \ - --hash=sha256:bacaa1c16810dd2b3e4417f8e730971b7c4d53d234de61fe4a918db78e80e1e4 \ - --hash=sha256:bae730580ae928de409d63cbe4fdca4c82c3ad2bed30511d19d34e995d63c77e \ - --hash=sha256:c3389b63222b1c7a158bb7fe69d11ca00066740ec5574596d47a2fe5317f563a \ - --hash=sha256:c873ba9fd4fa875ad62ce0e4891725e257a8fe7f5abdbc17e51a5d54819be55c \ - --hash=sha256:db03e3f2c716205fbdabb34d0848459840585225eb97b4f08998c743821ca323 \ - --hash=sha256:f47f757d1b76f0ecb8aa0508ec8d1b390df67a8b67ee2515dc1b046f3a1596ea \ - --hash=sha256:f891b17e3a3e974e93f9e34e7cca9f530806543571ce078998676a555837d91d +h5py==3.9.0 \ + --hash=sha256:12aa556d540f11a2cae53ea7cfb94017353bd271fb3962e1296b342f6550d1b8 \ + --hash=sha256:23e74b878bbe1653ab34ca49b83cac85529cd0b36b9d625516c5830cc5ca2eac \ + --hash=sha256:36408f8c62f50007d14e000f9f3acf77e103b9e932c114cbe52a3089e50ebf94 \ + --hash=sha256:3f457089c5d524b7998e3649bc63240679b8fb0a3859ea53bbb06841f3d755f1 \ + --hash=sha256:54f01202cdea754ab4227dd27014bdbd561a4bbe4b631424fd812f7c2ce9c6ac \ + --hash=sha256:551e358db05a874a0f827b22e95b30092f2303edc4b91bb62ad2f10e0236e1a0 \ + --hash=sha256:64acceaf6aff92af091a4b83f6dee3cf8d3061f924a6bb3a33eb6c4658a8348b \ + --hash=sha256:6822a814b9d8b8363ff102f76ea8d026f0ca25850bb579d85376029ee3e73b93 \ + --hash=sha256:78e44686334cbbf2dd21d9df15823bc38663f27a3061f6a032c68a3e30c47bf7 \ + --hash=sha256:79bbca34696c6f9eeeb36a91776070c49a060b2879828e2c8fa6c58b8ed10dd1 \ + --hash=sha256:804c7fb42a34c8ab3a3001901c977a5c24d2e9c586a0f3e7c0a389130b4276fc \ + --hash=sha256:8d9492391ff5c3c80ec30ae2fe82a3f0efd1e750833739c25b0d090e3be1b095 \ + --hash=sha256:95f7a745efd0d56076999b52e8da5fad5d30823bac98b59c68ae75588d09991a \ + --hash=sha256:9da9e7e63376c32704e37ad4cea2dceae6964cee0d8515185b3ab9cbd6b947bc \ + --hash=sha256:a4e20897c88759cbcbd38fb45b507adc91af3e0f67722aa302d71f02dd44d286 \ + --hash=sha256:a6284061f3214335e1eec883a6ee497dbe7a79f19e6a57fed2dd1f03acd5a8cb \ + --hash=sha256:d97409e17915798029e297a84124705c8080da901307ea58f29234e09b073ddc \ + --hash=sha256:dbf5225543ca35ce9f61c950b73899a82be7ba60d58340e76d0bd42bf659235a \ + --hash=sha256:e604db6521c1e367c6bd7fad239c847f53cc46646f2d2651372d05ae5e95f817 \ + --hash=sha256:eb7bdd5e601dd1739698af383be03f3dad0465fe67184ebd5afca770f50df9d6 \ + --hash=sha256:f68b41efd110ce9af1cbe6fa8af9f4dcbadace6db972d30828b911949e28fadd # via -r ./requirements.in idna==3.4 \ --hash=sha256:814f528e8dead7d329833b91c5faa87d60bf71824cd12a7530b5526063d02cb4 \ --hash=sha256:90b77e79eaa3eba6de819a0c442c0b4ceefc341a7a2ab77d7562bf49f425c5c2 # via requests -importlib-metadata==6.6.0 \ - --hash=sha256:43dd286a2cd8995d5eaef7fee2066340423b818ed3fd70adf0bad5f1fac53fed \ - --hash=sha256:92501cdf9cc66ebd3e612f1b4f0c0765dfa42f0fa38ffb319b6bd84dd675d705 +importlib-metadata==6.8.0 \ + --hash=sha256:3ebb78df84a805d7698245025b975d9d67053cd94c79245ba4b3eb694abe68bb \ + --hash=sha256:dbace7892d8c0c4ac1ad096662232f831d4e64f4c4545bd53016a3e9d4654743 # via markdown jax==0.4.7 \ --hash=sha256:5e7002d74db25f97c99b979d4ba1233b1ef26e1597e5fc468ad11d1c8a9dc4f8 # via -r ./requirements.in -keras-nightly==2.14.0.dev2023061207 \ - --hash=sha256:210671f010a0b21a5507be86b8e9e909f81b9f321cd3c51e1efdfdd41061919f \ - --hash=sha256:ad869b2bce863e111e4a57c7f5785f56097d93f683b5315df7f59917be1fa279 +keras-nightly==2.14.0.dev2023072407 \ + --hash=sha256:60ca7fae3ad903eeff858f45ddf9dc0dc395cdf41a07c26e807cb7f07e955441 \ + --hash=sha256:9eb387e3488f5ca87a4686b1ea93b8bd85b36d1006934e130f02185daf192ba5 # via -r ./requirements.in -lit==16.0.5.post0 \ - --hash=sha256:71745d9e58dad3717735d27e2a9cca0e9ca6861d067da73c307e02fd38c98479 +lit==16.0.6 \ + --hash=sha256:84623c9c23b6b14763d637f4e63e6b721b3446ada40bf7001d8fee70b8e77a9a # via -r ./requirements.in markdown==3.4.3 \ --hash=sha256:065fd4df22da73a625f14890dd77eb8040edcbd68794bcd35943be14490608b2 \ @@ -319,20 +315,20 @@ portpicker==1.5.2 \ --hash=sha256:01113f51c3cc63290a44dd7ae6e3eb9f8fe1b8a1f9d7988a897944230c39cd52 \ --hash=sha256:c55683ad725f5c00a41bc7db0225223e8be024b1fa564d039ed3390e4fd48fb3 # via -r ./requirements.in -protobuf==4.23.2 \ - --hash=sha256:09310bce43353b46d73ba7e3bca78273b9bc50349509b9698e64d288c6372c2a \ - --hash=sha256:20874e7ca4436f683b64ebdbee2129a5a2c301579a67d1a7dda2cdf62fb7f5f7 \ - --hash=sha256:25e3370eda26469b58b602e29dff069cfaae8eaa0ef4550039cc5ef8dc004511 \ - --hash=sha256:281342ea5eb631c86697e1e048cb7e73b8a4e85f3299a128c116f05f5c668f8f \ - --hash=sha256:384dd44cb4c43f2ccddd3645389a23ae61aeb8cfa15ca3a0f60e7c3ea09b28b3 \ - --hash=sha256:54a533b971288af3b9926e53850c7eb186886c0c84e61daa8444385a4720297f \ - --hash=sha256:6c081863c379bb1741be8f8193e893511312b1d7329b4a75445d1ea9955be69e \ - --hash=sha256:86df87016d290143c7ce3be3ad52d055714ebaebb57cc659c387e76cfacd81aa \ - --hash=sha256:8da6070310d634c99c0db7df48f10da495cc283fd9e9234877f0cd182d43ab7f \ - --hash=sha256:b2cfab63a230b39ae603834718db74ac11e52bccaaf19bf20f5cce1a84cf76df \ - --hash=sha256:c52cfcbfba8eb791255edd675c1fe6056f723bf832fa67f0442218f8817c076e \ - --hash=sha256:ce744938406de1e64b91410f473736e815f28c3b71201302612a68bf01517fea \ - --hash=sha256:efabbbbac1ab519a514579ba9ec52f006c28ae19d97915951f69fa70da2c9e91 +protobuf==4.23.4 \ + --hash=sha256:0a5759f5696895de8cc913f084e27fd4125e8fb0914bb729a17816a33819f474 \ + --hash=sha256:351cc90f7d10839c480aeb9b870a211e322bf05f6ab3f55fcb2f51331f80a7d2 \ + --hash=sha256:5fea3c64d41ea5ecf5697b83e41d09b9589e6f20b677ab3c48e5f242d9b7897b \ + --hash=sha256:6dd9b9940e3f17077e820b75851126615ee38643c2c5332aa7a359988820c720 \ + --hash=sha256:7b19b6266d92ca6a2a87effa88ecc4af73ebc5cfde194dc737cf8ef23a9a3b12 \ + --hash=sha256:8547bf44fe8cec3c69e3042f5c4fb3e36eb2a7a013bb0a44c018fc1e427aafbd \ + --hash=sha256:9053df6df8e5a76c84339ee4a9f5a2661ceee4a0dab019e8663c50ba324208b0 \ + --hash=sha256:c3e0939433c40796ca4cfc0fac08af50b00eb66a40bbbc5dee711998fb0bbc1e \ + --hash=sha256:ccd9430c0719dce806b93f89c91de7977304729e55377f872a92465d548329a9 \ + --hash=sha256:e1c915778d8ced71e26fcf43c0866d7499891bca14c4368448a82edc61fdbc70 \ + --hash=sha256:e9d0be5bf34b275b9f87ba7407796556abeeba635455d036c7351f7c183ef8ff \ + --hash=sha256:effeac51ab79332d44fba74660d40ae79985901ac21bca408f8dc335a81aa597 \ + --hash=sha256:fee88269a090ada09ca63551bf2f573eb2424035bcf2cb1b121895b01a46594a # via tb-nightly psutil==5.9.5 \ --hash=sha256:104a5cc0e31baa2bcf67900be36acde157756b9c44017b86b2c049f11957887d \ @@ -410,16 +406,16 @@ six==1.16.0 \ --hash=sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926 \ --hash=sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254 # via google-auth -tb-nightly==2.14.0a20230612 \ - --hash=sha256:1ad7d57386f6103df69d8f3083552d185d33c5a07fb18e691b60236d9b4dc679 +tb-nightly==2.14.0a20230724 \ + --hash=sha256:eed70d9ddca771b938d7ee90b90673bdbfcd4da8bf292badab2088f26c4612fe # via -r ./requirements.in -tensorboard-data-server==0.7.0 \ - --hash=sha256:64aa1be7c23e80b1a42c13b686eb0875bb70f5e755f4d2b8de5c1d880cf2267f \ - --hash=sha256:753d4214799b31da7b6d93837959abebbc6afa86e69eacf1e9a317a48daa31eb \ - --hash=sha256:eb7fa518737944dbf4f0cf83c2e40a7ac346bf91be2e6a0215de98be74e85454 +tensorboard-data-server==0.7.1 \ + --hash=sha256:255c02b7f5b03dd5c0a88c928e563441ff39e1d4b4a234cdbe09f016e53d9594 \ + --hash=sha256:9938bd39f5041797b33921066fba0eab03a0dd10d1887a05e62ae58841ad4c3f \ + --hash=sha256:be8d016a1aa394e6198280d4a3dc37898f56467310c5f5e617cac10a783e055a # via tb-nightly -tf-estimator-nightly==2.14.0.dev2023060108 \ - --hash=sha256:09f63c090f29b74ebb36076c0ef1105c8b0358c6920847f312926012926ff7ce +tf-estimator-nightly==2.14.0.dev2023072408 \ + --hash=sha256:80b79192e643e923a2e075837785626bce5c35d2d2d34c03eee4fbcea8afb884 # via -r ./requirements.in urllib3==1.26.16 \ --hash=sha256:8d36afa7616d8ab714608411b4a3b13e58f463aee519024578e062e141dce20f \ @@ -437,7 +433,7 @@ wheel==0.38.4 \ # via # -r ./requirements.in # tb-nightly -zipp==3.15.0 \ - --hash=sha256:112929ad649da941c23de50f356a2b5570c954b65150642bccdd66bf194d224b \ - --hash=sha256:48904fc76a60e542af151aded95726c1a5c34ed43ab4134b597665c86d7ad556 +zipp==3.16.2 \ + --hash=sha256:679e51dd4403591b2d6838a48de3d283f3d188412a9782faadf845f298736ba0 \ + --hash=sha256:ebc15946aa78bd63458992fc81ec3b6f7b1e92d51c35e6de1c3804e73b799147 # via importlib-metadata From ae4502aea67e6ef317643d13b8c42c5005151411 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 25 Jul 2023 11:20:37 -0700 Subject: [PATCH 123/410] Extract transpose fusions from ir_emitter_unnested. PiperOrigin-RevId: 550943790 --- .../compiler/xla/service/gpu/fusions/BUILD | 23 -- .../xla/service/gpu/fusions/fusions.cc | 4 - .../xla/service/gpu/fusions/transpose.cc | 232 ------------------ .../xla/service/gpu/fusions/transpose.h | 81 ------ .../xla/service/gpu/ir_emitter_unnested.cc | 222 ++++++++++++++++- .../xla/service/gpu/ir_emitter_unnested.h | 46 ++++ 6 files changed, 267 insertions(+), 341 deletions(-) delete mode 100644 tensorflow/compiler/xla/service/gpu/fusions/transpose.cc delete mode 100644 tensorflow/compiler/xla/service/gpu/fusions/transpose.h diff --git a/tensorflow/compiler/xla/service/gpu/fusions/BUILD b/tensorflow/compiler/xla/service/gpu/fusions/BUILD index 5bd8adde24b674..644544246200ce 100644 --- a/tensorflow/compiler/xla/service/gpu/fusions/BUILD +++ b/tensorflow/compiler/xla/service/gpu/fusions/BUILD @@ -62,7 +62,6 @@ cc_library( ":in_place_dynamic_update_slice", ":loop", ":reduction", - ":transpose", "//tensorflow/compiler/xla/hlo/ir:hlo", "//tensorflow/compiler/xla/mlir_hlo:lhlo", "//tensorflow/compiler/xla/service:elemental_ir_emitter", @@ -148,25 +147,3 @@ cc_library( "@llvm-project//mlir:MemRefDialect", ], ) - -cc_library( - name = "transpose", - srcs = ["transpose.cc"], - hdrs = ["transpose.h"], - deps = [ - ":fusion_emitter", - ":tiling_util", - "//tensorflow/compiler/xla:permutation_util", - "//tensorflow/compiler/xla/hlo/ir:hlo", - "//tensorflow/compiler/xla/mlir_hlo:lhlo", - "//tensorflow/compiler/xla/service:elemental_ir_emitter", - "//tensorflow/compiler/xla/service/gpu:hlo_fusion_analysis", - "//tensorflow/compiler/xla/service/gpu:ir_emitter_context", - "//tensorflow/compiler/xla/service/gpu:parallel_loop_emitter", - "//tensorflow/compiler/xla/service/gpu:target_util", - "//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter", - "//tensorflow/compiler/xla/service/llvm_ir:ir_array", - "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", - "@llvm-project//llvm:ir_headers", - ], -) diff --git a/tensorflow/compiler/xla/service/gpu/fusions/fusions.cc b/tensorflow/compiler/xla/service/gpu/fusions/fusions.cc index cbbd60050eda7f..34723c1ec64eef 100644 --- a/tensorflow/compiler/xla/service/gpu/fusions/fusions.cc +++ b/tensorflow/compiler/xla/service/gpu/fusions/fusions.cc @@ -22,7 +22,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/fusions/in_place_dynamic_update_slice.h" #include "tensorflow/compiler/xla/service/gpu/fusions/loop.h" #include "tensorflow/compiler/xla/service/gpu/fusions/reduction.h" -#include "tensorflow/compiler/xla/service/gpu/fusions/transpose.h" #include "tensorflow/compiler/xla/service/gpu/hlo_fusion_analysis.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" @@ -54,9 +53,6 @@ std::optional> GetFusionEmitter( case HloFusionAnalysis::EmitterFusionKind::kReduction: return std::make_unique( ir_emitter_context, elemental_emitter, fusion_op, fusion, analysis); - case HloFusionAnalysis::EmitterFusionKind::kTranspose: - return std::make_unique( - ir_emitter_context, elemental_emitter, fusion_op, fusion, analysis); case HloFusionAnalysis::EmitterFusionKind::kLoop: { bool is_single = IsSingleInstructionFusion(fusion_op); if (!is_single && CanEmitFusedDynamicUpdateSliceInPlaceForGpu( diff --git a/tensorflow/compiler/xla/service/gpu/fusions/transpose.cc b/tensorflow/compiler/xla/service/gpu/fusions/transpose.cc deleted file mode 100644 index c56323fc4aca78..00000000000000 --- a/tensorflow/compiler/xla/service/gpu/fusions/transpose.cc +++ /dev/null @@ -1,232 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu/fusions/transpose.h" - -#include - -#include "llvm/IR/IRBuilder.h" -#include "tensorflow/compiler/xla/permutation_util.h" -#include "tensorflow/compiler/xla/service/gpu/fusions/tiling_util.h" -#include "tensorflow/compiler/xla/service/gpu/target_util.h" -#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" -#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" -#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" - -namespace xla { -namespace gpu { -namespace { - -llvm::GlobalVariable* AllocateShared( - llvm::IRBuilder<>* builder, const TilingScheme& tiling_scheme, - llvm::Type* element_type, - absl::Span dimensions_major_to_minor, - absl::string_view buffer_name) { - CHECK(!dimensions_major_to_minor.empty()); - llvm::Type* ty = element_type; - for (auto dim : llvm::reverse(dimensions_major_to_minor)) { - ty = llvm::ArrayType::get(ty, dim); - } - ty = llvm::ArrayType::get(ty, tiling_scheme.GetThreadIdScalingFactor()); - return llvm_ir::AllocateSharedMemoryTile( - builder->GetInsertBlock()->getModule(), ty, buffer_name); -} - -void MaybeEmitFenceForAMDGPU(llvm::IRBuilder<>* builder, - IrEmitterContext& ir_emitter_context) { - auto* module = builder->GetInsertBlock()->getModule(); - if (IsAMDGPU(module) && - ir_emitter_context.rocm_compute_capability().gcn_arch_name().substr( - 0, 6) == "gfx90a") { - builder->CreateFence( - llvm::AtomicOrdering::SequentiallyConsistent, - builder->getContext().getOrInsertSyncScopeID("workgroup")); - } -} - -void EmitSyncThreads(llvm::IRBuilder<>* builder, - IrEmitterContext& ir_emitter_context) { - MaybeEmitFenceForAMDGPU(builder, ir_emitter_context); - EmitCallToTargetIntrinsic(TargetIntrinsicID::kBarrierId, {}, {}, builder); -} - -llvm_ir::IrArray::Index PermuteIndex(const llvm_ir::IrArray::Index& index, - absl::Span permutation) { - return llvm_ir::IrArray::Index{Permute(index.multidim(), permutation), - Permute(index.dims(), permutation), - index.GetType()}; -} - -} // namespace - -Status TransposeFusion::EmitKernel(const LaunchDimensions& launch_dims, - std::vector inputs, - std::vector outputs, - llvm::IRBuilder<>* builder, - int kernel_index) const { - const auto& tiling_scheme = *analysis_.GetTransposeTilingScheme(); - std::vector hlo_roots = - GetFusionRoots(fusion().fused_instructions_computation()); - FusedIrEmitter fused_emitter(elemental_emitter()); - for (auto [i, input] : llvm::enumerate(inputs)) { - HloInstruction* fused_operand = fusion().fused_parameter(i); - fused_emitter.BindGenerator( - *fused_operand, [input = input, builder, - fused_operand](const llvm_ir::IrArray::Index& index) { - return input.EmitReadArrayElement(index, builder, - fused_operand->name()); - }); - } - - absl::flat_hash_map tiles; - Vector3 permutation; - for (const auto& [tile_idx, root] : llvm::enumerate(hlo_roots)) { - if (auto tr = FindAnyTiledTranspose(*root)) { - permutation = tr->permutation; - const HloInstruction& hero = FindNonTrivialHero(*root); - tiles[&hero] = AllocateShared( - builder, tiling_scheme, - llvm_ir::PrimitiveTypeToIrType( - hero.operand(0)->shape().element_type(), - ir_emitter_context().llvm_module()), - {tiling_scheme.GetBlockTileSizeFor(permutation[TilingScheme::DimX]), - tiling_scheme.GetBlockTileSizeFor(TilingScheme::DimX) + 1}, - absl::StrCat("tr_tile_", tile_idx)); - } - } - - TileElementGenerator tile_generator = - [&](const TilingThreadIdInfo& thread_id_info, - const llvm_ir::IrArray::Index& index, - std::array tile_dimensions) { - // Copy input parameter values to shared memory buffers: - // tile[thread_id_y, thread_id_x] = input[index] - // Note that tile_width and tile_height are flipped here because we - // are reading a transposed tile. - EmitTile( - builder, tiling_scheme, index, thread_id_info, tile_dimensions, - [&](const TilingThreadIdInfo& thread_id_info, - const llvm_ir::IrArray::Index& index, llvm::Value* y_loc, - llvm::Value* x_loc) { - // Compute all extra output values before writing them. This - // avoids overwriting aliased input/output values before all reads - // occurred. - std::vector> - scheduled_writes; - - for (const auto& [output_idx, root] : - llvm::enumerate(hlo_roots)) { - if (FindAnyTiledTranspose(*root)) { - const HloInstruction& hero = FindNonTrivialHero(*root); - llvm_ir::ElementGenerator input_gen = - *fused_emitter.GetGenerator(*hero.operand(0)); - llvm_ir::IrArray::Index untiled_index = GetUnnormalizedIndex( - index, hero.operand(0)->shape(), builder, - tiling_scheme.GetDimsInElems()); - llvm::Value* value = *input_gen(untiled_index); - llvm::Value* addr = thread_id_info.GEPIntoSharedMemory( - builder, tiles[&hero], {y_loc, x_loc}); - - builder->CreateStore(value, addr); - } else { - llvm_ir::IrArray::Index untiled_index = - GetUnnormalizedIndex(index, root->shape(), builder, - tiling_scheme.GetDimsInElems()); - llvm_ir::ElementGenerator output_gen = - *fused_emitter.GetGenerator(*root); - llvm::Value* output_value = *output_gen(untiled_index); - scheduled_writes.emplace_back(outputs[output_idx], - untiled_index, output_value); - } - } - - for (const auto& [output, idx, value] : scheduled_writes) { - output.EmitWriteArrayElement(idx, value, builder); - } - }); - - EmitSyncThreads(builder, ir_emitter_context()); - - llvm_ir::IrArray::Index output_tile_index = - PermuteIndex(index, permutation); - std::array transposed_tile_dimensions = { - tile_dimensions[1], tile_dimensions[0]}; - - EmitTile( - builder, tiling_scheme, output_tile_index, thread_id_info, - transposed_tile_dimensions, - /*emit_elem_function=*/ - [&](const TilingThreadIdInfo& thread_id_info, - const llvm_ir::IrArray::Index& index, llvm::Value* y_loc, - llvm::Value* x_loc) { - for (const auto& [output_idx, root] : - llvm::enumerate(hlo_roots)) { - if (FindAnyTiledTranspose(*root)) { - const HloInstruction& hero = FindNonTrivialHero(*root); - - std::vector idx = {x_loc, y_loc}; - llvm::Value* gep = thread_id_info.GEPIntoSharedMemory( - builder, tiles[&hero], idx); - llvm::Type* type = - thread_id_info.GEPIntoSharedMemoryType(tiles[&hero], idx); - llvm::Value* loaded = - builder->CreateLoad(type, gep, "tiled_buffer"); - - FusedIrEmitter fused_emitter(elemental_emitter()); - fused_emitter.BindGenerator( - hero, [&](const llvm_ir::IrArray::Index& index) { - return loaded; - }); - for (int64_t i = 0; i < fusion() - .fused_instructions_computation() - ->num_parameters(); - ++i) { - llvm_ir::IrArray ir_array = inputs[i]; - HloInstruction* fused_operand = fusion().fused_parameter(i); - fused_emitter.BindGenerator( - *fused_operand, - [=](const llvm_ir::IrArray::Index& index) { - return ir_array.EmitReadArrayElement( - index, builder, fused_operand->name()); - }); - } - - // Apply codegeneration for the code after the real hero. - TF_ASSIGN_OR_RETURN(llvm_ir::ElementGenerator gen, - fused_emitter.GetGenerator(*root)); - - // Both for emission and writing it should be - // index-as-transformed by the computation. - llvm_ir::IrArray::Index untiled_index = GetUnnormalizedIndex( - index, root->shape(), builder, - Permute(tiling_scheme.GetDimsInElems(), permutation)); - TF_ASSIGN_OR_RETURN(llvm::Value * generated, - gen(untiled_index)); - outputs[output_idx].EmitWriteArrayElement(untiled_index, - generated, builder); - } - } - return OkStatus(); - }); - }; - - llvm::Type* index_type = - GetIndexTypeForKernel(fusion_op(), launch_dims.launch_bound(), builder); - return EmitTilingKernel(builder, tiling_scheme, index_type, tile_generator) - .status(); -} - -} // namespace gpu -} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/fusions/transpose.h b/tensorflow/compiler/xla/service/gpu/fusions/transpose.h deleted file mode 100644 index cfd1ff9eef96c0..00000000000000 --- a/tensorflow/compiler/xla/service/gpu/fusions/transpose.h +++ /dev/null @@ -1,81 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FUSIONS_TRANSPOSE_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FUSIONS_TRANSPOSE_H_ - -#include - -#include "tensorflow/compiler/xla/hlo/ir/hlo_instructions.h" -#include "tensorflow/compiler/xla/mlir_hlo/lhlo/IR/lhlo_ops.h" -#include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" -#include "tensorflow/compiler/xla/service/gpu/fusions/fusion_emitter.h" -#include "tensorflow/compiler/xla/service/gpu/hlo_fusion_analysis.h" -#include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" - -namespace xla { -namespace gpu { - -// Emits a kernel for the given hlo instruction using a tiled 0-2-1 transpose -// algorithm to improve the memory access patterns for the input parameters -// with a shape that is a 0-2-1 transpose of the output tensor shape. The -// caller is responsible for making sure that it is safe to apply the shared -// memory transpose on the input parameters. -// -// For the purpose of tiling, the output tensors have a logical shape of three -// components 0-2-1 while the relevant input parameters have a logical shape -// of three components 0-1-2 in the order major to minor. The x- and y- -// dimensions of the tensors are tiled in square tiles with an edge length -// `kTileSize`. Each thread block of `kTileSize` x `kNumRows` threads -// transposes one tile: each thread copies kTileSize/kNumRows elements from -// the input to a shared memory tile, then the otherwise "regular HLO kernel" -// reads from the shared memory instead of the original input. -// -// This is similar to the following CUDA algorithm in TensorFlow: -// https://goo.gl/MStRV6. -// -// `kTileSize` should usually be same as warp size. We currently choose 32 for -// `kTileSize` and 4 for `kNumRows`. The CUDA algorithm uses 8 for `kNumRows`. -// -// TODO(b/33320379): Here each block transposes 1 tile. It may be more -// efficient to launch fewer blocks so each transposes many tiles. -class TransposeFusion : public KernelFusionEmitterBase { - public: - TransposeFusion(IrEmitterContext& ir_emitter_context, - ElementalIrEmitter& elemental_emitter, - mlir::lmhlo::FusionOp fusion_op, - const HloFusionInstruction& fusion, - HloFusionAnalysis& analysis) - : KernelFusionEmitterBase(ir_emitter_context, elemental_emitter, - fusion_op, fusion), - analysis_(analysis) {} - StatusOr launch_dimensions() const override { - return analysis_.GetLaunchDimensions(false); - } - - protected: - Status EmitKernel(const LaunchDimensions& launch_dims, - std::vector inputs, - std::vector outputs, - llvm::IRBuilder<>* builder, - int kernel_index) const override; - - private: - HloFusionAnalysis& analysis_; -}; - -} // namespace gpu -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FUSIONS_TRANSPOSE_H_ diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index e220091a188ddd..3e23e291c96330 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -68,6 +69,7 @@ limitations under the License. #include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" // from @llvm-project #include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h" // from @llvm-project #include "mlir/Target/LLVMIR/Export.h" // from @llvm-project +#include "tensorflow/compiler/xla/hlo/ir/hlo_casting_utils.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_computation.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_instructions.h" @@ -77,6 +79,7 @@ limitations under the License. #include "tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.h" #include "tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "tensorflow/compiler/xla/mlir_hlo/transforms/gpu_passes.h" +#include "tensorflow/compiler/xla/permutation_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/custom_call_target_registry.h" @@ -90,6 +93,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/fused_mha_thunk.h" #include "tensorflow/compiler/xla/service/gpu/fusions/fusions.h" #include "tensorflow/compiler/xla/service/gpu/fusions/thunk_util.h" +#include "tensorflow/compiler/xla/service/gpu/fusions/tiling_util.h" #include "tensorflow/compiler/xla/service/gpu/gemm_thunk.h" #include "tensorflow/compiler/xla/service/gpu/gpu_asm_opts_util.h" #include "tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h" @@ -107,6 +111,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h" #include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h" #include "tensorflow/compiler/xla/service/gpu/matmul_utils.h" +#include "tensorflow/compiler/xla/service/gpu/memset_thunk.h" #include "tensorflow/compiler/xla/service/gpu/nccl_all_gather_thunk.h" #include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h" #include "tensorflow/compiler/xla/service/gpu/nccl_all_to_all_thunk.h" @@ -1941,6 +1946,32 @@ Status IrEmitterUnnested::EmitTritonFusion( #endif // GOOGLE_CUDA +Status IrEmitterUnnested::EmitUnnestedTranspose( + mlir::lmhlo::FusionOp fusion, HloFusionAnalysis& fusion_analysis) { + auto* tiling_scheme = fusion_analysis.GetTransposeTilingScheme(); + // Set flag to false as Transpose has it's own custom logic of choosing a + // block size. + TF_ASSIGN_OR_RETURN(auto launch_dimensions, + fusion_analysis.GetLaunchDimensions( + /*use_experimental_block_size=*/false)); + + TF_ASSIGN_OR_RETURN( + std::optional> opt_ir_arrays, + BuildKernelThunkForFusion(fusion, launch_dimensions)); + if (!opt_ir_arrays.has_value()) { + // The kernel was reused, no need to emit code. + return OkStatus(); + } + std::vector& ir_arrays = opt_ir_arrays.value(); + + TF_RETURN_IF_ERROR(EmitTransposeTile( + fusion, fusion_analysis.fused_computation(), + absl::MakeSpan(ir_arrays).subspan(0, fusion.getInputBuffers().size()), + absl::MakeSpan(ir_arrays).subspan(fusion.getInputBuffers().size()), + *tiling_scheme, launch_dimensions)); + return OkStatus(); +} + Status IrEmitterUnnested::EmitFusion(mlir::Operation* op) { auto fusion_op = mlir::cast(op); @@ -2016,13 +2047,14 @@ Status IrEmitterUnnested::EmitFusion(mlir::Operation* op) { #endif LOG(FATAL) << "Unsupported fusion kind: " << backend_config.kind(); } + case HloFusionAnalysis::EmitterFusionKind::kTranspose: + return EmitUnnestedTranspose(fusion_op, fusion_analysis); case HloFusionAnalysis::EmitterFusionKind::kInputSlices: return EmitInputFusibleNonStridedSlices(op, fusion_analysis); case HloFusionAnalysis::EmitterFusionKind::kScatter: return EmitScatter(fusion_op, fused_computation, fusion_analysis); case HloFusionAnalysis::EmitterFusionKind::kLoop: case HloFusionAnalysis::EmitterFusionKind::kReduction: - case HloFusionAnalysis::EmitterFusionKind::kTranspose: return FailedPrecondition( "Loop fusion should have been handled by GetFusionEmitter."); } @@ -3261,6 +3293,194 @@ Status IrEmitterUnnested::EmitTargetElementLoop( return InternalError("This should be unreachable"); } +llvm::Value* IrEmitterUnnested::CastSharedToGlobal(llvm::Value* input, + llvm::Type* element_type, + llvm::Twine name) { + return b_.CreateAddrSpaceCast(input, + llvm::PointerType::get(element_type, + /*AddressSpace=*/0), + name); +} + +llvm::CallInst* IrEmitterUnnested::EmitSyncThreads() { + MaybeEmitFenceForAMDGPU(llvm::AtomicOrdering::SequentiallyConsistent, + "workgroup"); + return EmitCallToTargetIntrinsic(TargetIntrinsicID::kBarrierId, {}, {}, &b_); +} + +static IrArray::Index PermuteIndex(const IrArray::Index& index, + absl::Span permutation) { + return IrArray::Index{Permute(index.multidim(), permutation), + Permute(index.dims(), permutation), index.GetType()}; +} + +Status IrEmitterUnnested::EmitTransposeTile( + mlir::lmhlo::FusionOp fusion, const HloComputation* fusion_hlo, + absl::Span operand_arrays, + absl::Span output_arrays, + const TilingScheme& tiling_scheme, + const LaunchDimensions& launch_dimensions) { + std::vector hlo_roots = + GetFusionRoots(const_cast(fusion_hlo)); + FusedIrEmitter fused_emitter(elemental_emitter_); + for (int i = 0; i < fusion_hlo->num_parameters(); i++) { + llvm_ir::IrArray ir_array = operand_arrays[i]; + HloInstruction* fused_operand = fusion_hlo->parameter_instruction(i); + fused_emitter.BindGenerator( + *fused_operand, + [this, ir_array, fused_operand](const llvm_ir::IrArray::Index& index) { + return ir_array.EmitReadArrayElement(index, &b_, + fused_operand->name()); + }); + } + + absl::flat_hash_map tiles; + Vector3 permutation; + for (const auto& [tile_idx, root] : llvm::enumerate(hlo_roots)) { + if (auto tr = FindAnyTiledTranspose(*root)) { + permutation = tr->permutation; + const HloInstruction& hero = FindNonTrivialHero(*root); + tiles[&hero] = + AllocateShared(tiling_scheme, + llvm_ir::PrimitiveTypeToIrType( + hero.operand(0)->shape().element_type(), module_), + {tiling_scheme.GetBlockTileSizeFor(permutation[kDimX]), + tiling_scheme.GetBlockTileSizeFor(kDimX) + 1}, + absl::StrCat("tr_tile_", tile_idx)); + } + } + + TileElementGenerator tile_generator = [&](const TilingThreadIdInfo& + thread_id_info, + const IrArray::Index& index, + ValueVector2 tile_dimensions) { + // Copy input parameter values to shared memory buffers: + // tile[thread_id_y, thread_id_x] = input[index] + // Note that tile_width and tile_height are flipped here because we + // are reading a transposed tile. + EmitTile( + &b_, tiling_scheme, index, thread_id_info, tile_dimensions, + [&](const TilingThreadIdInfo& thread_id_info, + const IrArray::Index& index, llvm::Value* y_loc, + llvm::Value* x_loc) { + // Compute all extra output values before writing them. This avoids + // overwriting aliased input/output values before all reads occurred. + std::vector> + scheduled_writes; + + for (const auto& [output_idx, root] : llvm::enumerate(hlo_roots)) { + if (FindAnyTiledTranspose(*root)) { + const HloInstruction& hero = FindNonTrivialHero(*root); + llvm_ir::ElementGenerator input_gen = + *fused_emitter.GetGenerator(*hero.operand(0)); + IrArray::Index untiled_index = + GetUnnormalizedIndex(index, hero.operand(0)->shape(), &b_, + tiling_scheme.GetDimsInElems()); + llvm::Value* value = *input_gen(untiled_index); + llvm::Value* addr = thread_id_info.GEPIntoSharedMemory( + &b_, tiles[&hero], {y_loc, x_loc}); + + b_.CreateStore(value, addr); + } else { + IrArray::Index untiled_index = GetUnnormalizedIndex( + index, root->shape(), &b_, tiling_scheme.GetDimsInElems()); + llvm_ir::ElementGenerator output_gen = + *fused_emitter.GetGenerator(*root); + llvm::Value* output_value = *output_gen(untiled_index); + scheduled_writes.emplace_back(output_arrays[output_idx], + untiled_index, output_value); + } + } + + for (const auto& [output, idx, value] : scheduled_writes) { + output.EmitWriteArrayElement(idx, value, &b_); + } + }); + + EmitSyncThreads(); + + IrArray::Index output_tile_index = PermuteIndex(index, permutation); + ValueVector2 transposed_tile_dimensions = {tile_dimensions[1], + tile_dimensions[0]}; + + EmitTile( + &b_, tiling_scheme, output_tile_index, thread_id_info, + transposed_tile_dimensions, + /*emit_elem_function=*/ + [&](const TilingThreadIdInfo& thread_id_info, + const llvm_ir::IrArray::Index& index, llvm::Value* y_loc, + llvm::Value* x_loc) { + for (const auto& [output_idx, root] : llvm::enumerate(hlo_roots)) { + if (FindAnyTiledTranspose(*root)) { + const HloInstruction& hero = FindNonTrivialHero(*root); + + std::vector idx = {x_loc, y_loc}; + llvm::Value* gep = + thread_id_info.GEPIntoSharedMemory(&b_, tiles[&hero], idx); + llvm::Type* type = + thread_id_info.GEPIntoSharedMemoryType(tiles[&hero], idx); + llvm::Value* loaded = b_.CreateLoad(type, gep, "tiled_buffer"); + + FusedIrEmitter fused_emitter(elemental_emitter_); + fused_emitter.BindGenerator( + hero, [&](const IrArray::Index& index) { return loaded; }); + for (int64_t i = 0; i < fusion_hlo->num_parameters(); ++i) { + llvm_ir::IrArray ir_array = operand_arrays[i]; + HloInstruction* fused_operand = + fusion_hlo->parameter_instruction(i); + fused_emitter.BindGenerator( + *fused_operand, [this, ir_array, fused_operand]( + const llvm_ir::IrArray::Index& index) { + return ir_array.EmitReadArrayElement( + index, &b_, fused_operand->name()); + }); + } + + // Apply codegeneration for the code after the real hero. + TF_ASSIGN_OR_RETURN(llvm_ir::ElementGenerator gen, + fused_emitter.GetGenerator(*root)); + + // Both for emission and writing it should be index-as-transformed + // by the computation. + IrArray::Index untiled_index = GetUnnormalizedIndex( + index, root->shape(), &b_, + Permute(tiling_scheme.GetDimsInElems(), permutation)); + TF_ASSIGN_OR_RETURN(llvm::Value * generated, gen(untiled_index)); + output_arrays[output_idx].EmitWriteArrayElement(untiled_index, + generated, &b_); + } + } + return OkStatus(); + }); + }; + + llvm::Type* index_type = GetIndexTypeForKernel( + fusion.getOperation(), launch_dimensions.launch_bound(), &b_); + return EmitTilingKernel(&b_, tiling_scheme, index_type, tile_generator) + .status(); +} + +llvm::GlobalVariable* IrEmitterUnnested::AllocateShared( + const TilingScheme& tiling_scheme, llvm::Type* element_type, + absl::Span dimensions_major_to_minor, + absl::string_view buffer_name) { + CHECK(!dimensions_major_to_minor.empty()); + llvm::Type* array_type = nullptr; + for (int i = dimensions_major_to_minor.size() - 1; i >= 0; i--) { + // Iterate in minor-to-major order. + int64_t dim = dimensions_major_to_minor[i]; + if (!array_type) { + array_type = llvm::ArrayType::get(element_type, dim); + } else { + array_type = llvm::ArrayType::get(array_type, dim); + } + } + array_type = llvm::ArrayType::get(array_type, + tiling_scheme.GetThreadIdScalingFactor()); + return llvm_ir::AllocateSharedMemoryTile(b_.GetInsertBlock()->getModule(), + array_type, buffer_name); +} + // Emits code for slices based on the below structure. An if statement with // a guarding condition is generated for each ROOT slice. // diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index 39f5b34e23238f..905ed163d64ead 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -324,6 +324,33 @@ class IrEmitterUnnested : public IrEmitter { const ReductionCodegenInfo& reduction_info, const ExtraOutputGensMap& extra_output_gens); + // Emits a kernel for the given hlo instruction using a tiled 0-2-1 transpose + // algorithm to improve the memory access patterns for the input parameters + // with a shape that is a 0-2-1 transpose of the output tensor shape. The + // caller is responsible for making sure that it is safe to apply the shared + // memory transpose on the input parameters. + // + // + // For the purpose of tiling, the output tensors have a logical shape of three + // components 0-2-1 while the relevant input parameters have a logical shape + // of three components 0-1-2 in the order major to minor. The x- and y- + // dimensions of the tensors are tiled in square tiles with an edge length + // `kTileSize`. Each thread block of `kTileSize` x `kNumRows` threads + // transposes one tile: each thread copies kTileSize/kNumRows elements from + // the input to a shared memory tile, then the otherwise "regular HLO kernel" + // reads from the shared memory instead of the original input. + // + // This is similar to the following CUDA algorithm in TensorFlow: + // https://goo.gl/MStRV6. + // + // `kTileSize` should usually be same as warp size. We currently choose 32 for + // `kTileSize` and 4 for `kNumRows`. The CUDA algorithm uses 8 for `kNumRows`. + // + // TODO(b/33320379): Here each block transposes 1 tile. It may be more + // efficient to launch fewer blocks so each transposes many tiles. + Status EmitUnnestedTranspose(mlir::lmhlo::FusionOp fusion, + HloFusionAnalysis& fusion_analysis); + // Generates code for input-fusible slices. // // Prerequisite: ROOT is either a slice or a tuple of slices. The input shapes @@ -386,6 +413,17 @@ class IrEmitterUnnested : public IrEmitter { const HloComputation* fused_computation, HloFusionAnalysis& fusion_analysis); + // Allocates a shared tile of given dimensions, applying scaling specified in + // tilng_scheme as a major-most dimension to avoid collisions. + llvm::GlobalVariable* AllocateShared( + const TilingScheme& tiling_scheme, llvm::Type* element_type, + absl::Span dimensions_major_to_minor, + absl::string_view buffer_name = ""); + + // Removes some unneeded defining operations from the calculation of `value`, + // before passing it to a KernelThunk. + static StatusOr RemoveTransformingOperations(mlir::Value value); + // Builds a thunk that calls a new or reused kernel for a fusion operation. // // The caller must specify the same launch dimensions for fusions which have @@ -454,6 +492,9 @@ class IrEmitterUnnested : public IrEmitter { StatusOr> BuildConditionalThunk( const HloInstruction* conditional); + // Emit __syncthreads(), synchronization barrier for all threads in a block. + llvm::CallInst* EmitSyncThreads(); + StatusOr GetOrCreateSubComputationFromRegion( mlir::Region* region, bool is_fusion); @@ -472,6 +513,11 @@ class IrEmitterUnnested : public IrEmitter { scratch_nested_computations_; // End optional members for XLA HLO -> LMHLO. + // __shared__ memory uses a different address space, so we cast it to + // global address space before writing or reading. + llvm::Value* CastSharedToGlobal(llvm::Value* input, llvm::Type* element_type, + llvm::Twine name = ""); + // Returns the ShapedSlices for the given operands. StatusOr> GetShapedSlices( mlir::Operation::operand_range operands); From 8df5908761cc67a25a1c0940e8ee1dd8bec8b489 Mon Sep 17 00:00:00 2001 From: Edward Schwartz Date: Tue, 25 Jul 2023 11:28:57 -0700 Subject: [PATCH 124/410] Fix docstring for Bincount MathOps by adding newlines after section headers PiperOrigin-RevId: 550946261 --- tensorflow/python/ops/bincount_ops.py | 1 + tensorflow/python/ops/sparse_ops.py | 1 + 2 files changed, 2 insertions(+) diff --git a/tensorflow/python/ops/bincount_ops.py b/tensorflow/python/ops/bincount_ops.py index 716cbd1a0d1ae2..d86507b6e40a09 100644 --- a/tensorflow/python/ops/bincount_ops.py +++ b/tensorflow/python/ops/bincount_ops.py @@ -88,6 +88,7 @@ def bincount(arr, [1, 1, 1, 0]], dtype=int32)> **Missing zeros in SparseTensor** + Note that missing zeros (implict zeros) in SparseTensor are **NOT** counted. This supports cases such as `0` in the values tensor indicates that index/id `0`is present and a missing zero indicates that no index/id is present. diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py index b4ae38ca6ac952..6688d9fb546866 100644 --- a/tensorflow/python/ops/sparse_ops.py +++ b/tensorflow/python/ops/sparse_ops.py @@ -3055,6 +3055,7 @@ def bincount(arr: sparse_tensor.SparseTensor, [1, 1, 1, 0]], dtype=int32)> **Missing zeros in SparseTensor** + Note that missing zeros (implict zeros) in SparseTensor are **NOT** counted. This supports cases such as `0` in the values tensor indicates that index/id `0`is present and a missing zero indicates that no index/id is present. From 1ae0951754040eed9549d921bd3462f0502912b8 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 25 Jul 2023 11:58:52 -0700 Subject: [PATCH 125/410] Extract reduction fusion code from ir_emitter_unnested. PiperOrigin-RevId: 550955045 --- tensorflow/compiler/xla/service/gpu/BUILD | 4 - .../compiler/xla/service/gpu/fusions/BUILD | 46 - .../xla/service/gpu/fusions/fusion_emitter.cc | 4 +- .../xla/service/gpu/fusions/fusion_emitter.h | 9 - .../xla/service/gpu/fusions/fusions.cc | 4 - .../xla/service/gpu/fusions/reduction.cc | 959 ------------------ .../xla/service/gpu/fusions/reduction.h | 120 --- .../xla/service/gpu/fusions/thunk_util.cc | 120 --- .../xla/service/gpu/fusions/thunk_util.h | 37 - .../xla/service/gpu/fusions/tiling_util.cc | 33 - .../xla/service/gpu/fusions/tiling_util.h | 5 - .../compiler/xla/service/gpu/ir_emitter.cc | 4 - .../xla/service/gpu/ir_emitter_unnested.cc | 871 +++++++++++++++- .../xla/service/gpu/ir_emitter_unnested.h | 158 +++ 14 files changed, 1025 insertions(+), 1349 deletions(-) delete mode 100644 tensorflow/compiler/xla/service/gpu/fusions/reduction.cc delete mode 100644 tensorflow/compiler/xla/service/gpu/fusions/reduction.h delete mode 100644 tensorflow/compiler/xla/service/gpu/fusions/thunk_util.cc delete mode 100644 tensorflow/compiler/xla/service/gpu/fusions/thunk_util.h diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index b164a34e6f8fb8..d2f7da5ffd450f 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -309,7 +309,6 @@ cc_library( ":fft_thunk", ":gemm_thunk", ":gpu_asm_opts_util", - ":gpu_constants", ":gpu_conv_runner", ":gpu_device_info", ":gpu_executable", @@ -342,10 +341,8 @@ cc_library( "//tensorflow/compiler/xla/service:custom_call_target_registry", "//tensorflow/compiler/xla/service:name_uniquer", "//tensorflow/compiler/xla/service/gpu/fusions", - "//tensorflow/compiler/xla/service/gpu/fusions:thunk_util", "//tensorflow/compiler/xla/service/gpu/fusions:tiling_util", "//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util", - "//tensorflow/compiler/xla/service/llvm_ir:dynamic_update_slice_util", "//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter", "//tensorflow/compiler/xla/service/llvm_ir:ir_array", "//tensorflow/compiler/xla/service/llvm_ir:kernel_support_library", @@ -409,7 +406,6 @@ cc_library( ":backend_configs_cc", ":hlo_fusion_analysis", ":hlo_to_ir_bindings", - ":ir_emission_utils", ":ir_emitter_context", ":kernel_reuse_cache", ":target_util", diff --git a/tensorflow/compiler/xla/service/gpu/fusions/BUILD b/tensorflow/compiler/xla/service/gpu/fusions/BUILD index 644544246200ce..0673e218c07396 100644 --- a/tensorflow/compiler/xla/service/gpu/fusions/BUILD +++ b/tensorflow/compiler/xla/service/gpu/fusions/BUILD @@ -61,7 +61,6 @@ cc_library( ":fusion_emitter", ":in_place_dynamic_update_slice", ":loop", - ":reduction", "//tensorflow/compiler/xla/hlo/ir:hlo", "//tensorflow/compiler/xla/mlir_hlo:lhlo", "//tensorflow/compiler/xla/service:elemental_ir_emitter", @@ -102,48 +101,3 @@ cc_library( "@llvm-project//llvm:ir_headers", ], ) - -cc_library( - name = "reduction", - srcs = ["reduction.cc"], - hdrs = ["reduction.h"], - deps = [ - ":fusion_emitter", - ":thunk_util", - ":tiling_util", - "//tensorflow/compiler/xla/hlo/ir:hlo", - "//tensorflow/compiler/xla/service/gpu:gpu_executable", - "//tensorflow/compiler/xla/service/gpu:hlo_fusion_analysis", - "//tensorflow/compiler/xla/service/gpu:ir_emission_utils", - "//tensorflow/compiler/xla/service/gpu:ir_emitter", - "//tensorflow/compiler/xla/service/gpu:ir_emitter_context", - "//tensorflow/compiler/xla/service/gpu:kernel_reuse_cache", - "//tensorflow/compiler/xla/service/gpu:parallel_loop_emitter", - "//tensorflow/compiler/xla/service/gpu:target_util", - "//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter", - "//tensorflow/compiler/xla/service/llvm_ir:ir_array", - "//tensorflow/compiler/xla/service/llvm_ir:kernel_support_library", - "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", - "//tensorflow/compiler/xla/translate/mhlo_to_hlo:location_exporter", - "@llvm-project//llvm:ir_headers", - ], -) - -cc_library( - name = "thunk_util", - srcs = ["thunk_util.cc"], - hdrs = ["thunk_util.h"], - visibility = ["//tensorflow/compiler/xla/service/gpu:__subpackages__"], - deps = [ - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla/service/gpu:gpu_executable", - "//tensorflow/compiler/xla/service/gpu:ir_emission_utils", - "//tensorflow/compiler/xla/service/gpu:ir_emitter_context", - "//tensorflow/compiler/xla/service/gpu:thunk", - "//tensorflow/compiler/xla/translate/hlo_to_mhlo:hlo_utils", - "@com_google_absl//absl/types:span", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:MemRefDialect", - ], -) diff --git a/tensorflow/compiler/xla/service/gpu/fusions/fusion_emitter.cc b/tensorflow/compiler/xla/service/gpu/fusions/fusion_emitter.cc index ab839a522c5df0..b4c60d1c2bc334 100644 --- a/tensorflow/compiler/xla/service/gpu/fusions/fusion_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/fusions/fusion_emitter.cc @@ -74,8 +74,6 @@ void AnnotateKernelLaunchDimensions(const LaunchDimensions& launch_dims, } } -} // namespace - std::tuple, std::vector> BuildKernelPrototype(IrEmitterContext& ir_emitter_context, @@ -173,6 +171,8 @@ BuildKernelPrototype(IrEmitterContext& ir_emitter_context, return {kernel, std::move(inputs), std::move(outputs)}; } +} // namespace + StatusOr KernelFusionEmitterBase::Emit( KernelReuseCache& kernel_cache, llvm::IRBuilder<>* builder) const { std::string suggested_kernel_name = GetIrNameFromLoc(fusion_op_->getLoc()); diff --git a/tensorflow/compiler/xla/service/gpu/fusions/fusion_emitter.h b/tensorflow/compiler/xla/service/gpu/fusions/fusion_emitter.h index b31247bd987da6..0fdc738a29eee5 100644 --- a/tensorflow/compiler/xla/service/gpu/fusions/fusion_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/fusions/fusion_emitter.h @@ -80,15 +80,6 @@ class KernelFusionEmitterBase : public FusionInterface { const HloFusionInstruction& fusion_; }; -std::tuple, - std::vector> -BuildKernelPrototype(IrEmitterContext& ir_emitter_context, - const std::string& suggested_name, - absl::Span arguments, - size_t num_inputs, - const LaunchDimensions& launch_dimensions, - llvm::IRBuilder<>* builder); - } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/fusions/fusions.cc b/tensorflow/compiler/xla/service/gpu/fusions/fusions.cc index 34723c1ec64eef..a6c800556c5c7a 100644 --- a/tensorflow/compiler/xla/service/gpu/fusions/fusions.cc +++ b/tensorflow/compiler/xla/service/gpu/fusions/fusions.cc @@ -21,7 +21,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/fusions/copy.h" #include "tensorflow/compiler/xla/service/gpu/fusions/in_place_dynamic_update_slice.h" #include "tensorflow/compiler/xla/service/gpu/fusions/loop.h" -#include "tensorflow/compiler/xla/service/gpu/fusions/reduction.h" #include "tensorflow/compiler/xla/service/gpu/hlo_fusion_analysis.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" @@ -50,9 +49,6 @@ std::optional> GetFusionEmitter( ElementalIrEmitter& elemental_emitter, mlir::lmhlo::FusionOp fusion_op, const HloFusionInstruction& fusion) { switch (analysis.GetEmitterFusionKind()) { - case HloFusionAnalysis::EmitterFusionKind::kReduction: - return std::make_unique( - ir_emitter_context, elemental_emitter, fusion_op, fusion, analysis); case HloFusionAnalysis::EmitterFusionKind::kLoop: { bool is_single = IsSingleInstructionFusion(fusion_op); if (!is_single && CanEmitFusedDynamicUpdateSliceInPlaceForGpu( diff --git a/tensorflow/compiler/xla/service/gpu/fusions/reduction.cc b/tensorflow/compiler/xla/service/gpu/fusions/reduction.cc deleted file mode 100644 index dfc1fbaa3f94d2..00000000000000 --- a/tensorflow/compiler/xla/service/gpu/fusions/reduction.cc +++ /dev/null @@ -1,959 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu/fusions/reduction.h" - -#include -#include - -#include "llvm/IR/IRBuilder.h" -#include "tensorflow/compiler/xla/hlo/ir/hlo_casting_utils.h" -#include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" -#include "tensorflow/compiler/xla/hlo/ir/hlo_instructions.h" -#include "tensorflow/compiler/xla/service/gpu/fusions/fusion_emitter.h" -#include "tensorflow/compiler/xla/service/gpu/fusions/thunk_util.h" -#include "tensorflow/compiler/xla/service/gpu/fusions/tiling_util.h" -#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" -#include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" -#include "tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h" -#include "tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h" -#include "tensorflow/compiler/xla/service/gpu/kernel_reuse_cache.h" -#include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h" -#include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h" -#include "tensorflow/compiler/xla/service/gpu/target_util.h" -#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" -#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" -#include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" -#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" -#include "tensorflow/compiler/xla/translate/mhlo_to_hlo/location_exporter.h" - -namespace xla { -namespace gpu { -namespace { - -using TypedPointer = std::pair; - -// Fusion root -> array of indexes, one per reduction output. -using ReductionOutputMap = - ConstHloInstructionMap>; -using ExtraOutputGensMap = ConstHloInstructionMap; - -void MaybeEmitFenceForAMDGPU(llvm::IRBuilder<>* builder, - IrEmitterContext& ir_emitter_context) { - auto* module = builder->GetInsertBlock()->getModule(); - if (IsAMDGPU(module) && - ir_emitter_context.rocm_compute_capability().gcn_arch_name().substr( - 0, 6) == "gfx90a") { - builder->CreateFence( - llvm::AtomicOrdering::SequentiallyConsistent, - builder->getContext().getOrInsertSyncScopeID("workgroup")); - } -} - -void EmitSyncThreads(llvm::IRBuilder<>* builder, - IrEmitterContext& ir_emitter_context) { - MaybeEmitFenceForAMDGPU(builder, ir_emitter_context); - EmitCallToTargetIntrinsic(TargetIntrinsicID::kBarrierId, {}, {}, builder); -} - -// For a row reduction, returns the number of rows we can process in parallel -// per warp. -int RowReductionGetRowsPerWarp(int reduced_dimension_size) { - if (WarpSize() % reduced_dimension_size != 0 || - reduced_dimension_size >= WarpSize()) { - return 1; - } - return WarpSize() / reduced_dimension_size; -} - -llvm::GlobalVariable* AllocateShared( - llvm::IRBuilder<>* builder, const TilingScheme& tiling_scheme, - llvm::Type* element_type, - absl::Span dimensions_major_to_minor, - absl::string_view buffer_name) { - CHECK(!dimensions_major_to_minor.empty()); - llvm::Type* ty = element_type; - for (auto dim : llvm::reverse(dimensions_major_to_minor)) { - ty = llvm::ArrayType::get(ty, dim); - } - ty = llvm::ArrayType::get(ty, tiling_scheme.GetThreadIdScalingFactor()); - return llvm_ir::AllocateSharedMemoryTile( - builder->GetInsertBlock()->getModule(), ty, buffer_name); -} - -Status EmitExtraOutputsForReduce(llvm::IRBuilder<>* builder, - const Shape& reduction_operand_shape, - const ReductionOutputMap& result_ir_arrays, - const llvm_ir::IrArray::Index& index, - const ReductionCodegenInfo& reduction_info, - const ExtraOutputGensMap& extra_output_gens) { - if (extra_output_gens.empty()) { - return OkStatus(); - } - - // Compute all extra output values before writing them. This avoids - // overwriting aliased input/output buffers before all reads occurred. - std::vector> - extra_output_ir_values; - extra_output_ir_values.reserve(extra_output_gens.size()); - - auto get_index = [&](const HloInstruction* instr) { - const Shape& s = instr->shape(); - return ShapeUtil::EqualIgnoringElementType(reduction_operand_shape, s) - ? index - : index.SourceIndexOfBitcast(reduction_operand_shape, s, - builder); - }; - - for (const auto& [instr, generator] : extra_output_gens) { - TF_ASSIGN_OR_RETURN(llvm::Value* const extra_output_ir_value, - generator(get_index(instr))); - extra_output_ir_values.emplace_back(instr, extra_output_ir_value); - } - - for (const auto& [instr, generator] : extra_output_ir_values) { - absl::Span result_ir = result_ir_arrays.at(instr); - CHECK_EQ(result_ir.size(), 1); - result_ir[0].EmitWriteArrayElement( - get_index(instr), generator, builder, /*use_linear_index=*/ - reduction_info.GetNumPartialResults() == 1); - } - return OkStatus(); -} - -ReductionCodegenState GenerateReductionCodegenState( - llvm::IRBuilder<>* builder, mlir::lmhlo::FusionOp fusion, - const ReductionCodegenInfo& reduction_info, - absl::Span reduce_instr_index_group, - FusedIrEmitter& fused_emitter) { - ReductionCodegenState reduction_codegen_state(reduction_info); - VLOG(10) << "Emit prologue for reduction: " << llvm_ir::DumpToString(fusion); - - for (const HloReduceInstruction* reduce_hlo : reduce_instr_index_group) { - int num_partial_results = reduction_codegen_state.GetNumPartialResults(); - int num_outputs = reduce_hlo->shape().IsTuple() - ? reduce_hlo->shape().tuple_shapes_size() - : 1; - for (int op_result_idx = 0; op_result_idx < num_outputs; op_result_idx++) { - Shape result_shape = reduce_hlo->shape().IsTuple() - ? reduce_hlo->shape().tuple_shapes(op_result_idx) - : reduce_hlo->shape(); - - llvm::Type* element_type = llvm_ir::PrimitiveTypeToIrType( - result_shape.element_type(), builder->GetInsertBlock()->getModule()); - llvm::AllocaInst* reduction_input_address = - llvm_ir::EmitAllocaAtFunctionEntry( - element_type, "reduction_input_address", builder); - - llvm::AllocaInst* partial_result_address = - llvm_ir::EmitAllocaAtFunctionEntryWithCount( - element_type, - /*element_count=*/builder->getInt32(num_partial_results), - "partial_reduction_result", builder); - - const HloInstruction* init_value = - reduce_hlo->init_values()[op_result_idx]; - - // Initialize the partial result with the initial value of the reduction. - llvm::Value* init_ir_value = (*fused_emitter.GetGenerator( - *init_value))(llvm_ir::IrArray::Index(builder->getInt32Ty())) - .value(); - - for (int i = 0; i < num_partial_results; ++i) { - builder->CreateStore( - init_ir_value, builder->CreateInBoundsGEP( - partial_result_address->getAllocatedType(), - partial_result_address, {builder->getInt32(i)})); - } - - const TilingScheme& tiling_scheme = - reduction_codegen_state.GetTilingScheme(); - int64_t num_threads_x = - tiling_scheme.GetNumThreadsFor(TilingScheme::DimX); - llvm::GlobalVariable* shared_cache = [&]() -> llvm::GlobalVariable* { - if (reduction_codegen_state.IsRowReduction()) { - // Multi-row reductions do not use shared memory. - if (RowReductionGetRowsPerWarp(tiling_scheme.GetDimsInElems()[2]) > - 1) { - return nullptr; - } - // Allocate __shared__ - // cache[num_partial_results][num_warps][scaling_factor]. - CHECK_EQ(tiling_scheme.GetNumThreadsPerBlock() % WarpSize(), 0); - int num_warps = tiling_scheme.GetNumThreadsPerBlock() / WarpSize(); - return AllocateShared(builder, tiling_scheme, element_type, - {num_partial_results, num_warps}, - "shared_cache"); - } else { - // Allocate __shared__ - // cache[num_threads][num_threads + 1], where - // num_threads == num_threads_x == num_threads_y. The "+1" is used to - // avoid bank conflicts. - // - // (Although each thread produces num_partial_results results, we - // don't need that much cache: Only one result is live at a time.) - CHECK_EQ(num_threads_x, - tiling_scheme.GetNumThreadsFor(TilingScheme::DimY)); - return AllocateShared(builder, tiling_scheme, element_type, - {num_threads_x, num_threads_x + 1}, - "shared_cache"); - } - }(); - - llvm_ir::ElementGenerator input_gen = - *fused_emitter.GetGenerator(*reduce_hlo->inputs()[op_result_idx]); - reduction_codegen_state.SetCalculationStateFor( - {shared_cache, init_ir_value, partial_result_address, - reduction_input_address, input_gen}, - reduce_hlo, op_result_idx); - } - } - - return reduction_codegen_state; -} - -// Generate a single element of the tile (update the accumulator state) for a -// given reducer of index `i`. -void GenerateElementForReducer( - llvm::IRBuilder<>* builder, IrEmitterContext& ir_emitter_context, - const HloReduceInstruction* reduction, llvm::Value* partial_result_index, - const ReductionCodegenState& codegen_state, - const llvm_ir::IrArray::Index& index_without_linear, - const llvm_ir::IrArray::Index& input_index, int num_partial_results, - const ReductionOutputMap& result_ir_arrays) { - HloComputation* reducer = reduction->to_apply(); - CHECK_EQ(reducer->num_parameters() % 2, 0); - - absl::InlinedVector reduction_accumulators; - absl::InlinedVector reduction_input_value; - for (int red_idx = 0; red_idx < reducer->num_parameters() / 2; red_idx++) { - const ReductionCodegenState::ReductionCalculationState& state = - codegen_state.GetCalculationStateFor(reduction, red_idx); - - llvm::AllocaInst* input_address = state.input_address; - llvm::AllocaInst* partial_reduction_result_address = - state.partial_result_address; - llvm::Value* const input_ir_value = *state.input_gen( - num_partial_results > 1 ? index_without_linear : input_index); - builder->CreateStore(input_ir_value, input_address); - llvm::Value* partial_result_address = builder->CreateInBoundsGEP( - partial_reduction_result_address->getAllocatedType(), - partial_reduction_result_address, {partial_result_index}); - reduction_accumulators.push_back(partial_result_address); - reduction_input_value.push_back(input_address); - } - - absl::InlinedVector reduction_params; - for (llvm::Value* acc : reduction_accumulators) { - reduction_params.push_back(acc); - } - for (llvm::Value* value : reduction_input_value) { - reduction_params.push_back(value); - } - - // Emit a call to the variadic reducer. Since it may be returning a - // tuple, we can't return it directly as a value. Instead, before - // the call, we create N (N = # arguments in the tuple) allocas, one - // for each returned argument, then when we make the call we pass N - // pointers as last parameters, the called computation writes into - // those pointers, and we have returned values on the stack (as well - // as pointers to them). - StatusOr> returned_scalars = - CallNestedComputationWithScalarAddrs(builder, ir_emitter_context, - *reducer, reduction_params); - TF_CHECK_OK(returned_scalars.status()); - - for (int i = 0; i < returned_scalars->size(); i++) { - builder->CreateStore(returned_scalars->at(i), reduction_accumulators[i]); - } -} - -// Emits shuffle-down reduction for the `partial_result_address` using the -// reduction computation `reducer`, writes output into -// `partial_result_address`. -// -// Multiple partial_result_address inputs happen when doing variadic -// reduction: each one should get the output value. -void EmitFullWarpShuffleDownLoopForReduce( - llvm::IRBuilder<>* builder, IrEmitterContext& ir_emitter_context, - const HloComputation* reducer, - absl::Span partial_result_addresses, - int threads_per_block, int num_results_per_warp = 1) { - // This only works when the block size is a multiple of 32 threads. - - // We check this here as a mistake in the number of threads per - // block is very hard to detect. - CHECK_EQ(threads_per_block % 32, 0); - CHECK_EQ(WarpSize() % num_results_per_warp, 0); - - for (int distance = 16 / num_results_per_warp; distance >= 1; distance /= 2) { - absl::InlinedVector reduction_params; - - for (auto acc : partial_result_addresses) { - reduction_params.push_back(acc.first); - } - - for (auto [partial_result_address, element_type] : - partial_result_addresses) { - int bit_width = llvm_ir::GetSizeInBits(element_type); - llvm::Value* result_from_other_lane = llvm_ir::EmitAllocaAtFunctionEntry( - element_type, "result_from_other_lane", builder); - - reduction_params.push_back(result_from_other_lane); - - // Bitcast cannot be applied to aggregate types (even packed ones), so - // we bitcast addresses of load/store to intN* of the same bit-width. - llvm::Type* shuffled_value_type = element_type->isStructTy() - ? builder->getIntNTy(bit_width) - : element_type; - auto convert_pointer_for_shuffle = [&](llvm::Value* ptr) { - return builder->CreatePointerBitCastOrAddrSpaceCast( - ptr, shuffled_value_type->getPointerTo()); - }; - - llvm::Value* partial_result = builder->CreateLoad( - shuffled_value_type, - convert_pointer_for_shuffle(partial_result_address), - "partial_reduction_result"); - builder->CreateStore( - EmitFullWarpShuffleDown(partial_result, builder->getInt32(distance), - builder), - convert_pointer_for_shuffle(result_from_other_lane)); - } - - StatusOr> returned_scalars = - CallNestedComputationWithScalarAddrs(builder, ir_emitter_context, - *reducer, reduction_params); - TF_CHECK_OK(returned_scalars.status()); - - for (int i = 0; i < returned_scalars->size(); i++) { - builder->CreateStore(/*Val=*/returned_scalars->at(i), - /*Ptr=*/partial_result_addresses[i].first); - } - } -} - -// Gets the output offset as calculated from thread_id.x (to be applied to the -// offset calculated from block_id and thread_id.y). -llvm::Value* GetStartOffsetX(const TilingScheme& tiling_scheme, - llvm::Value* thread_id_x, llvm::Type* index_ty, - llvm::IRBuilder<>* b) { - int64_t multiplier = - tiling_scheme.GetIndexingOrder() == TilingScheme::StridedIndexingX - ? tiling_scheme.GetVectorSize() - : tiling_scheme.GetTileSizeFor(TilingScheme::DimX); - return b->CreateMul(thread_id_x, - llvm::ConstantInt::get(index_ty, multiplier)); -} - -llvm::Value* GetOutputAddressForReduction( - llvm::IRBuilder<>* builder, int partial_result_idx, llvm::Type* index_ty, - const ReductionCodegenState& reduction_codegen_state, - const TilingKernelInfo& tiling_kernel_info, - const ReductionOutputMap& output_arrays, - const HloReduceInstruction* reduction, int output_idx) { - auto constant = [&](uint64_t c) -> llvm::Constant* { - return llvm::ConstantInt::get(index_ty, c); - }; - - const TilingScheme& tiling_scheme = reduction_codegen_state.GetTilingScheme(); - const TilingThreadIdInfo& thread_id_info = tiling_kernel_info.thread_id_info; - - llvm_ir::IrArray::Index start_offset = [&] { - llvm::Value* x_loc = thread_id_info.thread_id_x; - llvm::Value* y_loc = thread_id_info.thread_id_y; - if (!reduction_codegen_state.IsRowReduction()) { - std::swap(x_loc, y_loc); - } - llvm::Value* start_offset_x = - GetStartOffsetX(tiling_scheme, x_loc, index_ty, builder); - return tiling_kernel_info.tile_origin - .AddOffsetToDim(y_loc, TilingScheme::DimY, builder) - .AddOffsetToDim(start_offset_x, TilingScheme::DimX, builder); - }(); - - const llvm_ir::IrArray& output_array = - output_arrays.at(reduction)[output_idx]; - const Shape& operand_shape = reduction->inputs()[output_idx]->shape(); - Shape reduction_kept_element_shape = - ShapeUtil::DeleteDimensions(reduction->dimensions(), operand_shape); - - // Given the IrArray index of a reduction input, returns the linear address of - // the reduction output as if the reduction were going to keep the input shape - // with the dimensions being reduced moved. - llvm::Value* untransposed_output_linear_address = [&] { - const llvm_ir::IrArray::Index index = start_offset.AddOffsetToDim( - constant(partial_result_idx), TilingScheme::DimX, builder); - if (reduction_codegen_state.IsRowReduction()) { - // For row-reduction, y-coordinate determines which row we write into. - return index[TilingScheme::DimY]; - } - // For column reduction, we get the transposed address. - absl::Span dims_in_elem = tiling_scheme.GetDimsInElems(); - llvm::Value* x_dim_size = - index.GetConstantWithIndexType(dims_in_elem[TilingScheme::DimX]); - llvm::Value* x_block_offset = - builder->CreateMul(index[TilingScheme::DimZ], x_dim_size); - return builder->CreateAdd(x_block_offset, index[TilingScheme::DimX]); - }(); - - // A reduction is allowed to transpose its output. For example, suppose - // we are reducing the second dimension of f32[10,20,30]{3,2,1}. We are - // allowed to produce as output either f32[10,30]{1,0} (no transpose) or - // f32[10,30]{0,1} (transposing the two output dims). - // - // At this point in the function we have a "partial sum" of input elements - // (stored in partial_result_addresses), and we need to accumulate it into - // the correct output element. - llvm_ir::IrArray::Index element_index( - /*linear=*/untransposed_output_linear_address, - reduction_kept_element_shape, builder); - llvm_ir::IrArray::Index output_index(element_index.multidim(), - output_array.GetShape(), - element_index.GetType()); - - return output_array.EmitArrayElementAddress(output_index, builder, - "output_element_address"); -} - -// Wraps up the code generation for a tile block of a reduction kernel: -// write the calculated output into the output tensor. -void WriteReductionOutput(llvm::IRBuilder<>* builder, - IrEmitterContext& ir_emitter_context, - llvm::Type* index_ty, - const ReductionCodegenState& reduction_codegen_state, - const TilingKernelInfo& tiling_kernel_info, - const ReductionOutputMap& output_arrays, - const HloReduceInstruction* reduction, - int partial_result_idx, - const absl::Span values) { - const HloComputation* reducer = reduction->to_apply(); - for (const auto& [oidx, typed_ptr] : llvm::enumerate(values)) { - auto [output_ptr, type] = typed_ptr; - llvm::Value* output_address = GetOutputAddressForReduction( - builder, partial_result_idx, index_ty, reduction_codegen_state, - tiling_kernel_info, output_arrays, reduction, oidx); - if (reduction_codegen_state.IsRaceFree()) { - builder->CreateStore(builder->CreateLoad(type, output_ptr, "output"), - output_address); - } else { - CHECK_EQ(values.size(), 1); - TF_CHECK_OK(EmitAtomicOperationForNestedComputation( - builder, ir_emitter_context, *reducer, output_address, output_ptr, - type)); - } - } -} - -// `current_output`: the value the tile has calculated. -// `output_address`: address where the output value has to be written. -void EmitReductionOutputForRowReduction( - llvm::IRBuilder<>* builder, IrEmitterContext& ir_emitter_context, - const TilingKernelInfo& tiling_kernel_info, - const ReductionCodegenState& reduction_codegen_state, llvm::Type* index_ty, - const ReductionOutputMap& output_arrays, - const HloReduceInstruction* reduction, int partial_result_idx) { - const HloComputation* reducer = reduction->to_apply(); - const auto& thread_id_info = tiling_kernel_info.thread_id_info; - auto constant = [&](uint64_t c) -> llvm::Constant* { - return llvm::ConstantInt::get(index_ty, c); - }; - auto is_zero = [&](llvm::Value* value) { - return builder->CreateICmpEQ(value, constant(0)); - }; - - int num_outputs = reducer->num_parameters() / 2; - const TilingScheme& tiling_scheme = reduction_codegen_state.GetTilingScheme(); - absl::InlinedVector current_outputs; - for (int output_idx = 0; output_idx < num_outputs; output_idx++) { - const ReductionCodegenState::ReductionCalculationState& state = - reduction_codegen_state.GetCalculationStateFor(reduction, output_idx); - current_outputs.push_back( - {builder->CreateInBoundsGEP( - state.partial_result_address->getAllocatedType(), - state.partial_result_address, {constant(partial_result_idx)}, - "current_output"), - state.partial_result_address->getAllocatedType()}); - } - - int reduced_dimension_size = tiling_scheme.GetDimsInElems()[2]; - int num_rows_per_warp = RowReductionGetRowsPerWarp(reduced_dimension_size); - EmitFullWarpShuffleDownLoopForReduce( - builder, ir_emitter_context, reducer, absl::MakeSpan(current_outputs), - tiling_scheme.GetNumThreadsPerBlockPhysical(), num_rows_per_warp); - - KernelSupportLibrary ksl(builder); - llvm::Value* warp_id = - builder->CreateUDiv(thread_id_info.thread_id_x, constant(WarpSize())); - - auto emit_write_output = [&](llvm::Value* write_condition, - const absl::Span values) { - ksl.If("reduction_write_output", write_condition, [&] { - WriteReductionOutput(builder, ir_emitter_context, index_ty, - reduction_codegen_state, tiling_kernel_info, - output_arrays, reduction, partial_result_idx, - values); - }); - }; - - if (num_rows_per_warp > 1) { - llvm::Value* is_writing_thread = is_zero(builder->CreateAnd( - thread_id_info.thread_id_x, constant(reduced_dimension_size - 1))); - emit_write_output(is_writing_thread, current_outputs); - return; - } - - ksl.If("intra_warp_reduce_write", is_zero(thread_id_info.lane_id), [&] { - for (int oidx = 0; oidx < num_outputs; oidx++) { - const ReductionCodegenState::ReductionCalculationState& state = - reduction_codegen_state.GetCalculationStateFor(reduction, oidx); - llvm::Value* shmem_output_addr = thread_id_info.GEPIntoSharedMemory( - builder, state.shared_cache, {constant(partial_result_idx), warp_id}); - builder->CreateStore(builder->CreateLoad(current_outputs[oidx].second, - current_outputs[oidx].first), - shmem_output_addr); - } - }); - - // TODO(cheshire): Don't we want to sync it once for everything in the - // output? Not once per each? - EmitSyncThreads(builder, ir_emitter_context); - ksl.If("inter_warp_reduce", is_zero(warp_id), [&] { - absl::InlinedVector selected_values; - for (int oidx = 0; oidx < num_outputs; oidx++) { - const ReductionCodegenState::ReductionCalculationState& state = - reduction_codegen_state.GetCalculationStateFor(reduction, oidx); - llvm::Value* block_accum_addr = thread_id_info.GEPIntoSharedMemory( - builder, state.shared_cache, - {constant(partial_result_idx), thread_id_info.lane_id}); - - llvm::Type* element_type = - state.partial_result_address->getAllocatedType(); - - // Ensure initial value address is in generic, not scratch. - llvm::Value* initial_value_addr = builder->CreateAddrSpaceCast( - llvm_ir::EmitAllocaAtFunctionEntry(element_type, "initial_value_addr", - builder), - llvm::PointerType::get(element_type, - /*AddressSpace=*/0)); - builder->CreateStore(state.initial_value, initial_value_addr); - - llvm::Value* warp_exists = builder->CreateICmpULT( - thread_id_info.thread_id_x, - constant(tiling_scheme.GetNumThreadsFor(TilingScheme::DimX) / - WarpSize())); - - llvm::Value* selected_value = builder->CreateSelect( - warp_exists, block_accum_addr, initial_value_addr); - - selected_values.push_back({selected_value, element_type}); - } - - // If only one warp is present in the block, then we don't need inter-warp - // reduction. - // TODO(b/241414088) If only warp is present, then inter-warp communication - // using shared memory and synchronization using barrier is also unnecessary - // and should be removed. - if (tiling_scheme.GetNumThreadsPerBlock() > WarpSize()) { - EmitFullWarpShuffleDownLoopForReduce( - builder, ir_emitter_context, reducer, absl::MakeSpan(selected_values), - tiling_scheme.GetNumThreadsPerBlock()); - } - - emit_write_output(is_zero(thread_id_info.thread_id_x), selected_values); - }); -} - -// Same arguments as EmitReductionOutputForRowReduction. -void EmitReductionOutputForColumnReduction( - llvm::IRBuilder<>* builder, IrEmitterContext& ir_emitter_context, - const TilingKernelInfo& tiling_kernel_info, - const ReductionCodegenState& reduction_codegen_state, llvm::Type* index_ty, - const ReductionOutputMap& output_arrays, - const HloReduceInstruction* reduction, int partial_result_idx) { - KernelSupportLibrary ksl(builder); - const HloComputation* reducer = reduction->to_apply(); - const auto& thread_id_info = tiling_kernel_info.thread_id_info; - - auto constant = [&](uint64_t c) -> llvm::Constant* { - return llvm::ConstantInt::get(index_ty, c); - }; - auto is_zero = [&](llvm::Value* value) { - return builder->CreateICmpEQ(value, constant(0)); - }; - const TilingScheme& tiling_scheme = reduction_codegen_state.GetTilingScheme(); - int num_outputs = reducer->num_parameters() / 2; - - // Wait for reads from shmem in the last iteration to complete. (If this is - // slow, we could "double-buffer" by having two shmem buffers and switching - // between them.) - if (partial_result_idx > 0) { - EmitSyncThreads(builder, ir_emitter_context); - } - - // Store the transpose in shared memory. - for (int output_idx = 0; output_idx < num_outputs; output_idx++) { - const ReductionCodegenState::ReductionCalculationState& state = - reduction_codegen_state.GetCalculationStateFor(reduction, output_idx); - llvm::GlobalVariable* shared_cache = state.shared_cache; - llvm::AddrSpaceCastInst* shmem_output_addr = - llvm::cast(thread_id_info.GEPIntoSharedMemory( - builder, shared_cache, - {thread_id_info.thread_id_x, thread_id_info.thread_id_y}, - "shmem_output_address")); - llvm::Value* current_output = builder->CreateInBoundsGEP( - state.partial_result_address->getAllocatedType(), - state.partial_result_address, {constant(partial_result_idx)}, - "current_output"); - - llvm::Value* current_output_value = builder->CreateLoad( - state.partial_result_address->getAllocatedType(), current_output); - builder->CreateStore(current_output_value, shmem_output_addr); - } - - EmitSyncThreads(builder, ir_emitter_context); - - // Get transposed element from shared memory. - absl::InlinedVector shmem_transposed_addrs; - for (int output_idx = 0; output_idx < num_outputs; output_idx++) { - const ReductionCodegenState::ReductionCalculationState& state = - reduction_codegen_state.GetCalculationStateFor(reduction, output_idx); - llvm::AddrSpaceCastInst* shmem_transposed_addr = - llvm::cast(thread_id_info.GEPIntoSharedMemory( - builder, state.shared_cache, - {thread_id_info.thread_id_y, thread_id_info.thread_id_x}, - "shmem_transposed_addr")); - shmem_transposed_addrs.push_back( - {shmem_transposed_addr, llvm::cast( - shmem_transposed_addr->getPointerOperand()) - ->getResultElementType()}); - } - - EmitFullWarpShuffleDownLoopForReduce(builder, ir_emitter_context, reducer, - absl::MakeSpan(shmem_transposed_addrs), - tiling_scheme.GetNumThreadsPerBlock()); - - // Some warps in the block are completely outside of the bound of the - // tensor, so they should not write any output at all. - llvm::Value* has_output = builder->CreateAnd( - builder->CreateICmpULT( - GetStartOffsetX(tiling_scheme, thread_id_info.thread_id_y, index_ty, - builder), - tiling_kernel_info.output_tile_bounds[1]), - builder->CreateICmpULT(thread_id_info.thread_id_x, - tiling_kernel_info.output_tile_bounds[0])); - - ksl.If("reduction_write_output", - builder->CreateAnd(has_output, is_zero(thread_id_info.lane_id)), [&] { - WriteReductionOutput(builder, ir_emitter_context, index_ty, - reduction_codegen_state, tiling_kernel_info, - output_arrays, reduction, partial_result_idx, - shmem_transposed_addrs); - }); -} - -// Emits code for reductions in the output_instructions. -Status EmitIRForReduction(llvm::IRBuilder<>* builder, - IrEmitterContext& ir_emitter_context, - mlir::lmhlo::FusionOp fusion, - absl::Span instr_index_group, - FusedIrEmitter& fused_emitter, - const ReductionOutputMap& result_ir_arrays, - const ReductionCodegenInfo& reduction_info, - const Shape& input_shape) { - std::vector reductions; - ExtraOutputGensMap extra_output_gens; - - for (const HloInstruction* hlo : instr_index_group) { - if (IsReductionFromOrToContiguousDimensions(*hlo)) { - reductions.push_back(Cast(hlo)); - } else { - extra_output_gens[hlo] = *fused_emitter.GetGenerator(*hlo); - } - } - - CHECK(!reductions.empty()) << " expect at least one reduce instructions."; - const TilingScheme& tiling_scheme = reduction_info.GetTilingScheme(); - CHECK_EQ(tiling_scheme.GetNumThreadsPerBlockPhysical() % WarpSize(), 0); - llvm::Type* index_ty = - GetIndexTypeForKernel(fusion, - tiling_scheme.GetNumThreadsPerBlockPhysical() * - tiling_scheme.GetNumberOfBlocksPhysical(), - builder); - ReductionCodegenState codegen_state = GenerateReductionCodegenState( - builder, fusion, reduction_info, reductions, fused_emitter); - - EmitTileElementFunction emit_reduction_element = - [&](const TilingThreadIdInfo& thread_id_info, - const llvm_ir::IrArray::Index& index, llvm::Value* y_loc, - llvm::Value* x_loc) { - llvm_ir::IrArray::Index input_index = GetUnnormalizedIndex( - index, input_shape, builder, - codegen_state.GetTilingScheme().GetDimsInElems()); - llvm::Value* partial_result_index = - codegen_state.IsRowReduction() - ? builder->getInt32(0) - : builder->CreateSub( - x_loc, - GetStartOffsetX(tiling_scheme, thread_id_info.thread_id_x, - index_ty, builder)); - - // Clear the linear index field of the IrArray::Index to enable the use - // of GetElementPointer with array types. This enables the vectorization - // of the computation for different partial results. Use this index if - // 'num_partial_results > 1'. - int num_partial_results = codegen_state.GetNumPartialResults(); - llvm_ir::IrArray::Index index_without_linear{ - input_index.multidim(), input_shape, input_index.GetType()}; - - // Emit code to generate the input and perform the reduction computation - // for each reduction instruction. - for (const HloReduceInstruction* reduce : reductions) { - GenerateElementForReducer(builder, ir_emitter_context, reduce, - partial_result_index, codegen_state, - index_without_linear, input_index, - num_partial_results, result_ir_arrays); - } - - // Emit code to generate the output for the non-reduction instructions - // in the fusion, if any. - TF_CHECK_OK(EmitExtraOutputsForReduce( - builder, input_shape, result_ir_arrays, input_index, reduction_info, - extra_output_gens)); - }; - - TF_ASSIGN_OR_RETURN( - TilingKernelInfo tiling_kernel_info, - EmitTilingKernel(builder, tiling_scheme, index_ty, - [&](const TilingThreadIdInfo& thread_id_info, - const llvm_ir::IrArray::Index& index, - std::array tile_dimensions) { - EmitTile(builder, codegen_state.GetTilingScheme(), - index, thread_id_info, tile_dimensions, - emit_reduction_element); - })); - - KernelSupportLibrary ksl(builder); - for (const HloReduceInstruction* reduce : reductions) { - for (int partial_result_idx = 0; - partial_result_idx < reduction_info.GetNumPartialResults(); - ++partial_result_idx) { - if (codegen_state.IsRowReduction()) { - EmitReductionOutputForRowReduction( - builder, ir_emitter_context, tiling_kernel_info, codegen_state, - index_ty, result_ir_arrays, reduce, partial_result_idx); - } else { - EmitReductionOutputForColumnReduction( - builder, ir_emitter_context, tiling_kernel_info, codegen_state, - index_ty, result_ir_arrays, reduce, partial_result_idx); - } - } - } - - return OkStatus(); -} - -StatusOr> BuildKernelThunkForFusion( - IrEmitterContext& ir_emitter_context, KernelReuseCache& kernel_cache, - mlir::lmhlo::FusionOp fusion_op, const HloInstruction& fusion, - const LaunchDimensions& launch_dimensions, absl::string_view discriminator, - std::function, - std::vector)> - kernel_builder_fn, - llvm::IRBuilder<>* builder) { - TF_ASSIGN_OR_RETURN( - auto kernel_arguments, - KernelArguments::Create(ir_emitter_context.allocations(), fusion_op)); - - Status kernel_builder_status = OkStatus(); - auto [entry, cached] = kernel_cache.Get( - fusion.fused_instructions_computation(), kernel_arguments.args(), - discriminator, [&]() -> KernelReuseCache::Entry { - std::vector inputs, outputs; - llvm::Function* kernel; - std::tie(kernel, inputs, outputs) = BuildKernelPrototype( - ir_emitter_context, GetIrNameFromLoc(fusion_op->getLoc()), - kernel_arguments.args(), fusion.fused_parameters().size(), - launch_dimensions, builder); - kernel_builder_status = - kernel_builder_fn(std::move(inputs), std::move(outputs)); - return {kernel->getName().str(), launch_dimensions}; - }); - TF_RETURN_IF_ERROR(kernel_builder_status); - - return std::make_unique( - fusion_op, entry.kernel_name, kernel_arguments.args(), launch_dimensions); -} - -StatusOr> BuildFusedInitializerThunk( - IrEmitterContext& ir_emitter_context, mlir::lmhlo::FusionOp fusion_op, - const HloInstruction& fusion, ElementalIrEmitter& elemental_emitter, - KernelReuseCache& kernel_cache, int output_index, - llvm::IRBuilder<>* builder) { - auto reduce = mlir::dyn_cast_or_null( - fusion_op.getFusionRoots()[output_index]); - - TF_RET_CHECK(reduce); - TF_RET_CHECK(reduce.getNumResults() == 1); - - mlir::Value init_value = reduce.getInitValues()[0]; - mlir::Value dest = fusion_op.getOutputBuffers()[output_index]; - TF_ASSIGN_OR_RETURN(std::optional> constant_init_thunk, - BuildConstantInitializerThunk( - ir_emitter_context, fusion_op, init_value, dest)); - if (constant_init_thunk) { - return *std::move(constant_init_thunk); - } - - auto input_buffers = fusion_op.getInputBuffers(); - - const Shape dest_shape = GetShape(dest); - bool use_experimental_block_size = - ir_emitter_context.debug_options() - .xla_gpu_enable_experimental_block_size(); - - TF_ASSIGN_OR_RETURN(LaunchDimensions launch_dimensions, - CalculateLaunchDimensions( - dest_shape, ir_emitter_context.gpu_device_info(), - use_experimental_block_size)); - - const HloComputation* fused_computation = - fusion.fused_instructions_computation(); - HloInstruction* instr = fused_computation->root_instruction(); - if (instr->opcode() != HloOpcode::kTuple) { - CHECK_EQ(0, output_index); - } else { - instr = instr->mutable_operand(output_index); - } - TF_RET_CHECK(instr->shape().IsArray()); - - auto kernel_builder = [&](std::vector inputs, - std::vector outputs) -> Status { - FusedIrEmitter fused_emitter(elemental_emitter); - for (int i = 0; i < fused_computation->num_parameters(); i++) { - fused_emitter.BindGenerator( - *fused_computation->parameter_instruction(i), - [builder, &inputs, - i](llvm_ir::IrArray::Index index) -> StatusOr { - return inputs[i].EmitReadArrayElement(index, builder); - }); - } - TF_ASSIGN_OR_RETURN(auto generator, - fused_emitter.GetGenerator(*instr->operand(1))); - return ParallelLoopEmitter(generator, {outputs[0]}, launch_dimensions, - builder) - .EmitLoop(GetIrNameFromLoc(fusion_op.getLoc())); - }; - return BuildKernelThunkForFusion( - ir_emitter_context, kernel_cache, fusion_op, fusion, launch_dimensions, - /*discriminator=*/ - absl::StrCat("init_", output_index), kernel_builder, builder); -} - -} // namespace - -StatusOr ReductionFusion::Emit( - KernelReuseCache& kernel_cache, llvm::IRBuilder<>* builder) const { - auto* reduction_codegen_info = analysis_.GetReductionCodegenInfo(); - // Set `use_experimental_block_size` flag to false as the reduction code - // has its own custom logic of choosing a block size. - TF_ASSIGN_OR_RETURN(auto launch_dimensions, - analysis_.GetLaunchDimensions( - /*use_experimental_block_size=*/false)); - - FusionEmissionResult result; - VLOG(3) << "Launch dimensions of " - << mlir::mhlo::GetDebugNameFromLocation(fusion_op().getLoc()) << ": " - << launch_dimensions.ToString(); - if (!reduction_codegen_info->IsRaceFree()) { - absl::Span fusion_roots = analysis_.fusion_roots(); - for (int i = 0; i < fusion_roots.size(); ++i) { - if (IsReductionFromOrToContiguousDimensions(*fusion_roots[i])) { - TF_ASSIGN_OR_RETURN(result.thunks.emplace_back(), - BuildFusedInitializerThunk( - ir_emitter_context_, fusion_op(), fusion_, - elemental_emitter_, kernel_cache, i, builder)); - } - } - } - - auto kernel_builder = [&](std::vector inputs, - std::vector outputs) -> Status { - FusedIrEmitter fused_emitter(elemental_emitter_); - const HloComputation* fused_computation = analysis_.fused_computation(); - for (int i = 0; i < fused_computation->num_parameters(); i++) { - llvm_ir::IrArray ir_array = inputs[i]; - HloInstruction* fused_operand = - fused_computation->parameter_instruction(i); - fused_emitter.BindGenerator( - *fused_operand, - [builder, ir_array, fused_operand]( - const llvm_ir::IrArray::Index& index) -> StatusOr { - return ir_array.EmitReadArrayElement(index, builder, - fused_operand->name()); - }); - } - - // Get outputs. - ReductionOutputMap result_ir_arrays; - - // Skip all parameter buffers first. - int ir_arrays_idx = 0; - auto outputs_span = absl::MakeSpan(outputs); - for (HloInstruction* root : analysis_.fusion_roots()) { - int num_results = - root->shape().IsTuple() ? root->shape().tuple_shapes_size() : 1; - result_ir_arrays[root] = outputs_span.subspan(ir_arrays_idx, num_results); - ir_arrays_idx += num_results; - } - - KernelSupportLibrary ksl(builder, llvm_ir::UnrollMode::kDefaultUnroll); - - // Use raw block_id_y to select the i-th parallel reduction to run. Using - // block_id_y instead of block_id_x simplifies the index calculation - // for reduction code generation as the block_id_y is orthogonal to - // the indices used within the reductions. - const std::vector>& instr_index_groups = - reduction_codegen_info->GetIndexGroups(); - Shape reduce_operand_shape = - reduction_codegen_info->GetReduceOperandShape(); - - llvm::CallInst* raw_block_id_y = gpu::EmitCallToTargetIntrinsic( - gpu::TargetIntrinsicID::kBlockIdy, {}, {}, builder); - llvm_ir::AddRangeMetadata(0, instr_index_groups.size(), - llvm::cast(raw_block_id_y)); - for (int i = 0; i < instr_index_groups.size(); ++i) { - TF_RETURN_IF_ERROR(ksl.IfWithStatus( - absl::StrCat("reduce-group-", i), - builder->CreateICmpEQ(raw_block_id_y, builder->getInt32(i)), [&] { - return EmitIRForReduction(builder, ir_emitter_context_, fusion_op(), - instr_index_groups[i], fused_emitter, - result_ir_arrays, *reduction_codegen_info, - reduce_operand_shape); - })); - } - - return OkStatus(); - }; - - TF_ASSIGN_OR_RETURN( - result.thunks.emplace_back(), - BuildKernelThunkForFusion(ir_emitter_context_, kernel_cache, fusion_op(), - fusion_, launch_dimensions, "", kernel_builder, - builder)); - return result; -} - -} // namespace gpu -} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/fusions/reduction.h b/tensorflow/compiler/xla/service/gpu/fusions/reduction.h deleted file mode 100644 index 5bb45d4e3c01ea..00000000000000 --- a/tensorflow/compiler/xla/service/gpu/fusions/reduction.h +++ /dev/null @@ -1,120 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FUSIONS_REDUCTION_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FUSIONS_REDUCTION_H_ - -#include "tensorflow/compiler/xla/service/gpu/fusions/fusion_emitter.h" -#include "tensorflow/compiler/xla/service/gpu/hlo_fusion_analysis.h" - -namespace xla { -namespace gpu { - -// Generates code for reduction to contiguous dimensions. -// -// Row reduction uses the following algorithm described in CUDA-like -// pseudocode: -// -// ``` -// __global__ void reduce(int num_rows, float *in, float out) { -// __shared__ float[32] cache; -// int offset = blockDim.x * blockIdx.x + threadIdx.x; -// if (offset >= num_rows) return; -// int tile_bound = std::min(offset + kTileSizeX, num_rows); -// float accum = 0; -// for (int i=offset; i Emit( - KernelReuseCache& kernel_cache, - llvm::IRBuilder<>* builder) const override; - - private: - mlir::lmhlo::FusionOp fusion_op() const { return fusion_op_; } - - IrEmitterContext& ir_emitter_context_; - ElementalIrEmitter& elemental_emitter_; - mlir::lmhlo::FusionOp fusion_op_; - const HloFusionInstruction& fusion_; - HloFusionAnalysis& analysis_; -}; - -} // namespace gpu -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FUSIONS_REDUCTION_H_ diff --git a/tensorflow/compiler/xla/service/gpu/fusions/thunk_util.cc b/tensorflow/compiler/xla/service/gpu/fusions/thunk_util.cc deleted file mode 100644 index e38504096376a8..00000000000000 --- a/tensorflow/compiler/xla/service/gpu/fusions/thunk_util.cc +++ /dev/null @@ -1,120 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu/fusions/thunk_util.h" - -#include -#include - -#include "absl/types/span.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project -#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/IR/Operation.h" // from @llvm-project -#include "mlir/IR/SymbolTable.h" // from @llvm-project -#include "mlir/IR/Value.h" // from @llvm-project -#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" -#include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" -#include "tensorflow/compiler/xla/service/gpu/memset_thunk.h" -#include "tensorflow/compiler/xla/service/gpu/thunk.h" -#include "tensorflow/compiler/xla/shape.h" -#include "tensorflow/compiler/xla/translate/hlo_to_mhlo/hlo_utils.h" - -namespace xla { -namespace gpu { -namespace { - -// TODO(b/291536641): Clean this up. What's the difference between this and the -// caller? -std::optional> BuildConstantInitializerThunk( - mlir::Operation* op, absl::Span init_value, mlir::Value dest, - const BufferAllocation::Slice& dest_slice, const Shape& output_shape) { - int64_t num_bytes = init_value.size(); - if (absl::c_all_of(init_value, [](uint8_t byte) { return byte == 0; })) { - return {{std::make_unique(Thunk::ThunkInfo(op), dest_slice, - dest)}}; - } - - // If the literal is 8 or 16 bits wide, we can emit a 32-bit memset by - // repeating the literal 4 or 2 times, so long as the destination buffer is - // an even multiple of 32 bits long. - if ((num_bytes == 1 || num_bytes == 2) && - ShapeUtil::ByteSizeOf(output_shape) % 4 == 0) { - uint16_t pattern16; - if (num_bytes == 1) { - uint8_t b = init_value.front(); - pattern16 = uint16_t{b} | (uint16_t{b} << 8); - } else { - memcpy(&pattern16, init_value.data(), sizeof(pattern16)); - } - uint32_t pattern32 = uint32_t{pattern16} | (uint32_t{pattern16} << 16); - return {{std::make_unique( - Thunk::ThunkInfo(op), pattern32, dest_slice, dest)}}; - } - - // If the literal is an even multiple of 32 bits wide, we can emit a 32-bit - // memset so long as all 32-bit words of the scalar are equal to each other. - if (num_bytes >= 4 && num_bytes % 4 == 0 && - memcmp(init_value.data(), init_value.data() + 4, init_value.size() - 4) == - 0) { - uint32_t word; - memcpy(&word, init_value.data(), sizeof(word)); - return {{std::make_unique(Thunk::ThunkInfo(op), word, - dest_slice, dest)}}; - } - - return std::nullopt; -} - -} // namespace - -StatusOr>> BuildConstantInitializerThunk( - IrEmitterContext& ir_emitter_context, mlir::Operation* op, - mlir::Value init_value, mlir::Value dest) { - mlir::DenseElementsAttr const_init; - if (auto get_global_memref = - mlir::dyn_cast_or_null( - init_value.getDefiningOp())) { - auto global_memref = - mlir::SymbolTable::lookupNearestSymbolFrom( - get_global_memref, get_global_memref.getNameAttr()); - if (global_memref.getConstant() && global_memref.getInitialValue()) { - // If the initial value happens to be a constant, generate a specialized - // thunk. - const_init = global_memref.getInitialValue() - .value() - .cast(); - } - } else if (auto constant = mlir::dyn_cast_or_null( - init_value.getDefiningOp())) { - const_init = constant.getValue().dyn_cast(); - } - - if (const_init) { - std::vector literal_bytes; - TF_RETURN_IF_ERROR( - CopyDenseElementsDataToXlaFormat(const_init, &literal_bytes)); - - TF_ASSIGN_OR_RETURN( - auto dest_slice, - GetAllocationSlice(dest, ir_emitter_context.allocations())); - - const Shape dest_shape = GetShape(dest); - return BuildConstantInitializerThunk(op, literal_bytes, dest, dest_slice, - dest_shape); - } - return std::nullopt; -} - -} // namespace gpu -} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/fusions/thunk_util.h b/tensorflow/compiler/xla/service/gpu/fusions/thunk_util.h deleted file mode 100644 index b52c535f034654..00000000000000 --- a/tensorflow/compiler/xla/service/gpu/fusions/thunk_util.h +++ /dev/null @@ -1,37 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FUSIONS_THUNK_UTIL_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FUSIONS_THUNK_UTIL_H_ - -#include -#include - -#include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" -#include "tensorflow/compiler/xla/service/gpu/thunk.h" -#include "tensorflow/compiler/xla/statusor.h" - -namespace xla { -namespace gpu { - -// Attempts to build an initializer constant for the given value. Returns an -// empty optional if the value is not a constant. -StatusOr>> BuildConstantInitializerThunk( - IrEmitterContext& ir_emitter_context, mlir::Operation* op, - mlir::Value init_value, mlir::Value dest); - -} // namespace gpu -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FUSIONS_THUNK_UTIL_H_ diff --git a/tensorflow/compiler/xla/service/gpu/fusions/tiling_util.cc b/tensorflow/compiler/xla/service/gpu/fusions/tiling_util.cc index 8d0ca80b2d9182..c0c2640a895937 100644 --- a/tensorflow/compiler/xla/service/gpu/fusions/tiling_util.cc +++ b/tensorflow/compiler/xla/service/gpu/fusions/tiling_util.cc @@ -360,38 +360,5 @@ llvm::Value* TilingThreadIdInfo::GEPIntoSharedMemory( return b->CreateAddrSpaceCast(gep, pointer_in_addressspace); } -llvm_ir::IrArray::Index GetUnnormalizedIndex( - const llvm_ir::IrArray::Index& normalized_shape_index, - const Shape& unnormalized_shape, llvm::IRBuilder<>* builder, - absl::Span dims_in_elems) { - CHECK_EQ(normalized_shape_index.size(), 3); - // If the normalization only add a new dimensions of size 1, - // generate simpler indexing. LLVM doesn't always simplify the more - // complicated indexing and this prevents it from vectorizing some - // cases. We do this only for major_to_minor memory layout. - if (unnormalized_shape.rank() == 2 && unnormalized_shape.has_layout() && - unnormalized_shape.dimensions()[0] == normalized_shape_index.dims()[1] && - unnormalized_shape.dimensions()[1] == normalized_shape_index.dims()[2] && - unnormalized_shape.layout().minor_to_major(1) == 0) { - CHECK_EQ(normalized_shape_index.dims()[0], 1); - auto multidim = normalized_shape_index.multidim(); - return llvm_ir::IrArray::Index({multidim[1], multidim[2]}, - unnormalized_shape, - normalized_shape_index.GetType()); - } - if (unnormalized_shape.rank() == 2 && unnormalized_shape.has_layout() && - unnormalized_shape.dimensions()[0] == normalized_shape_index.dims()[2] && - unnormalized_shape.dimensions()[1] == normalized_shape_index.dims()[1] && - unnormalized_shape.layout().minor_to_major(1) == 1) { - CHECK_EQ(normalized_shape_index.dims()[0], 1); - auto multidim = normalized_shape_index.multidim(); - return llvm_ir::IrArray::Index({multidim[2], multidim[1]}, - unnormalized_shape, - normalized_shape_index.GetType()); - } - return normalized_shape_index.SourceIndexOfBitcast( - ShapeUtil::MakeShape(F32, dims_in_elems), unnormalized_shape, builder); -} - } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/fusions/tiling_util.h b/tensorflow/compiler/xla/service/gpu/fusions/tiling_util.h index be697addeb78e0..1d8dd40e916300 100644 --- a/tensorflow/compiler/xla/service/gpu/fusions/tiling_util.h +++ b/tensorflow/compiler/xla/service/gpu/fusions/tiling_util.h @@ -138,11 +138,6 @@ StatusOr EmitTilingKernel( llvm::IRBuilder<>* builder, const TilingScheme& tiling_scheme, llvm::Type* index_ty, const TileElementGenerator& tile_element_generator); -llvm_ir::IrArray::Index GetUnnormalizedIndex( - const llvm_ir::IrArray::Index& normalized_shape_index, - const Shape& unnormalized_shape, llvm::IRBuilder<>* builder, - absl::Span dims_in_elems); - } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index 95d0a67c6c4f27..3070cbf9b1dfa7 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -21,17 +21,13 @@ limitations under the License. #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DerivedTypes.h" -#include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_computation.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" #include "tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h" -#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" -#include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h" -#include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 3e23e291c96330..d51ef9046c6e5b 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -92,7 +92,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/for_thunk.h" #include "tensorflow/compiler/xla/service/gpu/fused_mha_thunk.h" #include "tensorflow/compiler/xla/service/gpu/fusions/fusions.h" -#include "tensorflow/compiler/xla/service/gpu/fusions/thunk_util.h" #include "tensorflow/compiler/xla/service/gpu/fusions/tiling_util.h" #include "tensorflow/compiler/xla/service/gpu/gemm_thunk.h" #include "tensorflow/compiler/xla/service/gpu/gpu_asm_opts_util.h" @@ -2047,6 +2046,8 @@ Status IrEmitterUnnested::EmitFusion(mlir::Operation* op) { #endif LOG(FATAL) << "Unsupported fusion kind: " << backend_config.kind(); } + case HloFusionAnalysis::EmitterFusionKind::kReduction: + return EmitUnnestedReduction(fusion_op, fusion_analysis); case HloFusionAnalysis::EmitterFusionKind::kTranspose: return EmitUnnestedTranspose(fusion_op, fusion_analysis); case HloFusionAnalysis::EmitterFusionKind::kInputSlices: @@ -2054,7 +2055,6 @@ Status IrEmitterUnnested::EmitFusion(mlir::Operation* op) { case HloFusionAnalysis::EmitterFusionKind::kScatter: return EmitScatter(fusion_op, fused_computation, fusion_analysis); case HloFusionAnalysis::EmitterFusionKind::kLoop: - case HloFusionAnalysis::EmitterFusionKind::kReduction: return FailedPrecondition( "Loop fusion should have been handled by GetFusionEmitter."); } @@ -3207,6 +3207,86 @@ IrEmitterUnnested::BuildKernelThunkForNonFusionOp( launch_dimensions); } +std::unique_ptr IrEmitterUnnested::BuildConstantInitializerThunk( + mlir::Operation* op, absl::Span init_value, mlir::Value dest, + const BufferAllocation::Slice& dest_slice, const Shape& output_shape) { + int64_t num_bytes = init_value.size(); + if (absl::c_all_of(init_value, [](uint8_t byte) { return byte == 0; })) { + return std::make_unique(Thunk::ThunkInfo(op), dest_slice, + dest); + } + + // If the literal is 8 or 16 bits wide, we can emit a 32-bit memset by + // repeating the literal 4 or 2 times, so long as the destination buffer is + // an even multiple of 32 bits long. + if ((num_bytes == 1 || num_bytes == 2) && + ShapeUtil::ByteSizeOf(output_shape) % 4 == 0) { + uint16_t pattern16; + if (num_bytes == 1) { + uint8_t b = init_value.front(); + pattern16 = uint16_t{b} | (uint16_t{b} << 8); + } else { + memcpy(&pattern16, init_value.data(), sizeof(pattern16)); + } + uint32_t pattern32 = uint32_t{pattern16} | (uint32_t{pattern16} << 16); + return std::make_unique(Thunk::ThunkInfo(op), + pattern32, dest_slice, dest); + } + + // If the literal is an even multiple of 32 bits wide, we can emit a 32-bit + // memset so long as all 32-bit words of the scalar are equal to each other. + if (num_bytes >= 4 && num_bytes % 4 == 0 && + memcmp(init_value.data(), init_value.data() + 4, init_value.size() - 4) == + 0) { + uint32_t word; + memcpy(&word, init_value.data(), sizeof(word)); + return std::make_unique(Thunk::ThunkInfo(op), word, + dest_slice, dest); + } + + return nullptr; +} + +StatusOr> +IrEmitterUnnested::TryBuildConstantInitializerThunk(mlir::Operation* op, + mlir::Value init_value, + mlir::Value dest) { + mlir::DenseElementsAttr const_init; + if (auto get_global_memref = + mlir::dyn_cast_or_null( + init_value.getDefiningOp())) { + auto global_memref = + mlir::SymbolTable::lookupNearestSymbolFrom( + get_global_memref, get_global_memref.getNameAttr()); + if (global_memref.getConstant() && global_memref.getInitialValue()) { + // If the initial value happens to be a constant, generate a specialized + // thunk. + const_init = global_memref.getInitialValue() + .value() + .cast(); + } + } else if (auto constant = mlir::dyn_cast_or_null( + init_value.getDefiningOp())) { + const_init = constant.getValue().dyn_cast(); + } + + if (const_init) { + std::vector literal_bytes; + TF_RETURN_IF_ERROR( + CopyDenseElementsDataToXlaFormat(const_init, &literal_bytes)); + + TF_ASSIGN_OR_RETURN(auto dest_slice, GetAllocationSlice(dest)); + + const Shape dest_shape = GetShape(dest); + auto thunk = BuildConstantInitializerThunk(op, literal_bytes, dest, + dest_slice, dest_shape); + if (thunk) { + return {std::move(thunk)}; + } + } + return std::unique_ptr(); +} + Status IrEmitterUnnested::BuildInitializerThunk(mlir::Operation* op, mlir::Value init_value, mlir::Value dest) { @@ -3214,11 +3294,10 @@ Status IrEmitterUnnested::BuildInitializerThunk(mlir::Operation* op, auto init_type = init_value.getType().dyn_cast(); TF_RET_CHECK(init_type.getRank() == 0); - TF_ASSIGN_OR_RETURN(std::optional> constant_init_thunk, - BuildConstantInitializerThunk(*ir_emitter_context_, op, - init_value, dest)); + TF_ASSIGN_OR_RETURN(std::unique_ptr constant_init_thunk, + TryBuildConstantInitializerThunk(op, init_value, dest)); if (constant_init_thunk) { - AddThunkToThunkSequence(*std::move(constant_init_thunk)); + AddThunkToThunkSequence(std::move(constant_init_thunk)); return OkStatus(); } @@ -3250,6 +3329,77 @@ Status IrEmitterUnnested::BuildInitializerThunk(mlir::Operation* op, return OkStatus(); } +Status IrEmitterUnnested::BuildFusedInitializerThunk( + mlir::lmhlo::FusionOp fusion, int output_index) { + auto reduce = mlir::dyn_cast_or_null( + fusion.getFusionRoots()[output_index]); + + TF_RET_CHECK(reduce); + TF_RET_CHECK(reduce.getNumResults() == 1); + + mlir::Value init_value = reduce.getInitValues()[0]; + mlir::Value dest = fusion.getOutputBuffers()[output_index]; + TF_ASSIGN_OR_RETURN( + std::unique_ptr constant_init_thunk, + TryBuildConstantInitializerThunk(fusion, init_value, dest)); + if (constant_init_thunk) { + AddThunkToThunkSequence(std::move(constant_init_thunk)); + return OkStatus(); + } + + auto input_buffers = fusion.getInputBuffers(); + + const Shape dest_shape = GetShape(dest); + bool use_experimental_block_size = + ir_emitter_context_->debug_options() + .xla_gpu_enable_experimental_block_size(); + + TF_ASSIGN_OR_RETURN(LaunchDimensions launch_dimensions, + CalculateLaunchDimensions( + dest_shape, ir_emitter_context_->gpu_device_info(), + use_experimental_block_size)); + + TF_ASSIGN_OR_RETURN( + std::optional> opt_ir_arrays, + BuildKernelThunkForFusion( + fusion, launch_dimensions, + /*discriminator=*/absl::StrCat("init_", output_index))); + if (!opt_ir_arrays.has_value()) { + // The kernel was reused, no need to emit code. + return OkStatus(); + } + std::vector& ir_arrays = opt_ir_arrays.value(); + + const llvm_ir::IrArray dest_array = + ir_arrays[input_buffers.size() + output_index]; + + const HloComputation* fused_computation = + *GetOrCreateSubComputationFromRegion(&fusion.getRegion(), + /*is_fusion=*/true); + + FusedIrEmitter fused_emitter(elemental_emitter_); + for (int i = 0; i < fused_computation->num_parameters(); i++) { + fused_emitter.BindGenerator( + *fused_computation->parameter_instruction(i), + [this, &ir_arrays, i](llvm_ir::IrArray::Index index) { + return ir_arrays[i].EmitReadArrayElement(index, &b_); + }); + } + HloInstruction* instr = fused_computation->root_instruction(); + if (instr->opcode() != HloOpcode::kTuple) { + CHECK_EQ(0, output_index); + } else { + instr = instr->mutable_operand(output_index); + } + TF_RET_CHECK(instr->shape().IsArray()); + TF_ASSIGN_OR_RETURN(auto generator, + fused_emitter.GetGenerator(*instr->operand(1))); + TF_RETURN_IF_ERROR( + ParallelLoopEmitter(generator, {dest_array}, launch_dimensions, &b_) + .EmitLoop(GetIrNameFromLoc(fusion.getLoc()))); + return OkStatus(); +} + StatusOr> IrEmitterUnnested::BuildWhileThunk( mlir::lmhlo::WhileOp while_op, const Thunk::ThunkInfo& thunk_info) { // Generate thunk sequence for while 'condition'. @@ -3293,6 +3443,263 @@ Status IrEmitterUnnested::EmitTargetElementLoop( return InternalError("This should be unreachable"); } +// Gets the output offset as calculated from thread_id.x (to be applied to the +// offset calculated from block_id and thread_id.y). +static llvm::Value* GetStartOffsetX(const TilingScheme& tiling_scheme, + llvm::Value* thread_id_x, + llvm::Type* index_ty, + llvm::IRBuilder<>* b) { + int64_t multiplier = tiling_scheme.GetIndexingOrder() == kStridedIndexingX + ? tiling_scheme.GetVectorSize() + : tiling_scheme.GetTileSizeFor(kDimX); + return b->CreateMul(thread_id_x, + llvm::ConstantInt::get(index_ty, multiplier)); +} + +static IrArray::Index GetUnnormalizedIndex( + const IrArray::Index& normalized_shape_index, + const Shape& unnormalized_shape, llvm::IRBuilder<>* b_, + absl::Span dims_in_elems) { + CHECK_EQ(normalized_shape_index.size(), 3); + // If the normalization only add a new dimensions of size 1, + // generate simpler indexing. LLVM doesn't always simplify the more + // complicated indexing and this prevents it from vectorizing some + // cases. We do this only for major_to_minor memory layout. + if (unnormalized_shape.rank() == 2 && unnormalized_shape.has_layout() && + unnormalized_shape.dimensions()[0] == normalized_shape_index.dims()[1] && + unnormalized_shape.dimensions()[1] == normalized_shape_index.dims()[2] && + unnormalized_shape.layout().minor_to_major(1) == 0) { + CHECK_EQ(normalized_shape_index.dims()[0], 1); + auto multidim = normalized_shape_index.multidim(); + return IrArray::Index({multidim[1], multidim[2]}, unnormalized_shape, + normalized_shape_index.GetType()); + } + if (unnormalized_shape.rank() == 2 && unnormalized_shape.has_layout() && + unnormalized_shape.dimensions()[0] == normalized_shape_index.dims()[2] && + unnormalized_shape.dimensions()[1] == normalized_shape_index.dims()[1] && + unnormalized_shape.layout().minor_to_major(1) == 1) { + CHECK_EQ(normalized_shape_index.dims()[0], 1); + auto multidim = normalized_shape_index.multidim(); + return IrArray::Index({multidim[2], multidim[1]}, unnormalized_shape, + normalized_shape_index.GetType()); + } + return normalized_shape_index.SourceIndexOfBitcast( + ShapeUtil::MakeShape(F32, dims_in_elems), unnormalized_shape, b_); +} + +static int GetNumOutputs(const Shape& shape) { + if (shape.IsTuple()) { + return shape.tuple_shapes_size(); + } + return 1; +} + +ReductionCodegenState IrEmitterUnnested::GenerateReductionCodegenState( + mlir::lmhlo::FusionOp fusion, const ReductionCodegenInfo& reduction_info, + absl::Span reduce_instr_index_group, + FusedIrEmitter& fused_emitter) { + ReductionCodegenState reduction_codegen_state(reduction_info); + VLOG(10) << "Emit prologue for reduction: " << llvm_ir::DumpToString(fusion); + + for (const HloReduceInstruction* reduce_hlo : reduce_instr_index_group) { + int num_partial_results = reduction_codegen_state.GetNumPartialResults(); + for (int op_result_idx = 0; + op_result_idx < GetNumOutputs(reduce_hlo->shape()); op_result_idx++) { + Shape result_shape = reduce_hlo->shape().IsTuple() + ? reduce_hlo->shape().tuple_shapes(op_result_idx) + : reduce_hlo->shape(); + + llvm::Type* element_type = + llvm_ir::PrimitiveTypeToIrType(result_shape.element_type(), module_); + llvm::AllocaInst* reduction_input_address = + llvm_ir::EmitAllocaAtFunctionEntry(element_type, + "reduction_input_address", &b_); + + llvm::AllocaInst* partial_result_address = + llvm_ir::EmitAllocaAtFunctionEntryWithCount( + element_type, /*element_count=*/b_.getInt32(num_partial_results), + "partial_reduction_result", &b_); + + const HloInstruction* init_value = + reduce_hlo->init_values()[op_result_idx]; + + // Initialize the partial result with the initial value of the reduction. + llvm::Value* init_ir_value = (*fused_emitter.GetGenerator(*init_value))( + IrArray::Index(b_.getInt32Ty())) + .value(); + + for (int i = 0; i < num_partial_results; ++i) { + b_.CreateStore(init_ir_value, + InBoundsGEP(partial_result_address->getAllocatedType(), + partial_result_address, {b_.getInt32(i)})); + } + + const TilingScheme& tiling_scheme = + reduction_codegen_state.GetTilingScheme(); + int64_t num_threads_x = tiling_scheme.GetNumThreadsFor(kDimX); + llvm::GlobalVariable* shared_cache = [&]() -> llvm::GlobalVariable* { + if (reduction_codegen_state.IsRowReduction()) { + // Multi-row reductions do not use shared memory. + if (RowReductionGetRowsPerWarp(tiling_scheme.GetDimsInElems()[2]) > + 1) { + return nullptr; + } + // Allocate __shared__ + // cache[num_partial_results][num_warps][scaling_factor]. + CHECK_EQ(tiling_scheme.GetNumThreadsPerBlock() % WarpSize(), 0); + int num_warps = tiling_scheme.GetNumThreadsPerBlock() / WarpSize(); + return AllocateShared(tiling_scheme, element_type, + {num_partial_results, num_warps}, + "shared_cache"); + } else { + // Allocate __shared__ + // cache[num_threads][num_threads + 1], where + // num_threads == num_threads_x == num_threads_y. The "+1" is used to + // avoid bank conflicts. + // + // (Although each thread produces num_partial_results results, we + // don't need that much cache: Only one result is live at a time.) + CHECK_EQ(num_threads_x, tiling_scheme.GetNumThreadsFor(kDimY)); + return AllocateShared(tiling_scheme, element_type, + {num_threads_x, num_threads_x + 1}, + "shared_cache"); + } + }(); + + llvm_ir::ElementGenerator input_gen = + *fused_emitter.GetGenerator(*reduce_hlo->inputs()[op_result_idx]); + reduction_codegen_state.SetCalculationStateFor( + {shared_cache, init_ir_value, partial_result_address, + reduction_input_address, input_gen}, + reduce_hlo, op_result_idx); + } + } + + return reduction_codegen_state; +} + +void IrEmitterUnnested::EmitFullWarpShuffleDownLoopForReduce( + const HloComputation* reducer, + absl::Span partial_result_addresses, + int threads_per_block, int num_results_per_warp) { + // This only works when the block size is a multiple of 32 threads. + + // We check this here as a mistake in the number of threads per + // block is very hard to detect. + CHECK_EQ(threads_per_block % 32, 0); + CHECK_EQ(WarpSize() % num_results_per_warp, 0); + + for (int distance = 16 / num_results_per_warp; distance >= 1; distance /= 2) { + absl::InlinedVector reduction_params; + + for (auto acc : partial_result_addresses) { + reduction_params.push_back(acc.first); + } + + for (auto [partial_result_address, element_type] : + partial_result_addresses) { + int bit_width = llvm_ir::GetSizeInBits(element_type); + llvm::Value* result_from_other_lane = llvm_ir::EmitAllocaAtFunctionEntry( + element_type, "result_from_other_lane", &b_); + + reduction_params.push_back(result_from_other_lane); + + // Bitcast cannot be applied to aggregate types (even packed ones), so + // we bitcast addresses of load/store to intN* of the same bit-width. + llvm::Type* shuffled_value_type = + element_type->isStructTy() ? b_.getIntNTy(bit_width) : element_type; + auto convert_pointer_for_shuffle = [&](llvm::Value* ptr) { + return b_.CreatePointerBitCastOrAddrSpaceCast( + ptr, shuffled_value_type->getPointerTo()); + }; + + llvm::Value* partial_result = + b_.CreateLoad(shuffled_value_type, + convert_pointer_for_shuffle(partial_result_address), + "partial_reduction_result"); + b_.CreateStore( + EmitFullWarpShuffleDown(partial_result, b_.getInt32(distance), &b_), + convert_pointer_for_shuffle(result_from_other_lane)); + } + + StatusOr> returned_scalars = + CallNestedComputationWithScalarAddrs(&b_, *ir_emitter_context_, + *reducer, reduction_params); + TF_CHECK_OK(returned_scalars.status()); + + for (int i = 0; i < returned_scalars->size(); i++) { + b_.CreateStore(/*Val=*/returned_scalars->at(i), + /*Ptr=*/partial_result_addresses[i].first); + } + } +} + +llvm::Value* IrEmitterUnnested::GetOutputAddressForReduction( + int partial_result_idx, llvm::Type* index_ty, + const ReductionCodegenState& reduction_codegen_state, + const TilingKernelInfo& tiling_kernel_info, + const IrEmitterUnnested::ReductionOutputMap& output_arrays, + const HloReduceInstruction* reduction, int output_idx) { + auto constant = [&](uint64_t c) -> llvm::Constant* { + return llvm::ConstantInt::get(index_ty, c); + }; + + const TilingScheme& tiling_scheme = reduction_codegen_state.GetTilingScheme(); + const TilingThreadIdInfo& thread_id_info = tiling_kernel_info.thread_id_info; + + IrArray::Index start_offset = [&] { + llvm::Value* x_loc = thread_id_info.thread_id_x; + llvm::Value* y_loc = thread_id_info.thread_id_y; + if (!reduction_codegen_state.IsRowReduction()) { + std::swap(x_loc, y_loc); + } + llvm::Value* start_offset_x = + GetStartOffsetX(tiling_scheme, x_loc, index_ty, &b_); + return tiling_kernel_info.tile_origin.AddOffsetToDim(y_loc, kDimY, &b_) + .AddOffsetToDim(start_offset_x, kDimX, &b_); + }(); + + const IrArray& output_array = output_arrays.at(reduction)[output_idx]; + const Shape& operand_shape = reduction->inputs()[output_idx]->shape(); + Shape reduction_kept_element_shape = + ShapeUtil::DeleteDimensions(reduction->dimensions(), operand_shape); + + // Given the IrArray index of a reduction input, returns the linear address of + // the reduction output as if the reduction were going to keep the input shape + // with the dimensions being reduced moved. + llvm::Value* untransposed_output_linear_address = [&] { + const llvm_ir::IrArray::Index index = + start_offset.AddOffsetToDim(constant(partial_result_idx), kDimX, &b_); + if (reduction_codegen_state.IsRowReduction()) { + // For row-reduction, y-coordinate determines which row we write into. + return index[kDimY]; + } + // For column reduction, we get the transposed address. + absl::Span dims_in_elem = tiling_scheme.GetDimsInElems(); + llvm::Value* x_dim_size = + index.GetConstantWithIndexType(dims_in_elem[kDimX]); + llvm::Value* x_block_offset = b_.CreateMul(index[kDimZ], x_dim_size); + return b_.CreateAdd(x_block_offset, index[kDimX]); + }(); + + // A reduction is allowed to transpose its output. For example, suppose + // we are reducing the second dimension of f32[10,20,30]{3,2,1}. We are + // allowed to produce as output either f32[10,30]{1,0} (no transpose) or + // f32[10,30]{0,1} (transposing the two output dims). + // + // At this point in the function we have a "partial sum" of input elements + // (stored in partial_result_addresses), and we need to accumulate it into + // the correct output element. + IrArray::Index element_index( + /*linear=*/untransposed_output_linear_address, + reduction_kept_element_shape, &b_); + IrArray::Index output_index(element_index.multidim(), output_array.GetShape(), + element_index.GetType()); + + return output_array.EmitArrayElementAddress(output_index, &b_, + "output_element_address"); +} + llvm::Value* IrEmitterUnnested::CastSharedToGlobal(llvm::Value* input, llvm::Type* element_type, llvm::Twine name) { @@ -3302,6 +3709,225 @@ llvm::Value* IrEmitterUnnested::CastSharedToGlobal(llvm::Value* input, name); } +void IrEmitterUnnested::WriteReductionOutput( + llvm::Type* index_ty, const ReductionCodegenState& reduction_codegen_state, + const TilingKernelInfo& tiling_kernel_info, + const ReductionOutputMap& output_arrays, + const HloReduceInstruction* reduction, int partial_result_idx, + const absl::Span values) { + const HloComputation* reducer = reduction->to_apply(); + for (const auto& [oidx, typed_ptr] : llvm::enumerate(values)) { + auto [output_ptr, type] = typed_ptr; + llvm::Value* output_address = GetOutputAddressForReduction( + partial_result_idx, index_ty, reduction_codegen_state, + tiling_kernel_info, output_arrays, reduction, oidx); + if (reduction_codegen_state.IsRaceFree()) { + b_.CreateStore(b_.CreateLoad(type, output_ptr, "output"), output_address); + } else { + CHECK_EQ(values.size(), 1); + TF_CHECK_OK(EmitAtomicOperationForNestedComputation( + &b_, *ir_emitter_context_, *reducer, output_address, output_ptr, + type)); + } + } +} + +void IrEmitterUnnested::EmitReductionOutputForRowReduction( + const TilingKernelInfo& tiling_kernel_info, + const ReductionCodegenState& reduction_codegen_state, llvm::Type* index_ty, + const ReductionOutputMap& output_arrays, + const HloReduceInstruction* reduction, int partial_result_idx) { + const HloComputation* reducer = reduction->to_apply(); + const auto& thread_id_info = tiling_kernel_info.thread_id_info; + auto constant = [&](uint64_t c) -> llvm::Constant* { + return llvm::ConstantInt::get(index_ty, c); + }; + auto is_zero = [&](llvm::Value* value) { + return b_.CreateICmpEQ(value, constant(0)); + }; + + int num_outputs = reducer->num_parameters() / 2; + const TilingScheme& tiling_scheme = reduction_codegen_state.GetTilingScheme(); + absl::InlinedVector current_outputs; + for (int output_idx = 0; output_idx < num_outputs; output_idx++) { + const ReductionCodegenState::ReductionCalculationState& state = + reduction_codegen_state.GetCalculationStateFor(reduction, output_idx); + current_outputs.push_back( + {InBoundsGEP(state.partial_result_address->getAllocatedType(), + state.partial_result_address, + {constant(partial_result_idx)}, "current_output"), + state.partial_result_address->getAllocatedType()}); + } + + int reduced_dimension_size = tiling_scheme.GetDimsInElems()[2]; + int num_rows_per_warp = RowReductionGetRowsPerWarp(reduced_dimension_size); + EmitFullWarpShuffleDownLoopForReduce( + reducer, absl::MakeSpan(current_outputs), + tiling_scheme.GetNumThreadsPerBlockPhysical(), num_rows_per_warp); + + KernelSupportLibrary ksl(&b_); + llvm::Value* warp_id = + b_.CreateUDiv(thread_id_info.thread_id_x, constant(WarpSize())); + + auto emit_write_output = [&](llvm::Value* write_condition, + const absl::Span values) { + ksl.If("reduction_write_output", write_condition, [&] { + WriteReductionOutput(index_ty, reduction_codegen_state, + tiling_kernel_info, output_arrays, reduction, + partial_result_idx, values); + }); + }; + + if (num_rows_per_warp > 1) { + llvm::Value* is_writing_thread = is_zero(b_.CreateAnd( + thread_id_info.thread_id_x, constant(reduced_dimension_size - 1))); + emit_write_output(is_writing_thread, current_outputs); + return; + } + + ksl.If("intra_warp_reduce_write", is_zero(thread_id_info.lane_id), [&] { + for (int oidx = 0; oidx < num_outputs; oidx++) { + const ReductionCodegenState::ReductionCalculationState& state = + reduction_codegen_state.GetCalculationStateFor(reduction, oidx); + llvm::Value* shmem_output_addr = thread_id_info.GEPIntoSharedMemory( + &b_, state.shared_cache, {constant(partial_result_idx), warp_id}); + Store(Load(current_outputs[oidx].second, current_outputs[oidx].first), + shmem_output_addr); + } + }); + + // TODO(cheshire): Don't we want to sync it once for everything in the + // output? Not once per each? + EmitSyncThreads(); + ksl.If("inter_warp_reduce", is_zero(warp_id), [&] { + absl::InlinedVector selected_values; + for (int oidx = 0; oidx < num_outputs; oidx++) { + const ReductionCodegenState::ReductionCalculationState& state = + reduction_codegen_state.GetCalculationStateFor(reduction, oidx); + llvm::Value* block_accum_addr = thread_id_info.GEPIntoSharedMemory( + &b_, state.shared_cache, + {constant(partial_result_idx), thread_id_info.lane_id}); + + llvm::Type* element_type = + state.partial_result_address->getAllocatedType(); + + /* Insure initial value address is in generic, not scratch. */ + llvm::Value* initial_value_addr = + CastSharedToGlobal(llvm_ir::EmitAllocaAtFunctionEntry( + element_type, "initial_value_addr", &b_), + element_type); + b_.CreateStore(state.initial_value, initial_value_addr); + + llvm::Value* warp_exists = b_.CreateICmpULT( + thread_id_info.thread_id_x, + constant(tiling_scheme.GetNumThreadsFor(kDimX) / WarpSize())); + + llvm::Value* selected_value = + b_.CreateSelect(warp_exists, block_accum_addr, initial_value_addr); + + selected_values.push_back({selected_value, element_type}); + } + + // If only one warp is present in the block, then we don't need inter-warp + // reduction. + // TODO(b/241414088) If only warp is present, then inter-warp communication + // using shared memory and synchronization using barrier is also unnecessary + // and should be removed. + if (tiling_scheme.GetNumThreadsPerBlock() > WarpSize()) { + EmitFullWarpShuffleDownLoopForReduce( + reducer, absl::MakeSpan(selected_values), + tiling_scheme.GetNumThreadsPerBlock()); + } + + emit_write_output(is_zero(thread_id_info.thread_id_x), selected_values); + }); +} + +void IrEmitterUnnested::EmitReductionOutputForColumnReduction( + const TilingKernelInfo& tiling_kernel_info, + const ReductionCodegenState& reduction_codegen_state, llvm::Type* index_ty, + const ReductionOutputMap& output_arrays, + const HloReduceInstruction* reduction, int partial_result_idx) { + KernelSupportLibrary ksl(&b_); + const HloComputation* reducer = reduction->to_apply(); + const auto& thread_id_info = tiling_kernel_info.thread_id_info; + + auto constant = [&](uint64_t c) -> llvm::Constant* { + return llvm::ConstantInt::get(index_ty, c); + }; + auto is_zero = [&](llvm::Value* value) { + return b_.CreateICmpEQ(value, constant(0)); + }; + const TilingScheme& tiling_scheme = reduction_codegen_state.GetTilingScheme(); + int num_outputs = reducer->num_parameters() / 2; + + // Wait for reads from shmem in the last iteration to complete. (If this is + // slow, we could "double-buffer" by having two shmem buffers and switching + // between them.) + if (partial_result_idx > 0) { + EmitSyncThreads(); + } + + // Store the transpose in shared memory. + for (int output_idx = 0; output_idx < num_outputs; output_idx++) { + const ReductionCodegenState::ReductionCalculationState& state = + reduction_codegen_state.GetCalculationStateFor(reduction, output_idx); + llvm::GlobalVariable* shared_cache = state.shared_cache; + llvm::AddrSpaceCastInst* shmem_output_addr = + llvm::cast(thread_id_info.GEPIntoSharedMemory( + &b_, shared_cache, + {thread_id_info.thread_id_x, thread_id_info.thread_id_y}, + "shmem_output_address")); + llvm::Value* current_output = + InBoundsGEP(state.partial_result_address->getAllocatedType(), + state.partial_result_address, + {constant(partial_result_idx)}, "current_output"); + + llvm::Value* current_output_value = + Load(state.partial_result_address->getAllocatedType(), current_output); + b_.CreateStore(current_output_value, shmem_output_addr); + } + + EmitSyncThreads(); + + // Get transposed element from shared memory. + absl::InlinedVector shmem_transposed_addrs; + for (int output_idx = 0; output_idx < num_outputs; output_idx++) { + const ReductionCodegenState::ReductionCalculationState& state = + reduction_codegen_state.GetCalculationStateFor(reduction, output_idx); + llvm::AddrSpaceCastInst* shmem_transposed_addr = + llvm::cast(thread_id_info.GEPIntoSharedMemory( + &b_, state.shared_cache, + {thread_id_info.thread_id_y, thread_id_info.thread_id_x}, + "shmem_transposed_addr")); + shmem_transposed_addrs.push_back( + {shmem_transposed_addr, llvm::cast( + shmem_transposed_addr->getPointerOperand()) + ->getResultElementType()}); + } + + EmitFullWarpShuffleDownLoopForReduce(reducer, + absl::MakeSpan(shmem_transposed_addrs), + tiling_scheme.GetNumThreadsPerBlock()); + + // Some warps in the block are completely outside of the bound of the + // tensor, so they should not write any output at all. + llvm::Value* has_output = + b_.CreateAnd(b_.CreateICmpULT(GetStartOffsetX(tiling_scheme, + thread_id_info.thread_id_y, + index_ty, &b_), + tiling_kernel_info.output_tile_bounds[1]), + b_.CreateICmpULT(thread_id_info.thread_id_x, + tiling_kernel_info.output_tile_bounds[0])); + + ksl.If("reduction_write_output", + b_.CreateAnd(has_output, is_zero(thread_id_info.lane_id)), [&] { + WriteReductionOutput(index_ty, reduction_codegen_state, + tiling_kernel_info, output_arrays, reduction, + partial_result_idx, shmem_transposed_addrs); + }); +} + llvm::CallInst* IrEmitterUnnested::EmitSyncThreads() { MaybeEmitFenceForAMDGPU(llvm::AtomicOrdering::SequentiallyConsistent, "workgroup"); @@ -3481,6 +4107,239 @@ llvm::GlobalVariable* IrEmitterUnnested::AllocateShared( array_type, buffer_name); } +// Generate a single element of the tile (update the accumulator state) for a +// given reducer of index `i`. +void IrEmitterUnnested::GenerateElementForReducer( + const HloReduceInstruction* reduction, llvm::Value* partial_result_index, + const ReductionCodegenState& codegen_state, + const llvm_ir::IrArray::Index& index_without_linear, + const IrArray::Index& input_index, int num_partial_results, + const ReductionOutputMap& result_ir_arrays) { + HloComputation* reducer = reduction->to_apply(); + CHECK_EQ(reducer->num_parameters() % 2, 0); + + absl::InlinedVector reduction_accumulators; + absl::InlinedVector reduction_input_value; + for (int red_idx = 0; red_idx < reducer->num_parameters() / 2; red_idx++) { + const ReductionCodegenState::ReductionCalculationState& state = + codegen_state.GetCalculationStateFor(reduction, red_idx); + + llvm::AllocaInst* input_address = state.input_address; + llvm::AllocaInst* partial_reduction_result_address = + state.partial_result_address; + llvm::Value* const input_ir_value = *state.input_gen( + num_partial_results > 1 ? index_without_linear : input_index); + b_.CreateStore(input_ir_value, input_address); + llvm::Value* partial_result_address = + InBoundsGEP(partial_reduction_result_address->getAllocatedType(), + partial_reduction_result_address, {partial_result_index}); + reduction_accumulators.push_back(partial_result_address); + reduction_input_value.push_back(input_address); + } + + absl::InlinedVector reduction_params; + for (llvm::Value* acc : reduction_accumulators) { + reduction_params.push_back(acc); + } + for (llvm::Value* value : reduction_input_value) { + reduction_params.push_back(value); + } + + // Emit a call to the variadic reducer. Since it may be returning a + // tuple, we can't return it directly as a value. Instead, before + // the call, we create N (N = # arguments in the tuple) allocas, one + // for each returned argument, then when we make the call we pass N + // pointers as last parameters, the called computation writes into + // those pointers, and we have returned values on the stack (as well + // as pointers to them). + StatusOr> returned_scalars = + CallNestedComputationWithScalarAddrs(&b_, *ir_emitter_context_, *reducer, + reduction_params); + TF_CHECK_OK(returned_scalars.status()); + + for (int i = 0; i < returned_scalars->size(); i++) { + b_.CreateStore(returned_scalars->at(i), reduction_accumulators[i]); + } +} + +Status IrEmitterUnnested::EmitIRForReduction( + mlir::lmhlo::FusionOp fusion, + absl::Span instr_index_group, + FusedIrEmitter& fused_emitter, const ReductionOutputMap& result_ir_arrays, + const ReductionCodegenInfo& reduction_info, const Shape& input_shape) { + std::vector reductions; + ExtraOutputGensMap extra_output_gens; + + for (const HloInstruction* hlo : instr_index_group) { + if (IsReductionFromOrToContiguousDimensions(*hlo)) { + reductions.push_back(Cast(hlo)); + } else { + extra_output_gens[hlo] = *fused_emitter.GetGenerator(*hlo); + } + } + + CHECK(!reductions.empty()) << " expect at least one reduce instructions."; + const TilingScheme& tiling_scheme = reduction_info.GetTilingScheme(); + CHECK_EQ(tiling_scheme.GetNumThreadsPerBlockPhysical() % WarpSize(), 0); + llvm::Type* index_ty = + GetIndexTypeForKernel(fusion, + tiling_scheme.GetNumThreadsPerBlockPhysical() * + tiling_scheme.GetNumberOfBlocksPhysical(), + &b_); + ReductionCodegenState codegen_state = GenerateReductionCodegenState( + fusion, reduction_info, reductions, fused_emitter); + + EmitTileElementFunction emit_reduction_element = + [&](const TilingThreadIdInfo& thread_id_info, const IrArray::Index& index, + llvm::Value* y_loc, llvm::Value* x_loc) { + IrArray::Index input_index = GetUnnormalizedIndex( + index, input_shape, &b_, + codegen_state.GetTilingScheme().GetDimsInElems()); + llvm::Value* partial_result_index = + codegen_state.IsRowReduction() + ? b_.getInt32(0) + : b_.CreateSub( + x_loc, + GetStartOffsetX(tiling_scheme, thread_id_info.thread_id_x, + index_ty, &b_)); + + // Clear the linear index field of the IrArray::Index to enable the use + // of GetElementPointer with array types. This enables the vectorization + // of the computation for different partial results. Use this index if + // 'num_partial_results > 1'. + int num_partial_results = codegen_state.GetNumPartialResults(); + IrArray::Index index_without_linear{input_index.multidim(), input_shape, + input_index.GetType()}; + + // Emit code to generate the input and perform the reduction computation + // for each reduction instruction. + for (const HloReduceInstruction* reduce : reductions) { + GenerateElementForReducer(reduce, partial_result_index, codegen_state, + index_without_linear, input_index, + num_partial_results, result_ir_arrays); + } + + // Emit code to generate the output for the non-reduction instructions + // in the fusion, if any. + TF_CHECK_OK(EmitExtraOutputsForReduce(input_shape, result_ir_arrays, + input_index, reduction_info, + extra_output_gens)); + }; + + TF_ASSIGN_OR_RETURN( + TilingKernelInfo tiling_kernel_info, + EmitTilingKernel( + &b_, tiling_scheme, index_ty, + [&](const TilingThreadIdInfo& thread_id_info, + const IrArray::Index& index, ValueVector2 tile_dimensions) { + EmitTile(&b_, codegen_state.GetTilingScheme(), index, + thread_id_info, tile_dimensions, emit_reduction_element); + })); + + KernelSupportLibrary ksl(&b_); + for (const HloReduceInstruction* reduce : reductions) { + for (int partial_result_idx = 0; + partial_result_idx < reduction_info.GetNumPartialResults(); + ++partial_result_idx) { + if (codegen_state.IsRowReduction()) { + EmitReductionOutputForRowReduction(tiling_kernel_info, codegen_state, + index_ty, result_ir_arrays, reduce, + partial_result_idx); + } else { + EmitReductionOutputForColumnReduction(tiling_kernel_info, codegen_state, + index_ty, result_ir_arrays, + reduce, partial_result_idx); + } + } + } + + return OkStatus(); +} + +Status IrEmitterUnnested::EmitUnnestedReduction( + mlir::lmhlo::FusionOp fusion, HloFusionAnalysis& fusion_analysis) { + auto* reduction_codegen_info = fusion_analysis.GetReductionCodegenInfo(); + // Set flag to false as Reduction has it's own custom logic of choosing a + // block size. + TF_ASSIGN_OR_RETURN(auto launch_dimensions, + fusion_analysis.GetLaunchDimensions( + /*use_experimental_block_size=*/false)); + + VLOG(3) << "Launch dimensions of " + << mlir::mhlo::GetDebugNameFromLocation(fusion.getLoc()) << ": " + << launch_dimensions.ToString(); + if (!reduction_codegen_info->IsRaceFree()) { + absl::Span fusion_roots = + fusion_analysis.fusion_roots(); + for (int i = 0; i < fusion_roots.size(); ++i) { + if (IsReductionFromOrToContiguousDimensions(*fusion_roots[i])) { + TF_RETURN_IF_ERROR(BuildFusedInitializerThunk(fusion, i)); + } + } + } + + TF_ASSIGN_OR_RETURN( + std::optional> opt_ir_arrays, + BuildKernelThunkForFusion(fusion, launch_dimensions)); + if (!opt_ir_arrays.has_value()) { + // The kernel was reused, no need to emit code. + return OkStatus(); + } + std::vector& ir_arrays = opt_ir_arrays.value(); + + FusedIrEmitter fused_emitter(elemental_emitter_); + const HloComputation* fused_computation = fusion_analysis.fused_computation(); + CHECK_LT(fused_computation->num_parameters(), ir_arrays.size()); + for (int i = 0; i < fused_computation->num_parameters(); i++) { + llvm_ir::IrArray ir_array = ir_arrays[i]; + HloInstruction* fused_operand = fused_computation->parameter_instruction(i); + fused_emitter.BindGenerator( + *fused_operand, + [this, ir_array, fused_operand](const llvm_ir::IrArray::Index& index) { + return ir_array.EmitReadArrayElement(index, &b_, + fused_operand->name()); + }); + } + + // Get outputs. + ReductionOutputMap result_ir_arrays; + + // Skip all parameter buffers first. + int ir_arrays_idx = fused_computation->num_parameters(); + for (HloInstruction* root : fusion_analysis.fusion_roots()) { + int get_num_results = GetNumOutputs(root->shape()); + result_ir_arrays[root] = + absl::MakeSpan(ir_arrays).subspan(ir_arrays_idx, get_num_results); + ir_arrays_idx += get_num_results; + } + + KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll); + + // Use raw block_id_y to select the i-th parallel reduction to run. Using + // block_id_y instead of block_id_x simplifies the index calculation + // for reduction code generation as the block_id_y is orthogonal to + // the indices used within the reductions. + const std::vector>& instr_index_groups = + reduction_codegen_info->GetIndexGroups(); + Shape reduce_operand_shape = reduction_codegen_info->GetReduceOperandShape(); + + llvm::CallInst* raw_block_id_y = gpu::EmitCallToTargetIntrinsic( + gpu::TargetIntrinsicID::kBlockIdy, {}, {}, &b_); + llvm_ir::AddRangeMetadata(0, instr_index_groups.size(), + llvm::cast(raw_block_id_y)); + for (int i = 0; i < instr_index_groups.size(); ++i) { + TF_RETURN_IF_ERROR(ksl.IfWithStatus( + StrCat("reduce-group-", i), + b_.CreateICmpEQ(raw_block_id_y, b_.getInt32(i)), [&] { + return EmitIRForReduction( + fusion, instr_index_groups[i], fused_emitter, result_ir_arrays, + *reduction_codegen_info, reduce_operand_shape); + })); + } + + return OkStatus(); +} + // Emits code for slices based on the below structure. An if statement with // a guarding condition is generated for each ROOT slice. // diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index 905ed163d64ead..a55276d9dcdb4f 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -324,6 +324,75 @@ class IrEmitterUnnested : public IrEmitter { const ReductionCodegenInfo& reduction_info, const ExtraOutputGensMap& extra_output_gens); + // Generates code for reduction to contiguous dimensions. + // + // Row reduction uses the following algorithm described in CUDA-like + // pseudocode: + // + // ``` + // __global__ void reduce(int num_rows, float *in, float out) { + // __shared__ float[32] cache; + // int offset = blockDim.x * blockIdx.x + threadIdx.x; + // if (offset >= num_rows) return; + // int tile_bound = std::min(offset + kTileSizeX, num_rows); + // float accum = 0; + // for (int i=offset; i reduce_instr_index_group, + FusedIrEmitter& fused_emitter); + + // Wraps up the code generation for a tile block of a reduction kernel: + // write the calculated output into the output tensor. + void EmitReductionOutput( + llvm::Type* index_ty, mlir::lmhlo::FusionOp fusion, + absl::Span reduce_instr_index_group, + const ReductionOutputMap& result_ir_arrays, + const ReductionCodegenState& reduction_codegen_state, + const TilingKernelInfo& tiling_kernel_info); + + // Returns the address to write the reduction output to. + llvm::Value* GetOutputAddressForReduction( + int partial_result_idx, llvm::Type* index_ty, + const ReductionCodegenState& reduction_codegen_state, + const TilingKernelInfo& tiling_kernel_info, + const ReductionOutputMap& output_arrays, + const HloReduceInstruction* reduction, int output_idx); + + // Performs the actual write of the reduction result. + using TypedPointer = std::pair; + void WriteReductionOutput( + llvm::Type* index_ty, + const ReductionCodegenState& reduction_codegen_state, + const TilingKernelInfo& tiling_kernel_info, + const ReductionOutputMap& output_arrays, + const HloReduceInstruction* reduction, int partial_result_idx, + const absl::Span values); + + // `current_output`: the value the tile has calculated. + // `output_address`: address where the output value has to be written. + void EmitReductionOutputForRowReduction( + const TilingKernelInfo& tiling_kernel_info, + const ReductionCodegenState& reduction_codegen_state, + llvm::Type* index_ty, const ReductionOutputMap& output_arrays, + const HloReduceInstruction* reduction, int partial_result_idx); + + // Same arguments as EmitReductionOutputForRowReduction. + void EmitReductionOutputForColumnReduction( + const TilingKernelInfo& tiling_kernel_info, + const ReductionCodegenState& reduction_codegen_state, + llvm::Type* index_ty, const ReductionOutputMap& output_arrays, + const HloReduceInstruction* reduction, int partial_result_idx); + + // Emits code for reductions in the output_instructions. + Status EmitIRForReduction(mlir::lmhlo::FusionOp fusion, + absl::Span instr_index_group, + FusedIrEmitter& fused_emitter, + const ReductionOutputMap& result_ir_arrays, + const ReductionCodegenInfo& reduction_info, + const Shape& input_shape); + + // Generate a single element of the tile (update the accumulator state) for a + // given reducer of index `i`. + void GenerateElementForReducer( + const HloReduceInstruction* reduction, llvm::Value* partial_result_index, + const ReductionCodegenState& codegen_state, + const llvm_ir::IrArray::Index& index_without_linear, + const llvm_ir::IrArray::Index& input_index, int num_partial_results, + const ReductionOutputMap& result_ir_arrays); + + // Emits shuffle-down reduction for the `partial_result_address` using the + // reduction computation `reducer`, writes output into + // `partial_result_address`. + // + // Multiple partial_result_address inputs happen when doing variadic + // reduction: each one should get the output value. + void EmitFullWarpShuffleDownLoopForReduce( + const HloComputation* reducer, + absl::Span partial_result_addresses, + int threads_per_block, int num_results_per_warp = 1); + // Allocates a shared tile of given dimensions, applying scaling specified in // tilng_scheme as a major-most dimension to avoid collisions. llvm::GlobalVariable* AllocateShared( @@ -472,8 +618,20 @@ class IrEmitterUnnested : public IrEmitter { mlir::Operation* op, mlir::ValueRange needed_operands, const LaunchDimensions& launch_dimensions); + // Returns a thunk that, given a reduce or select-and-scatter op, + // initializes its memory to the appropriate initial value. + std::unique_ptr BuildConstantInitializerThunk( + mlir::Operation* op, absl::Span init_value, + mlir::Value dest, const BufferAllocation::Slice& dest_slice, + const Shape& output_shape); + + StatusOr> TryBuildConstantInitializerThunk( + mlir::Operation* op, mlir::Value init_value, mlir::Value dest); + Status BuildInitializerThunk(mlir::Operation* op, mlir::Value init_value, mlir::Value dest); + Status BuildFusedInitializerThunk(mlir::lmhlo::FusionOp fusion, + int output_index); // Returns a WhileThunk that invokes thunk sequences for 'condition' and // 'body' sub-computations of while instruction 'hlo'. From 9c00e70a24a4c35a2afe0279ea97a6e823585efd Mon Sep 17 00:00:00 2001 From: Austin Anderson Date: Tue, 25 Jul 2023 11:59:51 -0700 Subject: [PATCH 126/410] Fix typo in wheel verification selection script PiperOrigin-RevId: 550955286 --- ci/official/utilities/rename_and_verify_wheels.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ci/official/utilities/rename_and_verify_wheels.sh b/ci/official/utilities/rename_and_verify_wheels.sh index 84ad5bf9a99682..502ce8ac4488a9 100755 --- a/ci/official/utilities/rename_and_verify_wheels.sh +++ b/ci/official/utilities/rename_and_verify_wheels.sh @@ -20,7 +20,7 @@ set -euxo pipefail cd $1 -for wheel in build/*.whl; do +for wheel in *.whl; do echo "Checking and renaming $wheel..." time python3 -m auditwheel repair --plat manylinux2014_x86_64 "$wheel" --wheel-dir build 2>&1 | tee check.txt From df4279c26f37432e76baf4013d68af047f74559f Mon Sep 17 00:00:00 2001 From: Kevin Gleason Date: Tue, 25 Jul 2023 12:14:19 -0700 Subject: [PATCH 127/410] Integrate StableHLO at openxla/stablehlo@8816d058 PiperOrigin-RevId: 550959571 --- third_party/stablehlo/temporary.patch | 106 +++----------------------- third_party/stablehlo/workspace.bzl | 4 +- 2 files changed, 14 insertions(+), 96 deletions(-) diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch index b697e5c80707ca..da15cb22c58f06 100644 --- a/third_party/stablehlo/temporary.patch +++ b/third_party/stablehlo/temporary.patch @@ -1,35 +1,10 @@ diff --ruN a/stablehlo/BUILD.bazel b/stablehlo/BUILD.bazel --- stablehlo/BUILD.bazel +++ stablehlo/BUILD.bazel -@@ -227,20 +227,6 @@ - tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "stablehlo/integrations/python/mlir/dialects/ChloOps.td", - deps = [ -- ":chlo_ops_py_td_files", -- ], --) -- --td_library( -- name = "chlo_ops_py_td_files", -- srcs = [ -- "@llvm-project//mlir:include/mlir/Bindings/Python/Attributes.td", -- ], -- includes = [ -- ".", -- "include", -- ], -- deps = [ - ":chlo_ops_td_files", - "@llvm-project//mlir:OpBaseTdFiles", - ], -@@ -289,6 +275,24 @@ - "@llvm-project//mlir:InferTypeOpInterface", - "@llvm-project//mlir:QuantOps", - "@llvm-project//mlir:TransformUtils", -+ ], -+) -+ -+cc_library( +@@ -279,6 +279,24 @@ + ) + + cc_library( + name = "experimental_ops", + srcs = [ + "stablehlo/dialect/ExperimentalOps.cpp", @@ -44,28 +19,14 @@ diff --ruN a/stablehlo/BUILD.bazel b/stablehlo/BUILD.bazel + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", - ], - ) - -@@ -635,17 +639,6 @@ - tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "stablehlo/integrations/python/mlir/dialects/StablehloOps.td", - deps = [ -- ":stablehlo_ops_py_td_files", -- ], --) -- --td_library( -- name = "stablehlo_ops_py_td_files", -- srcs = [ -- "@llvm-project//mlir:include/mlir/Bindings/Python/Attributes.td", -- ], -- includes = ["."], -- deps = [ - ":stablehlo_ops_td_files", - "@llvm-project//mlir:OpBaseTdFiles", - ], -@@ -702,6 +695,7 @@ ++ ], ++) ++ ++cc_library( + name = "reference_axes", + srcs = [ + "stablehlo/reference/Axes.cpp", +@@ -677,6 +695,7 @@ deps = [ ":base", ":chlo_ops", @@ -73,27 +34,6 @@ diff --ruN a/stablehlo/BUILD.bazel b/stablehlo/BUILD.bazel ":stablehlo_ops", ":stablehlo_ops_inc_gen", ":stablehlo_pass_inc_gen", -@@ -1014,20 +1008,6 @@ - tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "stablehlo/integrations/python/mlir/dialects/VhloOps.td", - deps = [ -- ":vhlo_ops_py_td_files", -- ], --) -- --td_library( -- name = "vhlo_ops_py_td_files", -- srcs = [ -- "@llvm-project//mlir:include/mlir/Bindings/Python/Attributes.td", -- ], -- includes = [ -- ".", -- "include", -- ], -- deps = [ - ":vhlo_ops_td_files", - "@llvm-project//mlir:OpBaseTdFiles", - ], diff --ruN a/stablehlo/CMakeLists.txt b/stablehlo/CMakeLists.txt --- stablehlo/CMakeLists.txt +++ stablehlo/CMakeLists.txt @@ -879,28 +819,6 @@ diff --ruN a/stablehlo/stablehlo/dialect/ExperimentalOps.h b/stablehlo/stablehlo +} // namespace mlir + +#endif // STABLEHLO_DIALECT_EXPERIMENTAL_OPS_H -diff --ruN a/stablehlo/stablehlo/integrations/python/mlir/dialects/ChloOps.td b/stablehlo/stablehlo/integrations/python/mlir/dialects/ChloOps.td ---- stablehlo/stablehlo/integrations/python/mlir/dialects/ChloOps.td -+++ stablehlo/stablehlo/integrations/python/mlir/dialects/ChloOps.td -@@ -17,7 +17,6 @@ - #ifndef STABLEHLO_INTEGRATIONS_PYTHON_CHLO_OPS - #define STABLEHLO_INTEGRATIONS_PYTHON_CHLO_OPS - --include "mlir/Bindings/Python/Attributes.td" - include "stablehlo/dialect/ChloOps.td" - - #endif -diff --ruN a/stablehlo/stablehlo/integrations/python/mlir/dialects/StablehloOps.td b/stablehlo/stablehlo/integrations/python/mlir/dialects/StablehloOps.td ---- stablehlo/stablehlo/integrations/python/mlir/dialects/StablehloOps.td -+++ stablehlo/stablehlo/integrations/python/mlir/dialects/StablehloOps.td -@@ -17,7 +17,6 @@ - #ifndef STABLEHLO_INTEGRATIONS_PYTHON_STABLEHLO_OPS - #define STABLEHLO_INTEGRATIONS_PYTHON_STABLEHLO_OPS - --include "mlir/Bindings/Python/Attributes.td" - include "stablehlo/dialect/StablehloOps.td" - - #endif diff --ruN a/stablehlo/stablehlo/tests/stablehlo_canonicalize_dynamism.mlir b/stablehlo/stablehlo/tests/stablehlo_canonicalize_dynamism.mlir --- stablehlo/stablehlo/tests/stablehlo_canonicalize_dynamism.mlir +++ stablehlo/stablehlo/tests/stablehlo_canonicalize_dynamism.mlir diff --git a/third_party/stablehlo/workspace.bzl b/third_party/stablehlo/workspace.bzl index 2ca73b67e5523f..deb356c5826bac 100644 --- a/third_party/stablehlo/workspace.bzl +++ b/third_party/stablehlo/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): # LINT.IfChange - STABLEHLO_COMMIT = "f4c43e121aa94a27eb1b846f2e1d776732203400" - STABLEHLO_SHA256 = "9dacfeb1bdb8468b8736995cffc5db42de04b6d1b2f0572751aa41f10232d9a8" + STABLEHLO_COMMIT = "8816d0581d9a5fb7d212affef858e991a349ad6b" + STABLEHLO_SHA256 = "4aefbd629468416d731863ae4ae94a7720023e02ca73410690c10b66de4a0990" # LINT.ThenChange(Google-internal path) tf_http_archive( From 7f16035462bb93b58e598974a5b542f09cbd00a2 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Tue, 25 Jul 2023 19:22:24 +0000 Subject: [PATCH 128/410] Add TF_PYTHON_VERSION env var to set Python version in macOS arm64 CI --- .../osx/arm64/tensorflow_as_build_nightly.Jenkinsfile | 8 ++++---- .../osx/arm64/tensorflow_as_build_release.Jenkinsfile | 8 ++++---- .../osx/arm64/tensorflow_as_test_ci.Jenkinsfile | 5 ++--- .../osx/arm64/tensorflow_as_test_nightly.Jenkinsfile | 11 ++++------- .../osx/arm64/tensorflow_as_test_release.Jenkinsfile | 11 ++++------- 5 files changed, 18 insertions(+), 25 deletions(-) diff --git a/tensorflow/tools/ci_build/osx/arm64/tensorflow_as_build_nightly.Jenkinsfile b/tensorflow/tools/ci_build/osx/arm64/tensorflow_as_build_nightly.Jenkinsfile index 65256e37b366dc..d565f227684f75 100644 --- a/tensorflow/tools/ci_build/osx/arm64/tensorflow_as_build_nightly.Jenkinsfile +++ b/tensorflow/tools/ci_build/osx/arm64/tensorflow_as_build_nightly.Jenkinsfile @@ -26,6 +26,7 @@ pipeline { environment { PYENV_ROOT="$HOME/.pyenv" PATH="$PYENV_ROOT/shims:/opt/homebrew/bin/:$PATH" + TF_PYTHON_VERSION=3.9 } steps { dir('tensorflow') { @@ -52,7 +53,6 @@ pipeline { sh ''' /opt/homebrew/bin/bazel --bazelrc="${WORKSPACE}/tensorflow/tensorflow/tools/ci_build/osx/arm64/.macos.bazelrc" build \ - --action_env PYTHON_LIB_PATH="/Users/admin/.pyenv/versions/3.9.13/lib/python3.9/site-packages" \ //tensorflow/tools/pip_package:build_pip_package ./bazel-bin/tensorflow/tools/pip_package/build_pip_package \ @@ -84,6 +84,7 @@ pipeline { environment { PYENV_ROOT="$HOME/.pyenv" PATH="$PYENV_ROOT/shims:/opt/homebrew/bin/:$PATH" + TF_PYTHON_VERSION=3.10 } steps { dir('tensorflow') { @@ -109,7 +110,6 @@ pipeline { sh ''' /opt/homebrew/bin/bazel --bazelrc="${WORKSPACE}/tensorflow/tensorflow/tools/ci_build/osx/arm64/.macos.bazelrc" build \ - --action_env PYTHON_LIB_PATH="/Users/admin/.pyenv/versions/3.10.4/lib/python3.10/site-packages" \ //tensorflow/tools/pip_package:build_pip_package ./bazel-bin/tensorflow/tools/pip_package/build_pip_package \ @@ -140,6 +140,7 @@ pipeline { environment { PYENV_ROOT="$HOME/.pyenv" PATH="$PYENV_ROOT/shims:/opt/homebrew/bin/:$PATH" + TF_PYTHON_VERSION=3.11 } steps { @@ -166,7 +167,6 @@ pipeline { sh ''' /opt/homebrew/bin/bazel --bazelrc="${WORKSPACE}/tensorflow/tensorflow/tools/ci_build/osx/arm64/.macos.bazelrc" build \ - --action_env PYTHON_LIB_PATH="/Users/admin/.pyenv/versions/3.11.2/lib/python3.11/site-packages" \ //tensorflow/tools/pip_package:build_pip_package ./bazel-bin/tensorflow/tools/pip_package/build_pip_package \ @@ -198,4 +198,4 @@ pipeline { build 'upload-nightly' } } -} \ No newline at end of file +} diff --git a/tensorflow/tools/ci_build/osx/arm64/tensorflow_as_build_release.Jenkinsfile b/tensorflow/tools/ci_build/osx/arm64/tensorflow_as_build_release.Jenkinsfile index 5b99e06bef3514..0cd2cfeb295ff5 100644 --- a/tensorflow/tools/ci_build/osx/arm64/tensorflow_as_build_release.Jenkinsfile +++ b/tensorflow/tools/ci_build/osx/arm64/tensorflow_as_build_release.Jenkinsfile @@ -29,6 +29,7 @@ pipeline { environment { PYENV_ROOT="$HOME/.pyenv" PATH="$PYENV_ROOT/shims:/opt/homebrew/bin/:$PATH" + TF_PYTHON_VERSION=3.9 } steps { dir('tensorflow') { @@ -50,7 +51,6 @@ pipeline { sh ''' /opt/homebrew/bin/bazel --bazelrc="${WORKSPACE}/tensorflow/tensorflow/tools/ci_build/osx/arm64/.macos.bazelrc" build \ - --action_env PYTHON_LIB_PATH="/Users/admin/.pyenv/versions/3.9.13/lib/python3.9/site-packages" \ //tensorflow/tools/pip_package:build_pip_package ./bazel-bin/tensorflow/tools/pip_package/build_pip_package \ @@ -78,6 +78,7 @@ pipeline { environment { PYENV_ROOT="$HOME/.pyenv" PATH="$PYENV_ROOT/shims:/opt/homebrew/bin/:$PATH" + TF_PYTHON_VERSION=3.10 } steps { dir('tensorflow') { @@ -99,7 +100,6 @@ pipeline { sh ''' /opt/homebrew/bin/bazel --bazelrc="${WORKSPACE}/tensorflow/tensorflow/tools/ci_build/osx/arm64/.macos.bazelrc" build \ - --action_env PYTHON_LIB_PATH="/Users/admin/.pyenv/versions/3.10.4/lib/python3.10/site-packages" \ //tensorflow/tools/pip_package:build_pip_package ./bazel-bin/tensorflow/tools/pip_package/build_pip_package \ @@ -127,6 +127,7 @@ pipeline { environment { PYENV_ROOT="$HOME/.pyenv" PATH="$PYENV_ROOT/shims:/opt/homebrew/bin/:$PATH" + TF_PYTHON_VERSION=3.11 } steps { dir('tensorflow') { @@ -148,7 +149,6 @@ pipeline { sh ''' /opt/homebrew/bin/bazel --bazelrc="${WORKSPACE}/tensorflow/tensorflow/tools/ci_build/osx/arm64/.macos.bazelrc" build \ - --action_env PYTHON_LIB_PATH="/Users/admin/.pyenv/versions/3.11.2/lib/python3.11/site-packages" \ //tensorflow/tools/pip_package:build_pip_package ./bazel-bin/tensorflow/tools/pip_package/build_pip_package \ @@ -172,4 +172,4 @@ pipeline { } } } -} \ No newline at end of file +} diff --git a/tensorflow/tools/ci_build/osx/arm64/tensorflow_as_test_ci.Jenkinsfile b/tensorflow/tools/ci_build/osx/arm64/tensorflow_as_test_ci.Jenkinsfile index 4d1f67bc8711b8..04b254dff6ccb6 100644 --- a/tensorflow/tools/ci_build/osx/arm64/tensorflow_as_test_ci.Jenkinsfile +++ b/tensorflow/tools/ci_build/osx/arm64/tensorflow_as_test_ci.Jenkinsfile @@ -26,6 +26,7 @@ pipeline { environment { PYENV_ROOT="$HOME/.pyenv" PATH="$PYENV_ROOT/shims:/opt/homebrew/bin/:$PATH" + TF_PYTHON_VERSION=3.11 } steps { sh ''' @@ -47,8 +48,6 @@ pipeline { sh ''' bazel --bazelrc="${WORKSPACE}/tensorflow/tools/ci_build/osx/arm64/.macos.bazelrc" test \ - --action_env PYTHON_LIB_PATH="/Users/admin/.pyenv/versions/3.11.2/lib/python3.11/site-packages" \ - --action_env PYTHON_BIN_PATH="/Users/admin/.pyenv/versions/3.11.2/bin/python3.11" \ --config=nonpip ''' @@ -57,4 +56,4 @@ pipeline { } } } -} \ No newline at end of file +} diff --git a/tensorflow/tools/ci_build/osx/arm64/tensorflow_as_test_nightly.Jenkinsfile b/tensorflow/tools/ci_build/osx/arm64/tensorflow_as_test_nightly.Jenkinsfile index 2f66102a901b48..749eaf58475823 100644 --- a/tensorflow/tools/ci_build/osx/arm64/tensorflow_as_test_nightly.Jenkinsfile +++ b/tensorflow/tools/ci_build/osx/arm64/tensorflow_as_test_nightly.Jenkinsfile @@ -26,6 +26,7 @@ pipeline { environment { PYENV_ROOT="$HOME/.pyenv" PATH="$PYENV_ROOT/shims:/opt/homebrew/bin/:$PATH" + TF_PYTHON_VERSION=3.9 } steps { @@ -47,8 +48,6 @@ pipeline { sh ''' bazel --bazelrc="${WORKSPACE}/tensorflow/tools/ci_build/osx/arm64/.macos.bazelrc" test \ - --action_env PYTHON_LIB_PATH="/Users/admin/.pyenv/versions/3.9.13/lib/python3.9/site-packages" \ - --action_env PYTHON_BIN_PATH="/Users/admin/.pyenv/versions/3.9.13/bin/python3.9" \ --config=nonpip ''' } @@ -60,6 +59,7 @@ pipeline { environment { PYENV_ROOT="$HOME/.pyenv" PATH="$PYENV_ROOT/shims:/opt/homebrew/bin/:$PATH" + TF_PYTHON_VERSION=3.10 } steps { sh ''' @@ -81,8 +81,6 @@ pipeline { sh ''' bazel --bazelrc="${WORKSPACE}/tensorflow/tools/ci_build/osx/arm64/.macos.bazelrc" test \ - --action_env PYTHON_LIB_PATH="/Users/admin/.pyenv/versions/3.10.4/lib/python3.10/site-packages" \ - --action_env PYTHON_BIN_PATH="/Users/admin/.pyenv/versions/3.10.4/bin/python3.10" \ --config=nonpip ''' @@ -95,6 +93,7 @@ pipeline { environment { PYENV_ROOT="$HOME/.pyenv" PATH="$PYENV_ROOT/shims:/opt/homebrew/bin/:$PATH" + TF_PYTHON_VERSION=3.11 } steps { sh ''' @@ -116,8 +115,6 @@ pipeline { sh ''' bazel --bazelrc="${WORKSPACE}/tensorflow/tools/ci_build/osx/arm64/.macos.bazelrc" test \ - --action_env PYTHON_LIB_PATH="/Users/admin/.pyenv/versions/3.11.2/lib/python3.11/site-packages" \ - --action_env PYTHON_BIN_PATH="/Users/admin/.pyenv/versions/3.11.2/bin/python3.11" \ --config=nonpip ''' @@ -126,4 +123,4 @@ pipeline { } } } -} \ No newline at end of file +} diff --git a/tensorflow/tools/ci_build/osx/arm64/tensorflow_as_test_release.Jenkinsfile b/tensorflow/tools/ci_build/osx/arm64/tensorflow_as_test_release.Jenkinsfile index 7518d329c2edbd..565630bd2443c7 100644 --- a/tensorflow/tools/ci_build/osx/arm64/tensorflow_as_test_release.Jenkinsfile +++ b/tensorflow/tools/ci_build/osx/arm64/tensorflow_as_test_release.Jenkinsfile @@ -29,6 +29,7 @@ pipeline { environment { PYENV_ROOT="$HOME/.pyenv" PATH="$PYENV_ROOT/shims:/opt/homebrew/bin/:$PATH" + TF_PYTHON_VERSION=3.9 } steps { @@ -50,8 +51,6 @@ pipeline { sh ''' bazel --bazelrc="${WORKSPACE}/tensorflow/tools/ci_build/osx/arm64/.macos.bazelrc" test \ - --action_env PYTHON_LIB_PATH="/Users/admin/.pyenv/versions/3.9.13/lib/python3.9/site-packages" \ - --action_env PYTHON_BIN_PATH="/Users/admin/.pyenv/versions/3.9.13/bin/python3.9" \ --config=nonpip ''' } @@ -63,6 +62,7 @@ pipeline { environment { PYENV_ROOT="$HOME/.pyenv" PATH="$PYENV_ROOT/shims:/opt/homebrew/bin/:$PATH" + TF_PYTHON_VERSION=3.10 } steps { sh ''' @@ -83,8 +83,6 @@ pipeline { sh ''' bazel --bazelrc="${WORKSPACE}/tensorflow/tools/ci_build/osx/arm64/.macos.bazelrc" test \ - --action_env PYTHON_LIB_PATH="/Users/admin/.pyenv/versions/3.10.4/lib/python3.10/site-packages" \ - --action_env PYTHON_BIN_PATH="/Users/admin/.pyenv/versions/3.10.4/bin/python3.10" \ --config=nonpip ''' @@ -97,6 +95,7 @@ pipeline { environment { PYENV_ROOT="$HOME/.pyenv" PATH="$PYENV_ROOT/shims:/opt/homebrew/bin/:$PATH" + TF_PYTHON_VERSION=3.11 } steps { sh ''' @@ -117,8 +116,6 @@ pipeline { sh ''' bazel --bazelrc="${WORKSPACE}/tensorflow/tools/ci_build/osx/arm64/.macos.bazelrc" test \ - --action_env PYTHON_LIB_PATH="/Users/admin/.pyenv/versions/3.11.2/lib/python3.11/site-packages" \ - --action_env PYTHON_BIN_PATH="/Users/admin/.pyenv/versions/3.11.2/bin/python3.11" \ --config=nonpip ''' @@ -127,4 +124,4 @@ pipeline { } } } -} \ No newline at end of file +} From 95dd4033b472dc6949410545077e71db96ff9a49 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 25 Jul 2023 12:30:55 -0700 Subject: [PATCH 129/410] Add a new `GetOutputMemoryKinds` method to `PjRtExecutable`, and propagate the info through IFRT. Finally, it can be used in Jax frontend (similar to `GetParameterShardings` and `GetOutputShardings`). PiperOrigin-RevId: 550963792 --- tensorflow/compiler/xla/pjrt/pjrt_c_api_client.h | 10 ++++++++++ tensorflow/compiler/xla/pjrt/pjrt_executable.h | 7 +++++++ .../compiler/xla/pjrt/pjrt_stream_executor_client.cc | 5 +++++ .../compiler/xla/pjrt/pjrt_stream_executor_client.h | 3 +++ tensorflow/compiler/xla/pjrt/tf_pjrt_client.h | 4 ++++ tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.h | 5 +++++ tensorflow/compiler/xla/python/ifrt/executable.h | 8 ++++++-- tensorflow/compiler/xla/python/ifrt/mock.h | 2 ++ .../compiler/xla/python/pjrt_ifrt/pjrt_executable.h | 6 ++++++ tensorflow/compiler/xla/python/py_executable.cc | 5 +++++ tensorflow/compiler/xla/python/py_executable.h | 5 +++-- tensorflow/compiler/xla/python/xla.cc | 4 ++++ tensorflow/compiler/xla/python/xla_client.py | 2 +- .../compiler/xla/python/xla_extension/__init__.pyi | 2 ++ 14 files changed, 63 insertions(+), 5 deletions(-) diff --git a/tensorflow/compiler/xla/pjrt/pjrt_c_api_client.h b/tensorflow/compiler/xla/pjrt/pjrt_c_api_client.h index a0417f124f3f80..a7138dd0034b45 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_c_api_client.h +++ b/tensorflow/compiler/xla/pjrt/pjrt_c_api_client.h @@ -384,6 +384,11 @@ class PjRtCApiExecutable : public PjRtExecutable { StatusOr>> GetHloModules() const override; + StatusOr>> GetOutputMemoryKinds() + const override { + return Unimplemented("PJRT C API does not support GetOutputMemoryKinds"); + } + const PJRT_Api* pjrt_c_api() const { return c_api_; } PJRT_Executable* c_executable() const { return executable_.get(); } @@ -432,6 +437,11 @@ class PjRtCApiLoadedExecutable : public PjRtLoadedExecutable { return executable_->GetHloModules(); } + StatusOr>> GetOutputMemoryKinds() + const override { + return executable_->GetOutputMemoryKinds(); + } + StatusOr>>> Execute( absl::Span> argument_handles, const ExecuteOptions& options, diff --git a/tensorflow/compiler/xla/pjrt/pjrt_executable.h b/tensorflow/compiler/xla/pjrt/pjrt_executable.h index bf9fb9c3f2665c..3dc2be6a57d858 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_executable.h +++ b/tensorflow/compiler/xla/pjrt/pjrt_executable.h @@ -275,6 +275,13 @@ class PjRtExecutable { // `GetHloModules()`. virtual StatusOr> GetOutputShapes() const; + // Returns a list of lists of memory kind strings for output. The returned + // value is `[num_programs, num_output]`. The size of the outer list should be + // equal to `GetHloModules()`. Under SPMD, one can use + // `GetOutputMemoryKinds().front()`. + virtual StatusOr>> + GetOutputMemoryKinds() const = 0; + // Returns a list of parameter OpSharding protos. virtual std::optional> GetParameterShardings() const; diff --git a/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc b/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc index 8880d2f6b94c8d..3988575f923149 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc +++ b/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc @@ -2843,6 +2843,11 @@ PjRtStreamExecutorExecutable::GetHloModules() const { return std::move(modules); } +StatusOr>> +PjRtStreamExecutorExecutable::GetOutputMemoryKinds() const { + return Unimplemented("GetOutputMemoryKinds is not supported."); +} + StatusOr PjRtStreamExecutorClient::GetExecutableExtras(CompileOptions* options) { ExecutableExtras extras; diff --git a/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h b/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h index 09ebfa10430079..185120813ab5ba 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h +++ b/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h @@ -812,6 +812,9 @@ class PjRtStreamExecutorExecutable : public PjRtLoadedExecutable { StatusOr>> GetHloModules() const override; + StatusOr>> GetOutputMemoryKinds() + const override; + using PjRtLoadedExecutable::Execute; StatusOr>>> Execute( absl::Span> argument_handles, diff --git a/tensorflow/compiler/xla/pjrt/tf_pjrt_client.h b/tensorflow/compiler/xla/pjrt/tf_pjrt_client.h index 215d84e0054fe3..8c7303287f12a0 100644 --- a/tensorflow/compiler/xla/pjrt/tf_pjrt_client.h +++ b/tensorflow/compiler/xla/pjrt/tf_pjrt_client.h @@ -130,6 +130,10 @@ class TfPjRtExecutable : public PjRtLoadedExecutable { const override { return wrapped_->GetHloModules(); } + StatusOr>> GetOutputMemoryKinds() + const override { + return wrapped_->GetOutputMemoryKinds(); + } using PjRtLoadedExecutable::Execute; StatusOr>>> Execute( absl::Span> argument_handles, diff --git a/tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.h b/tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.h index e15e5db1968603..bf23a5bfb02f87 100644 --- a/tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.h +++ b/tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.h @@ -379,6 +379,11 @@ class TfrtCpuExecutable final : public PjRtLoadedExecutable { cpu_executable_->shared_module()}; } + StatusOr>> GetOutputMemoryKinds() + const override { + return Unimplemented("GetOutputMemoryKinds is not supported."); + } + StatusOr GetCompiledMemoryStats() const override { CompiledMemoryStats memory_stats = CompiledMemoryStats(); memory_stats.generated_code_size_in_bytes = SizeOfGeneratedCodeInBytes(); diff --git a/tensorflow/compiler/xla/python/ifrt/executable.h b/tensorflow/compiler/xla/python/ifrt/executable.h index cb73311133a8ec..f2b2dd44d730c8 100644 --- a/tensorflow/compiler/xla/python/ifrt/executable.h +++ b/tensorflow/compiler/xla/python/ifrt/executable.h @@ -19,8 +19,6 @@ limitations under the License. #include #include #include -#include -#include #include #include "absl/types/span.h" @@ -119,6 +117,12 @@ class LoadedExecutable // Return an HloModule (optimized) per partition. virtual StatusOr>> GetHloModules() const = 0; + // Returns a list of lists of memory kind strings for output. The returned + // value is `[num_programs, num_output]`. The size of the outer list should be + // equal to `GetHloModules()`. Under SPMD, one can use + // `GetOutputMemoryKinds().front()`. + virtual StatusOr>> + GetOutputMemoryKinds() const = 0; // Returns named values for cost properties of this executable (such as // operations, size of input/outputs, and run time estimate). Properties may diff --git a/tensorflow/compiler/xla/python/ifrt/mock.h b/tensorflow/compiler/xla/python/ifrt/mock.h index 38c845c21ec4da..3cb77ff4275257 100644 --- a/tensorflow/compiler/xla/python/ifrt/mock.h +++ b/tensorflow/compiler/xla/python/ifrt/mock.h @@ -229,6 +229,8 @@ class MockLoadedExecutable final (const, final)); MOCK_METHOD(std::optional>, GetOutputShardings, (), (const, final)); + MOCK_METHOD(absl::StatusOr>>, + GetOutputMemoryKinds, (), (const, final)); MOCK_METHOD(StatusOr>>, GetHloModules, (), (const, final)); MOCK_METHOD((StatusOrGetHloModules(); } + StatusOr>> GetOutputMemoryKinds() + const override { + DCHECK(this); + return pjrt_loaded_executable_->GetOutputMemoryKinds(); + } + PjRtCompatibleClient* client() const override { DCHECK(this); return client_; diff --git a/tensorflow/compiler/xla/python/py_executable.cc b/tensorflow/compiler/xla/python/py_executable.cc index a71b9d441e66af..84a1ea7310c7f2 100644 --- a/tensorflow/compiler/xla/python/py_executable.cc +++ b/tensorflow/compiler/xla/python/py_executable.cc @@ -354,6 +354,11 @@ PyLoadedExecutable::HloModules() const { return ifrt_loaded_executable_->GetHloModules(); } +StatusOr>> +PyLoadedExecutable::GetOutputMemoryKinds() const { + return ifrt_loaded_executable_->GetOutputMemoryKinds(); +} + std::optional> PyLoadedExecutable::GetParameterShardings() const { return ifrt_loaded_executable_->GetParameterShardings(); diff --git a/tensorflow/compiler/xla/python/py_executable.h b/tensorflow/compiler/xla/python/py_executable.h index 1221f1d48466e5..c16ff0d307c303 100644 --- a/tensorflow/compiler/xla/python/py_executable.h +++ b/tensorflow/compiler/xla/python/py_executable.h @@ -27,11 +27,9 @@ limitations under the License. #include "tensorflow/compiler/xla/pjrt/pjrt_client.h" #include "tensorflow/compiler/xla/python/pjrt_ifrt/pjrt_executable.h" #include "tensorflow/compiler/xla/python/py_array.h" -#include "tensorflow/compiler/xla/python/py_buffer.h" #include "tensorflow/compiler/xla/python/py_client.h" #include "tensorflow/compiler/xla/python/traceback.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/types.h" namespace xla { @@ -169,6 +167,9 @@ class PyLoadedExecutable StatusOr>> HloModules() const; + StatusOr>> GetOutputMemoryKinds() + const; + std::optional> GetParameterShardings() const; std::optional> GetOutputShardings() const; diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc index f156dbb50448aa..ba86a32c5609f7 100644 --- a/tensorflow/compiler/xla/python/xla.cc +++ b/tensorflow/compiler/xla/python/xla.cc @@ -654,6 +654,8 @@ PYBIND11_MODULE(xla_extension, m) { py::arg("arguments"), py::arg("with_tokens") = false) .def("hlo_modules", xla::ValueOrThrowWrapper(&PyLoadedExecutable::HloModules)) + .def("get_output_memory_kinds", + xla::ValueOrThrowWrapper(&PyLoadedExecutable::GetOutputMemoryKinds)) .def("get_output_shardings", &PyLoadedExecutable::GetOutputShardings) .def("get_parameter_shardings", &PyLoadedExecutable::GetParameterShardings) @@ -943,6 +945,8 @@ PYBIND11_MODULE(xla_extension, m) { py::class_>(m, "Executable") .def("hlo_modules", xla::ValueOrThrowWrapper(&PjRtExecutable::GetHloModules)) + .def("get_output_memory_kinds", + xla::ValueOrThrowWrapper(&PjRtExecutable::GetOutputMemoryKinds)) .def("get_output_shardings", &PjRtExecutable::GetOutputShardings) .def("get_parameter_shardings", &PjRtExecutable::GetParameterShardings) .def("get_compiled_memory_stats", diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index 90e1414154d5c5..6bfb6f15e4cfce 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -44,7 +44,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. -_version = 172 +_version = 173 # Version number for MLIR:Python components. mlir_api_version = 54 diff --git a/tensorflow/compiler/xla/python/xla_extension/__init__.pyi b/tensorflow/compiler/xla/python/xla_extension/__init__.pyi index 37616f46d5bcab..64508460b9710d 100644 --- a/tensorflow/compiler/xla/python/xla_extension/__init__.pyi +++ b/tensorflow/compiler/xla/python/xla_extension/__init__.pyi @@ -553,6 +553,7 @@ class LoadedExecutable: self, arguments: Sequence[List[ArrayImpl]], with_tokens: bool = ...) -> ExecuteResults: ... def hlo_modules(self) -> List[HloModule]: ... + def get_output_memory_kinds(self) -> List[List[str]]: ... def get_compiled_memory_stats(self) -> CompiledMemoryStats: ... def keep_alive(self) -> None: ... def compile_options(self) -> CompileOptions: ... @@ -562,6 +563,7 @@ class LoadedExecutable: class Executable: def hlo_modules(self) -> List[HloModule]: ... + def get_output_memory_kinds(self) -> List[List[str]]: ... def get_output_shardings(self) -> Optional[List[OpSharding]]: ... def get_parameter_shardings(self) -> Optional[List[OpSharding]]: ... def get_compiled_memory_stats(self) -> CompiledMemoryStats: ... From b863f9f52757e349cf57d726582e69b4fbebb2c3 Mon Sep 17 00:00:00 2001 From: Deqiang Chen Date: Tue, 25 Jul 2023 12:34:05 -0700 Subject: [PATCH 130/410] Remove the usage of obsolete execution_context PiperOrigin-RevId: 550964622 --- tensorflow/core/tfrt/mlrt/kernel/kernel.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/tensorflow/core/tfrt/mlrt/kernel/kernel.cc b/tensorflow/core/tfrt/mlrt/kernel/kernel.cc index b47be8d5aa6520..2149917cf95a65 100644 --- a/tensorflow/core/tfrt/mlrt/kernel/kernel.cc +++ b/tensorflow/core/tfrt/mlrt/kernel/kernel.cc @@ -233,7 +233,6 @@ void AsyncWhileOp::OnPredicateReady( std::move(final_promise).SetError(status); } }); - execution_context.Fail(status); return; } From 4b3688657a3aca43e57ef8e9d46e8ec01f0c63b7 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 25 Jul 2023 12:37:31 -0700 Subject: [PATCH 131/410] Remove python3.8 from hermetic python PiperOrigin-RevId: 550965442 --- ci/official/requirements_updater/BUILD.bazel | 8 - ci/official/requirements_updater/WORKSPACE | 1 - ci/official/requirements_updater/updater.sh | 2 +- requirements_lock_3_8.txt | 443 ------------------ .../tools/toolchains/python/python_repo.bzl | 2 +- 5 files changed, 2 insertions(+), 454 deletions(-) delete mode 100644 requirements_lock_3_8.txt diff --git a/ci/official/requirements_updater/BUILD.bazel b/ci/official/requirements_updater/BUILD.bazel index 5377418368b7c8..4f2bb9b8d9edd4 100644 --- a/ci/official/requirements_updater/BUILD.bazel +++ b/ci/official/requirements_updater/BUILD.bazel @@ -13,18 +13,10 @@ # limitations under the License. # ============================================================================== -load("@python//3.8:defs.bzl", compile_pip_requirements_3_8 = "compile_pip_requirements") load("@python//3.9:defs.bzl", compile_pip_requirements_3_9 = "compile_pip_requirements") load("@python//3.10:defs.bzl", compile_pip_requirements_3_10 = "compile_pip_requirements") load("@python//3.11:defs.bzl", compile_pip_requirements_3_11 = "compile_pip_requirements") -compile_pip_requirements_3_8( - name = "requirements_3_8", - extra_args = ["--allow-unsafe"], - requirements_in = "requirements.in", - requirements_txt = "requirements_lock_3_8.txt", -) - compile_pip_requirements_3_9( name = "requirements_3_9", extra_args = ["--allow-unsafe"], diff --git a/ci/official/requirements_updater/WORKSPACE b/ci/official/requirements_updater/WORKSPACE index 1bd9e86aba9e7e..362946bd03a6d2 100644 --- a/ci/official/requirements_updater/WORKSPACE +++ b/ci/official/requirements_updater/WORKSPACE @@ -28,7 +28,6 @@ python_register_multi_toolchains( default_version = default_python_version, ignore_root_user_error = True, python_versions = [ - "3.8", "3.9", "3.10", "3.11", diff --git a/ci/official/requirements_updater/updater.sh b/ci/official/requirements_updater/updater.sh index 9dd5036fcf4046..9ee382b7612c23 100644 --- a/ci/official/requirements_updater/updater.sh +++ b/ci/official/requirements_updater/updater.sh @@ -21,7 +21,7 @@ mv BUILD.bazel BUILD -SUPPORTED_VERSIONS=("3_8" "3_9" "3_10" "3_11") +SUPPORTED_VERSIONS=("3_9" "3_10" "3_11") for VERSION in "${SUPPORTED_VERSIONS[@]}" do diff --git a/requirements_lock_3_8.txt b/requirements_lock_3_8.txt deleted file mode 100644 index 52dc137e39f429..00000000000000 --- a/requirements_lock_3_8.txt +++ /dev/null @@ -1,443 +0,0 @@ -absl-py==1.4.0 \ - --hash=sha256:0d3fe606adfa4f7db64792dd4c7aee4ee0c38ab75dfd353b7a83ed3e957fcb47 \ - --hash=sha256:d2c244d01048ba476e7c080bd2c6df5e141d211de80223460d5b3b8a2a58433d - # via tb-nightly -cachetools==5.3.1 \ - --hash=sha256:95ef631eeaea14ba2e36f06437f36463aac3a096799e876ee55e5cdccb102590 \ - --hash=sha256:dce83f2d9b4e1f732a8cd44af8e8fab2dbe46201467fc98b3ef8f269092bf62b - # via google-auth -certifi==2023.5.7 \ - --hash=sha256:0f0d56dc5a6ad56fd4ba36484d6cc34451e1c6548c61daad8c320169f91eddc7 \ - --hash=sha256:c6c2e98f5c7869efca1f8916fed228dd91539f9f1b444c314c06eef02980c716 - # via requests -charset-normalizer==3.1.0 \ - --hash=sha256:04afa6387e2b282cf78ff3dbce20f0cc071c12dc8f685bd40960cc68644cfea6 \ - --hash=sha256:04eefcee095f58eaabe6dc3cc2262f3bcd776d2c67005880894f447b3f2cb9c1 \ - --hash=sha256:0be65ccf618c1e7ac9b849c315cc2e8a8751d9cfdaa43027d4f6624bd587ab7e \ - --hash=sha256:0c95f12b74681e9ae127728f7e5409cbbef9cd914d5896ef238cc779b8152373 \ - --hash=sha256:0ca564606d2caafb0abe6d1b5311c2649e8071eb241b2d64e75a0d0065107e62 \ - --hash=sha256:10c93628d7497c81686e8e5e557aafa78f230cd9e77dd0c40032ef90c18f2230 \ - --hash=sha256:11d117e6c63e8f495412d37e7dc2e2fff09c34b2d09dbe2bee3c6229577818be \ - --hash=sha256:11d3bcb7be35e7b1bba2c23beedac81ee893ac9871d0ba79effc7fc01167db6c \ - --hash=sha256:12a2b561af122e3d94cdb97fe6fb2bb2b82cef0cdca131646fdb940a1eda04f0 \ - --hash=sha256:12d1a39aa6b8c6f6248bb54550efcc1c38ce0d8096a146638fd4738e42284448 \ - --hash=sha256:1435ae15108b1cb6fffbcea2af3d468683b7afed0169ad718451f8db5d1aff6f \ - --hash=sha256:1c60b9c202d00052183c9be85e5eaf18a4ada0a47d188a83c8f5c5b23252f649 \ - --hash=sha256:1e8fcdd8f672a1c4fc8d0bd3a2b576b152d2a349782d1eb0f6b8e52e9954731d \ - --hash=sha256:20064ead0717cf9a73a6d1e779b23d149b53daf971169289ed2ed43a71e8d3b0 \ - --hash=sha256:21fa558996782fc226b529fdd2ed7866c2c6ec91cee82735c98a197fae39f706 \ - --hash=sha256:22908891a380d50738e1f978667536f6c6b526a2064156203d418f4856d6e86a \ - --hash=sha256:3160a0fd9754aab7d47f95a6b63ab355388d890163eb03b2d2b87ab0a30cfa59 \ - --hash=sha256:322102cdf1ab682ecc7d9b1c5eed4ec59657a65e1c146a0da342b78f4112db23 \ - --hash=sha256:34e0a2f9c370eb95597aae63bf85eb5e96826d81e3dcf88b8886012906f509b5 \ - --hash=sha256:3573d376454d956553c356df45bb824262c397c6e26ce43e8203c4c540ee0acb \ - --hash=sha256:3747443b6a904001473370d7810aa19c3a180ccd52a7157aacc264a5ac79265e \ - --hash=sha256:38e812a197bf8e71a59fe55b757a84c1f946d0ac114acafaafaf21667a7e169e \ - --hash=sha256:3a06f32c9634a8705f4ca9946d667609f52cf130d5548881401f1eb2c39b1e2c \ - --hash=sha256:3a5fc78f9e3f501a1614a98f7c54d3969f3ad9bba8ba3d9b438c3bc5d047dd28 \ - --hash=sha256:3d9098b479e78c85080c98e1e35ff40b4a31d8953102bb0fd7d1b6f8a2111a3d \ - --hash=sha256:3dc5b6a8ecfdc5748a7e429782598e4f17ef378e3e272eeb1340ea57c9109f41 \ - --hash=sha256:4155b51ae05ed47199dc5b2a4e62abccb274cee6b01da5b895099b61b1982974 \ - --hash=sha256:49919f8400b5e49e961f320c735388ee686a62327e773fa5b3ce6721f7e785ce \ - --hash=sha256:53d0a3fa5f8af98a1e261de6a3943ca631c526635eb5817a87a59d9a57ebf48f \ - --hash=sha256:5f008525e02908b20e04707a4f704cd286d94718f48bb33edddc7d7b584dddc1 \ - --hash=sha256:628c985afb2c7d27a4800bfb609e03985aaecb42f955049957814e0491d4006d \ - --hash=sha256:65ed923f84a6844de5fd29726b888e58c62820e0769b76565480e1fdc3d062f8 \ - --hash=sha256:6734e606355834f13445b6adc38b53c0fd45f1a56a9ba06c2058f86893ae8017 \ - --hash=sha256:6baf0baf0d5d265fa7944feb9f7451cc316bfe30e8df1a61b1bb08577c554f31 \ - --hash=sha256:6f4f4668e1831850ebcc2fd0b1cd11721947b6dc7c00bf1c6bd3c929ae14f2c7 \ - --hash=sha256:6f5c2e7bc8a4bf7c426599765b1bd33217ec84023033672c1e9a8b35eaeaaaf8 \ - --hash=sha256:6f6c7a8a57e9405cad7485f4c9d3172ae486cfef1344b5ddd8e5239582d7355e \ - --hash=sha256:7381c66e0561c5757ffe616af869b916c8b4e42b367ab29fedc98481d1e74e14 \ - --hash=sha256:73dc03a6a7e30b7edc5b01b601e53e7fc924b04e1835e8e407c12c037e81adbd \ - --hash=sha256:74db0052d985cf37fa111828d0dd230776ac99c740e1a758ad99094be4f1803d \ - --hash=sha256:75f2568b4189dda1c567339b48cba4ac7384accb9c2a7ed655cd86b04055c795 \ - --hash=sha256:78cacd03e79d009d95635e7d6ff12c21eb89b894c354bd2b2ed0b4763373693b \ - --hash=sha256:80d1543d58bd3d6c271b66abf454d437a438dff01c3e62fdbcd68f2a11310d4b \ - --hash=sha256:830d2948a5ec37c386d3170c483063798d7879037492540f10a475e3fd6f244b \ - --hash=sha256:891cf9b48776b5c61c700b55a598621fdb7b1e301a550365571e9624f270c203 \ - --hash=sha256:8f25e17ab3039b05f762b0a55ae0b3632b2e073d9c8fc88e89aca31a6198e88f \ - --hash=sha256:9a3267620866c9d17b959a84dd0bd2d45719b817245e49371ead79ed4f710d19 \ - --hash=sha256:a04f86f41a8916fe45ac5024ec477f41f886b3c435da2d4e3d2709b22ab02af1 \ - --hash=sha256:aaf53a6cebad0eae578f062c7d462155eada9c172bd8c4d250b8c1d8eb7f916a \ - --hash=sha256:abc1185d79f47c0a7aaf7e2412a0eb2c03b724581139193d2d82b3ad8cbb00ac \ - --hash=sha256:ac0aa6cd53ab9a31d397f8303f92c42f534693528fafbdb997c82bae6e477ad9 \ - --hash=sha256:ac3775e3311661d4adace3697a52ac0bab17edd166087d493b52d4f4f553f9f0 \ - --hash=sha256:b06f0d3bf045158d2fb8837c5785fe9ff9b8c93358be64461a1089f5da983137 \ - --hash=sha256:b116502087ce8a6b7a5f1814568ccbd0e9f6cfd99948aa59b0e241dc57cf739f \ - --hash=sha256:b82fab78e0b1329e183a65260581de4375f619167478dddab510c6c6fb04d9b6 \ - --hash=sha256:bd7163182133c0c7701b25e604cf1611c0d87712e56e88e7ee5d72deab3e76b5 \ - --hash=sha256:c36bcbc0d5174a80d6cccf43a0ecaca44e81d25be4b7f90f0ed7bcfbb5a00909 \ - --hash=sha256:c3af8e0f07399d3176b179f2e2634c3ce9c1301379a6b8c9c9aeecd481da494f \ - --hash=sha256:c84132a54c750fda57729d1e2599bb598f5fa0344085dbde5003ba429a4798c0 \ - --hash=sha256:cb7b2ab0188829593b9de646545175547a70d9a6e2b63bf2cd87a0a391599324 \ - --hash=sha256:cca4def576f47a09a943666b8f829606bcb17e2bc2d5911a46c8f8da45f56755 \ - --hash=sha256:cf6511efa4801b9b38dc5546d7547d5b5c6ef4b081c60b23e4d941d0eba9cbeb \ - --hash=sha256:d16fd5252f883eb074ca55cb622bc0bee49b979ae4e8639fff6ca3ff44f9f854 \ - --hash=sha256:d2686f91611f9e17f4548dbf050e75b079bbc2a82be565832bc8ea9047b61c8c \ - --hash=sha256:d7fc3fca01da18fbabe4625d64bb612b533533ed10045a2ac3dd194bfa656b60 \ - --hash=sha256:dd5653e67b149503c68c4018bf07e42eeed6b4e956b24c00ccdf93ac79cdff84 \ - --hash=sha256:de5695a6f1d8340b12a5d6d4484290ee74d61e467c39ff03b39e30df62cf83a0 \ - --hash=sha256:e0ac8959c929593fee38da1c2b64ee9778733cdf03c482c9ff1d508b6b593b2b \ - --hash=sha256:e1b25e3ad6c909f398df8921780d6a3d120d8c09466720226fc621605b6f92b1 \ - --hash=sha256:e633940f28c1e913615fd624fcdd72fdba807bf53ea6925d6a588e84e1151531 \ - --hash=sha256:e89df2958e5159b811af9ff0f92614dabf4ff617c03a4c1c6ff53bf1c399e0e1 \ - --hash=sha256:ea9f9c6034ea2d93d9147818f17c2a0860d41b71c38b9ce4d55f21b6f9165a11 \ - --hash=sha256:f645caaf0008bacf349875a974220f1f1da349c5dbe7c4ec93048cdc785a3326 \ - --hash=sha256:f8303414c7b03f794347ad062c0516cee0e15f7a612abd0ce1e25caf6ceb47df \ - --hash=sha256:fca62a8301b605b954ad2e9c3666f9d97f63872aa4efcae5492baca2056b74ab - # via requests -google-auth==2.19.1 \ - --hash=sha256:a9cfa88b3e16196845e64a3658eb953992129d13ac7337b064c6546f77c17183 \ - --hash=sha256:ea165e014c7cbd496558796b627c271aa8c18b4cba79dc1cc962b24c5efdfb85 - # via - # google-auth-oauthlib - # tb-nightly -google-auth-oauthlib==1.0.0 \ - --hash=sha256:95880ca704928c300f48194d1770cf5b1462835b6e49db61445a520f793fd5fb \ - --hash=sha256:e375064964820b47221a7e1b7ee1fd77051b6323c3f9e3e19785f78ab67ecfc5 - # via tb-nightly -grpcio==1.54.2 \ - --hash=sha256:0212e2f7fdf7592e4b9d365087da30cb4d71e16a6f213120c89b4f8fb35a3ab3 \ - --hash=sha256:09d4bfd84686cd36fd11fd45a0732c7628308d094b14d28ea74a81db0bce2ed3 \ - --hash=sha256:1e623e0cf99a0ac114f091b3083a1848dbc64b0b99e181473b5a4a68d4f6f821 \ - --hash=sha256:2288d76e4d4aa7ef3fe7a73c1c470b66ea68e7969930e746a8cd8eca6ef2a2ea \ - --hash=sha256:2296356b5c9605b73ed6a52660b538787094dae13786ba53080595d52df13a98 \ - --hash=sha256:2a1e601ee31ef30a9e2c601d0867e236ac54c922d32ed9f727b70dd5d82600d5 \ - --hash=sha256:2be88c081e33f20630ac3343d8ad9f1125f32987968e9c8c75c051c9800896e8 \ - --hash=sha256:33d40954199bddbb6a78f8f6f2b2082660f381cd2583ec860a6c2fa7c8400c08 \ - --hash=sha256:40e1cbf69d6741b40f750f3cccc64326f927ac6145a9914d33879e586002350c \ - --hash=sha256:46a057329938b08e5f0e12ea3d7aed3ecb20a0c34c4a324ef34e00cecdb88a12 \ - --hash=sha256:4864f99aac207e3e45c5e26c6cbb0ad82917869abc2f156283be86c05286485c \ - --hash=sha256:4c44e1a765b31e175c391f22e8fc73b2a2ece0e5e6ff042743d8109b5d2eff9f \ - --hash=sha256:4cb283f630624ebb16c834e5ac3d7880831b07cbe76cb08ab7a271eeaeb8943e \ - --hash=sha256:5008964885e8d23313c8e5ea0d44433be9bfd7e24482574e8cc43c02c02fc796 \ - --hash=sha256:50a9f075eeda5097aa9a182bb3877fe1272875e45370368ac0ee16ab9e22d019 \ - --hash=sha256:51630c92591d6d3fe488a7c706bd30a61594d144bac7dee20c8e1ce78294f474 \ - --hash=sha256:5cc928cfe6c360c1df636cf7991ab96f059666ac7b40b75a769410cc6217df9c \ - --hash=sha256:61f7203e2767800edee7a1e1040aaaf124a35ce0c7fe0883965c6b762defe598 \ - --hash=sha256:66233ccd2a9371158d96e05d082043d47dadb18cbb294dc5accfdafc2e6b02a7 \ - --hash=sha256:70fcac7b94f4c904152809a050164650ac81c08e62c27aa9f156ac518029ebbe \ - --hash=sha256:714242ad0afa63a2e6dabd522ae22e1d76e07060b5af2ddda5474ba4f14c2c94 \ - --hash=sha256:782f4f8662a2157c4190d0f99eaaebc602899e84fb1e562a944e5025929e351c \ - --hash=sha256:7fc2b4edb938c8faa4b3c3ea90ca0dd89b7565a049e8e4e11b77e60e4ed2cc05 \ - --hash=sha256:881d058c5ccbea7cc2c92085a11947b572498a27ef37d3eef4887f499054dca8 \ - --hash=sha256:89dde0ac72a858a44a2feb8e43dc68c0c66f7857a23f806e81e1b7cc7044c9cf \ - --hash=sha256:8cdbcbd687e576d48f7886157c95052825ca9948c0ed2afdc0134305067be88b \ - --hash=sha256:8d6192c37a30a115f4663592861f50e130caed33efc4eec24d92ec881c92d771 \ - --hash=sha256:96a41817d2c763b1d0b32675abeb9179aa2371c72aefdf74b2d2b99a1b92417b \ - --hash=sha256:9bdbb7624d65dc0ed2ed8e954e79ab1724526f09b1efa88dcd9a1815bf28be5f \ - --hash=sha256:9bf88004fe086c786dc56ef8dd6cb49c026833fdd6f42cb853008bce3f907148 \ - --hash=sha256:a08920fa1a97d4b8ee5db2f31195de4a9def1a91bc003544eb3c9e6b8977960a \ - --hash=sha256:a2f5a1f1080ccdc7cbaf1171b2cf384d852496fe81ddedeb882d42b85727f610 \ - --hash=sha256:b04202453941a63b36876a7172b45366dc0cde10d5fd7855c0f4a4e673c0357a \ - --hash=sha256:b38b3de8cff5bc70f8f9c615f51b48eff7313fc9aca354f09f81b73036e7ddfa \ - --hash=sha256:b52d00d1793d290c81ad6a27058f5224a7d5f527867e5b580742e1bd211afeee \ - --hash=sha256:b74ae837368cfffeb3f6b498688a123e6b960951be4dec0e869de77e7fa0439e \ - --hash=sha256:be48496b0e00460717225e7680de57c38be1d8629dc09dadcd1b3389d70d942b \ - --hash=sha256:c0e3155fc5335ec7b3b70f15230234e529ca3607b20a562b6c75fb1b1218874c \ - --hash=sha256:c2392f5b5d84b71d853918687d806c1aa4308109e5ca158a16e16a6be71041eb \ - --hash=sha256:c72956972e4b508dd39fdc7646637a791a9665b478e768ffa5f4fe42123d5de1 \ - --hash=sha256:dc80c9c6b608bf98066a038e0172013a49cfa9a08d53335aefefda2c64fc68f4 \ - --hash=sha256:e416c8baf925b5a1aff31f7f5aecc0060b25d50cce3a5a7255dc5cf2f1d4e5eb \ - --hash=sha256:f8da84bbc61a4e92af54dc96344f328e5822d574f767e9b08e1602bb5ddc254a \ - --hash=sha256:f900ed4ad7a0f1f05d35f955e0943944d5a75f607a836958c6b8ab2a81730ef2 \ - --hash=sha256:fd6c6c29717724acf9fc1847c4515d57e4dc12762452457b9cb37461f30a81bb - # via - # -r ./requirements.in - # tb-nightly -h5py==3.8.0 \ - --hash=sha256:03890b1c123d024fb0239a3279737d5432498c1901c354f8b10d8221d1d16235 \ - --hash=sha256:0fef76e10b9216657fa37e7edff6d8be0709b25bd5066474c229b56cf0098df9 \ - --hash=sha256:26ffc344ec9984d2cd3ca0265007299a8bac8d85c1ad48f4639d8d3aed2af171 \ - --hash=sha256:290e00fa2de74a10688d1bac98d5a9cdd43f14f58e562c580b5b3dfbd358ecae \ - --hash=sha256:33b15aae79e9147aebe1d0e54099cbcde8d65e3e227cd5b59e49b1272aa0e09d \ - --hash=sha256:36761693efbe53df179627a775476dcbc37727d6e920958277a7efbc18f1fb73 \ - --hash=sha256:377865821fe80ad984d003723d6f8890bd54ceeb5981b43c0313b9df95411b30 \ - --hash=sha256:49bc857635f935fa30e92e61ac1e87496df8f260a6945a3235e43a9890426866 \ - --hash=sha256:4a506fc223def428f4329e7e1f9fe1c8c593eab226e7c0942c8d75308ad49950 \ - --hash=sha256:533d7dad466ddb7e3b30af274b630eb7c1a6e4ddf01d1c373a0334dc2152110a \ - --hash=sha256:5fd2252d1fc364ba0e93dd0b7089f4906b66805cb4e6aca7fa8874ac08649647 \ - --hash=sha256:6fead82f0c4000cf38d53f9c030780d81bfa0220218aee13b90b7701c937d95f \ - --hash=sha256:7f3350fc0a8407d668b13247861c2acd23f7f5fe7d060a3ad9b0820f5fcbcae0 \ - --hash=sha256:8f55d9c6c84d7d09c79fb85979e97b81ec6071cc776a97eb6b96f8f6ec767323 \ - --hash=sha256:98a240cd4c1bfd568aaa52ec42d263131a2582dab82d74d3d42a0d954cac12be \ - --hash=sha256:9f6f6ffadd6bfa9b2c5b334805eb4b19ca0a5620433659d8f7fb86692c40a359 \ - --hash=sha256:b685453e538b2b5934c58a644ac3f3b3d0cec1a01b6fb26d57388e9f9b674ad0 \ - --hash=sha256:b7865de06779b14d98068da387333ad9bf2756b5b579cc887fac169bc08f87c3 \ - --hash=sha256:bacaa1c16810dd2b3e4417f8e730971b7c4d53d234de61fe4a918db78e80e1e4 \ - --hash=sha256:bae730580ae928de409d63cbe4fdca4c82c3ad2bed30511d19d34e995d63c77e \ - --hash=sha256:c3389b63222b1c7a158bb7fe69d11ca00066740ec5574596d47a2fe5317f563a \ - --hash=sha256:c873ba9fd4fa875ad62ce0e4891725e257a8fe7f5abdbc17e51a5d54819be55c \ - --hash=sha256:db03e3f2c716205fbdabb34d0848459840585225eb97b4f08998c743821ca323 \ - --hash=sha256:f47f757d1b76f0ecb8aa0508ec8d1b390df67a8b67ee2515dc1b046f3a1596ea \ - --hash=sha256:f891b17e3a3e974e93f9e34e7cca9f530806543571ce078998676a555837d91d - # via -r ./requirements.in -idna==3.4 \ - --hash=sha256:814f528e8dead7d329833b91c5faa87d60bf71824cd12a7530b5526063d02cb4 \ - --hash=sha256:90b77e79eaa3eba6de819a0c442c0b4ceefc341a7a2ab77d7562bf49f425c5c2 - # via requests -importlib-metadata==6.6.0 \ - --hash=sha256:43dd286a2cd8995d5eaef7fee2066340423b818ed3fd70adf0bad5f1fac53fed \ - --hash=sha256:92501cdf9cc66ebd3e612f1b4f0c0765dfa42f0fa38ffb319b6bd84dd675d705 - # via markdown -jax==0.4.7 \ - --hash=sha256:5e7002d74db25f97c99b979d4ba1233b1ef26e1597e5fc468ad11d1c8a9dc4f8 - # via -r ./requirements.in -keras-nightly==2.14.0.dev2023061207 \ - --hash=sha256:210671f010a0b21a5507be86b8e9e909f81b9f321cd3c51e1efdfdd41061919f \ - --hash=sha256:ad869b2bce863e111e4a57c7f5785f56097d93f683b5315df7f59917be1fa279 - # via -r ./requirements.in -lit==16.0.5.post0 \ - --hash=sha256:71745d9e58dad3717735d27e2a9cca0e9ca6861d067da73c307e02fd38c98479 - # via -r ./requirements.in -markdown==3.4.3 \ - --hash=sha256:065fd4df22da73a625f14890dd77eb8040edcbd68794bcd35943be14490608b2 \ - --hash=sha256:8bf101198e004dc93e84a12a7395e31aac6a9c9942848ae1d99b9d72cf9b3520 - # via tb-nightly -markupsafe==2.1.3 \ - --hash=sha256:05fb21170423db021895e1ea1e1f3ab3adb85d1c2333cbc2310f2a26bc77272e \ - --hash=sha256:0a4e4a1aff6c7ac4cd55792abf96c915634c2b97e3cc1c7129578aa68ebd754e \ - --hash=sha256:10bbfe99883db80bdbaff2dcf681dfc6533a614f700da1287707e8a5d78a8431 \ - --hash=sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686 \ - --hash=sha256:1577735524cdad32f9f694208aa75e422adba74f1baee7551620e43a3141f559 \ - --hash=sha256:1b40069d487e7edb2676d3fbdb2b0829ffa2cd63a2ec26c4938b2d34391b4ecc \ - --hash=sha256:282c2cb35b5b673bbcadb33a585408104df04f14b2d9b01d4c345a3b92861c2c \ - --hash=sha256:2c1b19b3aaacc6e57b7e25710ff571c24d6c3613a45e905b1fde04d691b98ee0 \ - --hash=sha256:2ef12179d3a291be237280175b542c07a36e7f60718296278d8593d21ca937d4 \ - --hash=sha256:338ae27d6b8745585f87218a3f23f1512dbf52c26c28e322dbe54bcede54ccb9 \ - --hash=sha256:3c0fae6c3be832a0a0473ac912810b2877c8cb9d76ca48de1ed31e1c68386575 \ - --hash=sha256:3fd4abcb888d15a94f32b75d8fd18ee162ca0c064f35b11134be77050296d6ba \ - --hash=sha256:42de32b22b6b804f42c5d98be4f7e5e977ecdd9ee9b660fda1a3edf03b11792d \ - --hash=sha256:504b320cd4b7eff6f968eddf81127112db685e81f7e36e75f9f84f0df46041c3 \ - --hash=sha256:525808b8019e36eb524b8c68acdd63a37e75714eac50e988180b169d64480a00 \ - --hash=sha256:56d9f2ecac662ca1611d183feb03a3fa4406469dafe241673d521dd5ae92a155 \ - --hash=sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac \ - --hash=sha256:65c1a9bcdadc6c28eecee2c119465aebff8f7a584dd719facdd9e825ec61ab52 \ - --hash=sha256:68e78619a61ecf91e76aa3e6e8e33fc4894a2bebe93410754bd28fce0a8a4f9f \ - --hash=sha256:69c0f17e9f5a7afdf2cc9fb2d1ce6aabdb3bafb7f38017c0b77862bcec2bbad8 \ - --hash=sha256:6b2b56950d93e41f33b4223ead100ea0fe11f8e6ee5f641eb753ce4b77a7042b \ - --hash=sha256:787003c0ddb00500e49a10f2844fac87aa6ce977b90b0feaaf9de23c22508b24 \ - --hash=sha256:7ef3cb2ebbf91e330e3bb937efada0edd9003683db6b57bb108c4001f37a02ea \ - --hash=sha256:8023faf4e01efadfa183e863fefde0046de576c6f14659e8782065bcece22198 \ - --hash=sha256:8758846a7e80910096950b67071243da3e5a20ed2546e6392603c096778d48e0 \ - --hash=sha256:8afafd99945ead6e075b973fefa56379c5b5c53fd8937dad92c662da5d8fd5ee \ - --hash=sha256:8c41976a29d078bb235fea9b2ecd3da465df42a562910f9022f1a03107bd02be \ - --hash=sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2 \ - --hash=sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707 \ - --hash=sha256:962f82a3086483f5e5f64dbad880d31038b698494799b097bc59c2edf392fce6 \ - --hash=sha256:9dcdfd0eaf283af041973bff14a2e143b8bd64e069f4c383416ecd79a81aab58 \ - --hash=sha256:aa7bd130efab1c280bed0f45501b7c8795f9fdbeb02e965371bbef3523627779 \ - --hash=sha256:ab4a0df41e7c16a1392727727e7998a467472d0ad65f3ad5e6e765015df08636 \ - --hash=sha256:ad9e82fb8f09ade1c3e1b996a6337afac2b8b9e365f926f5a61aacc71adc5b3c \ - --hash=sha256:af598ed32d6ae86f1b747b82783958b1a4ab8f617b06fe68795c7f026abbdcad \ - --hash=sha256:b076b6226fb84157e3f7c971a47ff3a679d837cf338547532ab866c57930dbee \ - --hash=sha256:b7ff0f54cb4ff66dd38bebd335a38e2c22c41a8ee45aa608efc890ac3e3931bc \ - --hash=sha256:bfce63a9e7834b12b87c64d6b155fdd9b3b96191b6bd334bf37db7ff1fe457f2 \ - --hash=sha256:c011a4149cfbcf9f03994ec2edffcb8b1dc2d2aede7ca243746df97a5d41ce48 \ - --hash=sha256:c9c804664ebe8f83a211cace637506669e7890fec1b4195b505c214e50dd4eb7 \ - --hash=sha256:ca379055a47383d02a5400cb0d110cef0a776fc644cda797db0c5696cfd7e18e \ - --hash=sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b \ - --hash=sha256:cd0f502fe016460680cd20aaa5a76d241d6f35a1c3350c474bac1273803893fa \ - --hash=sha256:ceb01949af7121f9fc39f7d27f91be8546f3fb112c608bc4029aef0bab86a2a5 \ - --hash=sha256:d080e0a5eb2529460b30190fcfcc4199bd7f827663f858a226a81bc27beaa97e \ - --hash=sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb \ - --hash=sha256:df0be2b576a7abbf737b1575f048c23fb1d769f267ec4358296f31c2479db8f9 \ - --hash=sha256:e09031c87a1e51556fdcb46e5bd4f59dfb743061cf93c4d6831bf894f125eb57 \ - --hash=sha256:e4dd52d80b8c83fdce44e12478ad2e85c64ea965e75d66dbeafb0a3e77308fcc \ - --hash=sha256:fec21693218efe39aa7f8599346e90c705afa52c5b31ae019b2e57e8f6542bb2 - # via werkzeug -ml-dtypes==0.2.0 \ - --hash=sha256:022d5a4ee6be14569c2a9d1549e16f1ec87ca949681d0dca59995445d5fcdd5b \ - --hash=sha256:1749b60348da71fd3c2ab303fdbc1965958dc50775ead41f5669c932a341cafd \ - --hash=sha256:32107e7fa9f62db9a5281de923861325211dfff87bd23faefb27b303314635ab \ - --hash=sha256:35b984cddbe8173b545a0e3334fe56ea1a5c3eb67c507f60d0cfde1d3fa8f8c2 \ - --hash=sha256:36d28b8861a8931695e5a31176cad5ae85f6504906650dea5598fbec06c94606 \ - --hash=sha256:50845af3e9a601810751b55091dee6c2562403fa1cb4e0123675cf3a4fc2c17a \ - --hash=sha256:6488eb642acaaf08d8020f6de0a38acee7ac324c1e6e92ee0c0fea42422cb797 \ - --hash=sha256:75015818a7fccf99a5e8ed18720cb430f3e71a8838388840f4cdf225c036c983 \ - --hash=sha256:80d304c836d73f10605c58ccf7789c171cc229bfb678748adfb7cea2510dfd0e \ - --hash=sha256:832a019a1b6db5c4422032ca9940a990fa104eee420f643713241b3a518977fa \ - --hash=sha256:8faaf0897942c8253dd126662776ba45f0a5861968cf0f06d6d465f8a7bc298a \ - --hash=sha256:bc29a0524ef5e23a7fbb8d881bdecabeb3fc1d19d9db61785d077a86cb94fab2 \ - --hash=sha256:df6a76e1c8adf484feb138ed323f9f40a7b6c21788f120f7c78bec20ac37ee81 \ - --hash=sha256:e70047ec2c83eaee01afdfdabee2c5b0c133804d90d0f7db4dd903360fcc537c \ - --hash=sha256:e85ba8e24cf48d456e564688e981cf379d4c8e644db0a2f719b78de281bac2ca \ - --hash=sha256:f00c71c8c63e03aff313bc6a7aeaac9a4f1483a921a6ffefa6d4404efd1af3d0 \ - --hash=sha256:f08c391c2794f2aad358e6f4c70785a9a7b1df980ef4c232b3ccd4f6fe39f719 - # via jax -numpy==1.23.5 \ - --hash=sha256:01dd17cbb340bf0fc23981e52e1d18a9d4050792e8fb8363cecbf066a84b827d \ - --hash=sha256:06005a2ef6014e9956c09ba07654f9837d9e26696a0470e42beedadb78c11b07 \ - --hash=sha256:09b7847f7e83ca37c6e627682f145856de331049013853f344f37b0c9690e3df \ - --hash=sha256:0aaee12d8883552fadfc41e96b4c82ee7d794949e2a7c3b3a7201e968c7ecab9 \ - --hash=sha256:0cbe9848fad08baf71de1a39e12d1b6310f1d5b2d0ea4de051058e6e1076852d \ - --hash=sha256:1b1766d6f397c18153d40015ddfc79ddb715cabadc04d2d228d4e5a8bc4ded1a \ - --hash=sha256:33161613d2269025873025b33e879825ec7b1d831317e68f4f2f0f84ed14c719 \ - --hash=sha256:5039f55555e1eab31124a5768898c9e22c25a65c1e0037f4d7c495a45778c9f2 \ - --hash=sha256:522e26bbf6377e4d76403826ed689c295b0b238f46c28a7251ab94716da0b280 \ - --hash=sha256:56e454c7833e94ec9769fa0f86e6ff8e42ee38ce0ce1fa4cbb747ea7e06d56aa \ - --hash=sha256:58f545efd1108e647604a1b5aa809591ccd2540f468a880bedb97247e72db387 \ - --hash=sha256:5e05b1c973a9f858c74367553e236f287e749465f773328c8ef31abe18f691e1 \ - --hash=sha256:7903ba8ab592b82014713c491f6c5d3a1cde5b4a3bf116404e08f5b52f6daf43 \ - --hash=sha256:8969bfd28e85c81f3f94eb4a66bc2cf1dbdc5c18efc320af34bffc54d6b1e38f \ - --hash=sha256:92c8c1e89a1f5028a4c6d9e3ccbe311b6ba53694811269b992c0b224269e2398 \ - --hash=sha256:9c88793f78fca17da0145455f0d7826bcb9f37da4764af27ac945488116efe63 \ - --hash=sha256:a7ac231a08bb37f852849bbb387a20a57574a97cfc7b6cabb488a4fc8be176de \ - --hash=sha256:abdde9f795cf292fb9651ed48185503a2ff29be87770c3b8e2a14b0cd7aa16f8 \ - --hash=sha256:af1da88f6bc3d2338ebbf0e22fe487821ea4d8e89053e25fa59d1d79786e7481 \ - --hash=sha256:b2a9ab7c279c91974f756c84c365a669a887efa287365a8e2c418f8b3ba73fb0 \ - --hash=sha256:bf837dc63ba5c06dc8797c398db1e223a466c7ece27a1f7b5232ba3466aafe3d \ - --hash=sha256:ca51fcfcc5f9354c45f400059e88bc09215fb71a48d3768fb80e357f3b457e1e \ - --hash=sha256:ce571367b6dfe60af04e04a1834ca2dc5f46004ac1cc756fb95319f64c095a96 \ - --hash=sha256:d208a0f8729f3fb790ed18a003f3a57895b989b40ea4dce4717e9cf4af62c6bb \ - --hash=sha256:dbee87b469018961d1ad79b1a5d50c0ae850000b639bcb1b694e9981083243b6 \ - --hash=sha256:e9f4c4e51567b616be64e05d517c79a8a22f3606499941d97bb76f2ca59f982d \ - --hash=sha256:f063b69b090c9d918f9df0a12116029e274daf0181df392839661c4c7ec9018a \ - --hash=sha256:f9a909a8bae284d46bbfdefbdd4a262ba19d3bc9921b1e76126b1d21c3c34135 - # via - # -r ./requirements.in - # h5py - # jax - # ml-dtypes - # opt-einsum - # scipy - # tb-nightly -oauthlib==3.2.2 \ - --hash=sha256:8139f29aac13e25d502680e9e19963e83f16838d48a0d71c287fe40e7067fbca \ - --hash=sha256:9859c40929662bec5d64f34d01c99e093149682a3f38915dc0655d5a633dd918 - # via requests-oauthlib -opt-einsum==3.3.0 \ - --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ - --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 - # via jax -packaging==23.1 \ - --hash=sha256:994793af429502c4ea2ebf6bf664629d07c1a9fe974af92966e4b8d2df7edc61 \ - --hash=sha256:a392980d2b6cffa644431898be54b0045151319d1e7ec34f0cfed48767dd334f - # via -r ./requirements.in -portpicker==1.5.2 \ - --hash=sha256:01113f51c3cc63290a44dd7ae6e3eb9f8fe1b8a1f9d7988a897944230c39cd52 \ - --hash=sha256:c55683ad725f5c00a41bc7db0225223e8be024b1fa564d039ed3390e4fd48fb3 - # via -r ./requirements.in -protobuf==4.23.2 \ - --hash=sha256:09310bce43353b46d73ba7e3bca78273b9bc50349509b9698e64d288c6372c2a \ - --hash=sha256:20874e7ca4436f683b64ebdbee2129a5a2c301579a67d1a7dda2cdf62fb7f5f7 \ - --hash=sha256:25e3370eda26469b58b602e29dff069cfaae8eaa0ef4550039cc5ef8dc004511 \ - --hash=sha256:281342ea5eb631c86697e1e048cb7e73b8a4e85f3299a128c116f05f5c668f8f \ - --hash=sha256:384dd44cb4c43f2ccddd3645389a23ae61aeb8cfa15ca3a0f60e7c3ea09b28b3 \ - --hash=sha256:54a533b971288af3b9926e53850c7eb186886c0c84e61daa8444385a4720297f \ - --hash=sha256:6c081863c379bb1741be8f8193e893511312b1d7329b4a75445d1ea9955be69e \ - --hash=sha256:86df87016d290143c7ce3be3ad52d055714ebaebb57cc659c387e76cfacd81aa \ - --hash=sha256:8da6070310d634c99c0db7df48f10da495cc283fd9e9234877f0cd182d43ab7f \ - --hash=sha256:b2cfab63a230b39ae603834718db74ac11e52bccaaf19bf20f5cce1a84cf76df \ - --hash=sha256:c52cfcbfba8eb791255edd675c1fe6056f723bf832fa67f0442218f8817c076e \ - --hash=sha256:ce744938406de1e64b91410f473736e815f28c3b71201302612a68bf01517fea \ - --hash=sha256:efabbbbac1ab519a514579ba9ec52f006c28ae19d97915951f69fa70da2c9e91 - # via tb-nightly -psutil==5.9.5 \ - --hash=sha256:104a5cc0e31baa2bcf67900be36acde157756b9c44017b86b2c049f11957887d \ - --hash=sha256:3c6f686f4225553615612f6d9bc21f1c0e305f75d7d8454f9b46e901778e7217 \ - --hash=sha256:4aef137f3345082a3d3232187aeb4ac4ef959ba3d7c10c33dd73763fbc063da4 \ - --hash=sha256:5410638e4df39c54d957fc51ce03048acd8e6d60abc0f5107af51e5fb566eb3c \ - --hash=sha256:5b9b8cb93f507e8dbaf22af6a2fd0ccbe8244bf30b1baad6b3954e935157ae3f \ - --hash=sha256:7a7dd9997128a0d928ed4fb2c2d57e5102bb6089027939f3b722f3a210f9a8da \ - --hash=sha256:89518112647f1276b03ca97b65cc7f64ca587b1eb0278383017c2a0dcc26cbe4 \ - --hash=sha256:8c5f7c5a052d1d567db4ddd231a9d27a74e8e4a9c3f44b1032762bd7b9fdcd42 \ - --hash=sha256:ab8ed1a1d77c95453db1ae00a3f9c50227ebd955437bcf2a574ba8adbf6a74d5 \ - --hash=sha256:acf2aef9391710afded549ff602b5887d7a2349831ae4c26be7c807c0a39fac4 \ - --hash=sha256:b258c0c1c9d145a1d5ceffab1134441c4c5113b2417fafff7315a917a026c3c9 \ - --hash=sha256:be8929ce4313f9f8146caad4272f6abb8bf99fc6cf59344a3167ecd74f4f203f \ - --hash=sha256:c607bb3b57dc779d55e1554846352b4e358c10fff3abf3514a7a6601beebdb30 \ - --hash=sha256:ea8518d152174e1249c4f2a1c89e3e6065941df2fa13a1ab45327716a23c2b48 - # via portpicker -pyasn1==0.5.0 \ - --hash=sha256:87a2121042a1ac9358cabcaf1d07680ff97ee6404333bacca15f76aa8ad01a57 \ - --hash=sha256:97b7290ca68e62a832558ec3976f15cbf911bf5d7c7039d8b861c2a0ece69fde - # via - # pyasn1-modules - # rsa -pyasn1-modules==0.3.0 \ - --hash=sha256:5bd01446b736eb9d31512a30d46c1ac3395d676c6f3cafa4c03eb54b9925631c \ - --hash=sha256:d3ccd6ed470d9ffbc716be08bd90efbd44d0734bc9303818f7336070984a162d - # via google-auth -requests==2.31.0 \ - --hash=sha256:58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f \ - --hash=sha256:942c5a758f98d790eaed1a29cb6eefc7ffb0d1cf7af05c3d2791656dbd6ad1e1 - # via - # -r ./requirements.in - # requests-oauthlib - # tb-nightly -requests-oauthlib==1.3.1 \ - --hash=sha256:2577c501a2fb8d05a304c09d090d6e47c306fef15809d102b327cf8364bddab5 \ - --hash=sha256:75beac4a47881eeb94d5ea5d6ad31ef88856affe2332b9aafb52c6452ccf0d7a - # via google-auth-oauthlib -rsa==4.9 \ - --hash=sha256:90260d9058e514786967344d0ef75fa8727eed8a7d2e43ce9f4bcf1b536174f7 \ - --hash=sha256:e38464a49c6c85d7f1351b0126661487a7e0a14a50f1675ec50eb34d4f20ef21 - # via google-auth -scipy==1.10.1 \ - --hash=sha256:049a8bbf0ad95277ffba9b3b7d23e5369cc39e66406d60422c8cfef40ccc8415 \ - --hash=sha256:07c3457ce0b3ad5124f98a86533106b643dd811dd61b548e78cf4c8786652f6f \ - --hash=sha256:0f1564ea217e82c1bbe75ddf7285ba0709ecd503f048cb1236ae9995f64217bd \ - --hash=sha256:1553b5dcddd64ba9a0d95355e63fe6c3fc303a8fd77c7bc91e77d61363f7433f \ - --hash=sha256:15a35c4242ec5f292c3dd364a7c71a61be87a3d4ddcc693372813c0b73c9af1d \ - --hash=sha256:1b4735d6c28aad3cdcf52117e0e91d6b39acd4272f3f5cd9907c24ee931ad601 \ - --hash=sha256:2cf9dfb80a7b4589ba4c40ce7588986d6d5cebc5457cad2c2880f6bc2d42f3a5 \ - --hash=sha256:39becb03541f9e58243f4197584286e339029e8908c46f7221abeea4b749fa88 \ - --hash=sha256:43b8e0bcb877faf0abfb613d51026cd5cc78918e9530e375727bf0625c82788f \ - --hash=sha256:4b3f429188c66603a1a5c549fb414e4d3bdc2a24792e061ffbd607d3d75fd84e \ - --hash=sha256:4c0ff64b06b10e35215abce517252b375e580a6125fd5fdf6421b98efbefb2d2 \ - --hash=sha256:51af417a000d2dbe1ec6c372dfe688e041a7084da4fdd350aeb139bd3fb55353 \ - --hash=sha256:5678f88c68ea866ed9ebe3a989091088553ba12c6090244fdae3e467b1139c35 \ - --hash=sha256:79c8e5a6c6ffaf3a2262ef1be1e108a035cf4f05c14df56057b64acc5bebffb6 \ - --hash=sha256:7ff7f37b1bf4417baca958d254e8e2875d0cc23aaadbe65b3d5b3077b0eb23ea \ - --hash=sha256:aaea0a6be54462ec027de54fca511540980d1e9eea68b2d5c1dbfe084797be35 \ - --hash=sha256:bce5869c8d68cf383ce240e44c1d9ae7c06078a9396df68ce88a1230f93a30c1 \ - --hash=sha256:cd9f1027ff30d90618914a64ca9b1a77a431159df0e2a195d8a9e8a04c78abf9 \ - --hash=sha256:d925fa1c81b772882aa55bcc10bf88324dadb66ff85d548c71515f6689c6dac5 \ - --hash=sha256:e7354fd7527a4b0377ce55f286805b34e8c54b91be865bac273f527e1b839019 \ - --hash=sha256:fae8a7b898c42dffe3f7361c40d5952b6bf32d10c4569098d276b4c547905ee1 - # via - # -r ./requirements.in - # jax -setuptools==67.6.1 \ - --hash=sha256:257de92a9d50a60b8e22abfcbb771571fde0dbf3ec234463212027a4eeecbe9a \ - --hash=sha256:e728ca814a823bf7bf60162daf9db95b93d532948c4c0bea762ce62f60189078 - # via - # -r ./requirements.in - # tb-nightly -six==1.16.0 \ - --hash=sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926 \ - --hash=sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254 - # via google-auth -tb-nightly==2.14.0a20230612 \ - --hash=sha256:1ad7d57386f6103df69d8f3083552d185d33c5a07fb18e691b60236d9b4dc679 - # via -r ./requirements.in -tensorboard-data-server==0.7.0 \ - --hash=sha256:64aa1be7c23e80b1a42c13b686eb0875bb70f5e755f4d2b8de5c1d880cf2267f \ - --hash=sha256:753d4214799b31da7b6d93837959abebbc6afa86e69eacf1e9a317a48daa31eb \ - --hash=sha256:eb7fa518737944dbf4f0cf83c2e40a7ac346bf91be2e6a0215de98be74e85454 - # via tb-nightly -tf-estimator-nightly==2.14.0.dev2023060108 \ - --hash=sha256:09f63c090f29b74ebb36076c0ef1105c8b0358c6920847f312926012926ff7ce - # via -r ./requirements.in -urllib3==1.26.16 \ - --hash=sha256:8d36afa7616d8ab714608411b4a3b13e58f463aee519024578e062e141dce20f \ - --hash=sha256:8f135f6502756bde6b2a9b28989df5fbe87c9970cecaa69041edcce7f0589b14 - # via - # google-auth - # requests -werkzeug==2.3.6 \ - --hash=sha256:935539fa1413afbb9195b24880778422ed620c0fc09670945185cce4d91a8890 \ - --hash=sha256:98c774df2f91b05550078891dee5f0eb0cb797a522c757a2452b9cee5b202330 - # via tb-nightly -wheel==0.38.4 \ - --hash=sha256:965f5259b566725405b05e7cf774052044b1ed30119b5d586b2703aafe8719ac \ - --hash=sha256:b60533f3f5d530e971d6737ca6d58681ee434818fab630c83a734bb10c083ce8 - # via - # -r ./requirements.in - # tb-nightly -zipp==3.15.0 \ - --hash=sha256:112929ad649da941c23de50f356a2b5570c954b65150642bccdd66bf194d224b \ - --hash=sha256:48904fc76a60e542af151aded95726c1a5c34ed43ab4134b597665c86d7ad556 - # via importlib-metadata diff --git a/tensorflow/tools/toolchains/python/python_repo.bzl b/tensorflow/tools/toolchains/python/python_repo.bzl index 7dfb571b68cc47..61a45964649db7 100644 --- a/tensorflow/tools/toolchains/python/python_repo.bzl +++ b/tensorflow/tools/toolchains/python/python_repo.bzl @@ -4,7 +4,7 @@ Can be set via build parameter "--repo_env=TF_PYTHON_VERSION=3.10" Defaults to 3.10. """ -VERSIONS = ["3.8", "3.9", "3.10", "3.11"] +VERSIONS = ["3.9", "3.10", "3.11"] DEFAULT_VERSION = "3.10" WARNING = """ TF_PYTHON_VERSION variable was not set correctly, using default version. {} Python From ccda01716e37ce652c12f582bd10475e2dd49bb6 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 25 Jul 2023 12:38:25 -0700 Subject: [PATCH 132/410] Add missing comma in string lists leading to unwanted implicit concat. PiperOrigin-RevId: 550965686 --- .../profiler/convert/hlo_proto_to_memory_visualization_utils.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/profiler/convert/hlo_proto_to_memory_visualization_utils.cc b/tensorflow/core/profiler/convert/hlo_proto_to_memory_visualization_utils.cc index 92bf64d0f43583..c562642ecad327 100644 --- a/tensorflow/core/profiler/convert/hlo_proto_to_memory_visualization_utils.cc +++ b/tensorflow/core/profiler/convert/hlo_proto_to_memory_visualization_utils.cc @@ -832,7 +832,7 @@ void ConvertAllocationTimeline(const HloProtoBufferWrapper& wrapper, "orange", "orangered", "orchid", - "palegoldenrod" + "palegoldenrod", "palegreen", "paleturquoise", "palevioletred", From 863451d7d0875aac843e65f485cdf7dcbbd72c0a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 25 Jul 2023 12:40:10 -0700 Subject: [PATCH 133/410] Make new shape as close to the original shape as possible. PiperOrigin-RevId: 550966133 --- tensorflow/compiler/xla/shape_util.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 7af02a290f6d72..d30e1b779b91f0 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -428,6 +428,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( for (int i = 0; i < shape.dimensions_size(); ++i) { new_shape.set_dynamic_dimension(i, shape.is_dynamic_dimension(i)); } + new_shape.mutable_layout()->set_memory_space(shape.layout().memory_space()); return new_shape; } From bd321826d84ddd4508a97ef228ffafd745f56a82 Mon Sep 17 00:00:00 2001 From: Matt Callanan Date: Tue, 25 Jul 2023 13:08:26 -0700 Subject: [PATCH 134/410] #tf-data-service Fix bug when falling back to gRPC in the critical path. This fallback loop is supposed to terminate when the worker client's data transfer protocol is gRPC or local, but it mistakenly checks the user-specified data transfer protocol (which doesn't change after a given worker falls back). This results in an infinite loop in workflows with an expected nonpreemptable error, as in `//third_party/tensorflow/python/data/experimental/kernel_tests/service:coordinated_read_test` and `//third_party/tensorflow/python/data/experimental/kernel_tests/service:cross_trainer_cache_test`. PiperOrigin-RevId: 550974057 --- .../service/client/data_service_client.cc | 20 ++++++++----------- 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/tensorflow/core/data/service/client/data_service_client.cc b/tensorflow/core/data/service/client/data_service_client.cc index f98bd77f5fa80e..5cbb955f8bb075 100644 --- a/tensorflow/core/data/service/client/data_service_client.cc +++ b/tensorflow/core/data/service/client/data_service_client.cc @@ -815,24 +815,20 @@ Status DataServiceClient::GetElement(Task* task, int64_t deadline_micros, break; } if (!IsPreemptedError(s)) { - std::string data_transfer_protocol = - !params_.data_transfer_protocol.empty() - ? params_.data_transfer_protocol - : DefaultDataTransferProtocol(); - if (data_transfer_protocol == kGrpcTransferProtocol || - data_transfer_protocol == kLocalTransferProtocol) { + if (task->worker->GetDataTransferProtocol() == kGrpcTransferProtocol || + task->worker->GetDataTransferProtocol() == kLocalTransferProtocol) { return s; } + LOG(ERROR) << "failed to use alternative data transfer protocol '" + << task->worker->GetDataTransferProtocol() + << "'; falling back to grpc. Original error: " << s; + metrics::RecordTFDataServiceDataTransferProtocolError( + task->worker->GetDataTransferProtocol(), + static_cast(s.raw_code()), std::string(s.message())); mutex_lock l(mu_); TF_ASSIGN_OR_RETURN(std::unique_ptr worker, CreateGrpcWorkerClient(task->info)); task->worker = std::move(worker); - LOG(ERROR) << "failed to use alternative data transfer protocol '" - << data_transfer_protocol << "'; falling back to grpc. " - << "Original error: " << s; - metrics::RecordTFDataServiceDataTransferProtocolError( - DefaultDataTransferProtocol(), static_cast(s.raw_code()), - std::string(s.message())); continue; } { From 2a8de4e558cdfb11a58a36f7b107f855b24e0f42 Mon Sep 17 00:00:00 2001 From: Haibo Huang Date: Tue, 25 Jul 2023 13:19:29 -0700 Subject: [PATCH 135/410] Convert kTfNextPluggableDeviceUseCApi to a flag This makes the flag runtime configurable. It is still default to the build time value, to make it backward compatible. After we fully migrate to the runtime configuration, we can remove the build time flag entirely. PiperOrigin-RevId: 550977259 --- .../next_pluggable_device/BUILD | 9 ------- .../next_pluggable_device_c_api_flag.cc | 24 ----------------- .../next_pluggable_device_c_api_flag.h | 26 ------------------- ...plugin_coordination_service_agent_helper.h | 5 ++-- .../plugin_op_kernel_helper.h | 7 ++--- 5 files changed, 7 insertions(+), 64 deletions(-) delete mode 100644 tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device_c_api_flag.cc delete mode 100644 tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device_c_api_flag.h diff --git a/tensorflow/core/common_runtime/next_pluggable_device/BUILD b/tensorflow/core/common_runtime/next_pluggable_device/BUILD index 1d283021737aaf..c89135e0bfeef9 100644 --- a/tensorflow/core/common_runtime/next_pluggable_device/BUILD +++ b/tensorflow/core/common_runtime/next_pluggable_device/BUILD @@ -66,13 +66,6 @@ cc_library( ], ) -cc_library( - name = "next_pluggable_device_c_api_flag", - srcs = ["next_pluggable_device_c_api_flag.cc"], - hdrs = ["next_pluggable_device_c_api_flag.h"], - visibility = ["//visibility:public"], -) - cc_library( name = "next_pluggable_device_factory", srcs = [ @@ -223,7 +216,6 @@ cc_library( ":c_plugin_op_kernel", ":direct_plugin_op_kernel", ":loose_headers", - ":next_pluggable_device_c_api_flag", ":plugin_op_kernel", "//tensorflow/c:kernels_hdrs", "//tensorflow/c:tf_status_helper", @@ -290,7 +282,6 @@ cc_library( deps = [ ":c_plugin_coordination_service_agent", ":direct_plugin_coordination_service_agent", - ":next_pluggable_device_c_api_flag", ":plugin_coordination_service_agent", "//tensorflow/c:kernels_hdrs", "//tensorflow/c:tf_status_helper", diff --git a/tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device_c_api_flag.cc b/tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device_c_api_flag.cc deleted file mode 100644 index 79c948838f4f39..00000000000000 --- a/tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device_c_api_flag.cc +++ /dev/null @@ -1,24 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device_c_api_flag.h" - -namespace tensorflow { -namespace npd { - -// Define as a weak symbol so that it can be overridden when necessary. -extern const bool kTfNextPluggableDeviceUseCApi __attribute__((weak)) = false; - -} // namespace npd -} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device_c_api_flag.h b/tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device_c_api_flag.h deleted file mode 100644 index 2628a6e1648bdc..00000000000000 --- a/tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device_c_api_flag.h +++ /dev/null @@ -1,26 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_NEXT_PLUGGABLE_DEVICE_C_API_FLAG_H_ -#define TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_NEXT_PLUGGABLE_DEVICE_C_API_FLAG_H_ - -namespace tensorflow { -namespace npd { - -extern const bool kTfNextPluggableDeviceUseCApi; - -} // namespace npd -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_NEXT_PLUGGABLE_DEVICE_C_API_FLAG_H_ diff --git a/tensorflow/core/common_runtime/next_pluggable_device/plugin_coordination_service_agent_helper.h b/tensorflow/core/common_runtime/next_pluggable_device/plugin_coordination_service_agent_helper.h index 3fd2bacba29d78..fda1e00d063c8f 100644 --- a/tensorflow/core/common_runtime/next_pluggable_device/plugin_coordination_service_agent_helper.h +++ b/tensorflow/core/common_runtime/next_pluggable_device/plugin_coordination_service_agent_helper.h @@ -20,14 +20,15 @@ limitations under the License. #include "tensorflow/c/tf_status_helper.h" #include "tensorflow/core/common_runtime/next_pluggable_device/c_plugin_coordination_service_agent.h" #include "tensorflow/core/common_runtime/next_pluggable_device/direct_plugin_coordination_service_agent.h" -#include "tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device_c_api_flag.h" #include "tensorflow/core/common_runtime/next_pluggable_device/plugin_coordination_service_agent.h" +ABSL_DECLARE_FLAG(bool, next_pluggable_device_use_c_api); + namespace tensorflow { inline PluginCoordinationServiceAgent* CreatePluginCoordinationServiceAgent( void* agent) { - if (!tensorflow::npd::kTfNextPluggableDeviceUseCApi) { + if (!absl::GetFlag(FLAGS_next_pluggable_device_use_c_api)) { return new DirectPluginCoordinationServiceAgent(agent); } else { return new CPluginCoordinationServiceAgent(agent); diff --git a/tensorflow/core/common_runtime/next_pluggable_device/plugin_op_kernel_helper.h b/tensorflow/core/common_runtime/next_pluggable_device/plugin_op_kernel_helper.h index 3e087b9d41154c..46017c01be4222 100644 --- a/tensorflow/core/common_runtime/next_pluggable_device/plugin_op_kernel_helper.h +++ b/tensorflow/core/common_runtime/next_pluggable_device/plugin_op_kernel_helper.h @@ -20,13 +20,14 @@ limitations under the License. #include "tensorflow/c/tf_status_helper.h" #include "tensorflow/core/common_runtime/next_pluggable_device/c_plugin_op_kernel.h" #include "tensorflow/core/common_runtime/next_pluggable_device/direct_plugin_op_kernel.h" -#include "tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device_c_api_flag.h" #include "tensorflow/core/common_runtime/next_pluggable_device/plugin_op_kernel.h" +ABSL_DECLARE_FLAG(bool, next_pluggable_device_use_c_api); + namespace tensorflow { inline PluginOpKernelConstruction* CreatePluginOpKernelConstruction(void* ctx) { - if (!npd::kTfNextPluggableDeviceUseCApi) { + if (!absl::GetFlag(FLAGS_next_pluggable_device_use_c_api)) { return new DirectPluginOpKernelConstruction(ctx); } else { return new CPluginOpKernelConstruction(ctx); @@ -39,7 +40,7 @@ inline void DeletePluginOpKernelConstruction( } inline PluginOpKernelContext* CreatePluginOpKernelContext(void* ctx) { - if (!npd::kTfNextPluggableDeviceUseCApi) { + if (!absl::GetFlag(FLAGS_next_pluggable_device_use_c_api)) { return new DirectPluginOpKernelContext(ctx); } else { return new CPluginOpKernelContext(ctx); From 56f261ba2fad322bacbb69aa1eeaff46b5d2910c Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Tue, 25 Jul 2023 13:43:56 -0700 Subject: [PATCH 136/410] [XLA:GPU] Use more measured HLO operation durations in fusion performance modeling. Also add a tool that is used to measure these. PiperOrigin-RevId: 550984284 --- tensorflow/compiler/xla/service/gpu/BUILD | 94 +- .../compiler/xla/service/gpu/fusion_merger.cc | 8 +- .../compiler/xla/service/gpu/gpu_compiler.cc | 2 +- .../xla/service/gpu/gpu_hlo_cost_analysis.cc | 137 +- .../xla/service/gpu/gpu_hlo_cost_analysis.h | 10 +- .../xla/service/gpu/gpu_performance_model.cc | 41 +- .../xla/service/gpu/gpu_performance_model.h | 5 +- .../service/gpu/gpu_performance_model_test.cc | 679 +---- .../xla/service/gpu/hlo_op_profile.proto | 18 + .../xla/service/gpu/hlo_op_profiler.cc | 185 ++ .../xla/service/gpu/hlo_op_profiler.h | 60 + .../xla/service/gpu/hlo_op_profiler_run.cc | 127 + .../xla/service/gpu/hlo_op_profiler_test.cc | 46 + .../xla/service/gpu/hlo_op_profiles.h | 2370 +++++++++++++++++ .../xla/service/gpu/multi_output_fusion.cc | 20 +- .../xla/service/gpu/priority_fusion.cc | 21 +- .../xla/service/gpu/priority_fusion.h | 5 +- tensorflow/compiler/xla/xla.bzl | 2 + tensorflow/core/BUILD | 1 + 19 files changed, 3031 insertions(+), 800 deletions(-) create mode 100644 tensorflow/compiler/xla/service/gpu/hlo_op_profile.proto create mode 100644 tensorflow/compiler/xla/service/gpu/hlo_op_profiler.cc create mode 100644 tensorflow/compiler/xla/service/gpu/hlo_op_profiler.h create mode 100644 tensorflow/compiler/xla/service/gpu/hlo_op_profiler_run.cc create mode 100644 tensorflow/compiler/xla/service/gpu/hlo_op_profiler_test.cc create mode 100644 tensorflow/compiler/xla/service/gpu/hlo_op_profiles.h diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index d2f7da5ffd450f..425279a2e6da6f 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -2910,14 +2910,22 @@ cc_library( cc_library( name = "gpu_hlo_cost_analysis", srcs = ["gpu_hlo_cost_analysis.cc"], - hdrs = ["gpu_hlo_cost_analysis.h"], + hdrs = [ + "gpu_hlo_cost_analysis.h", + "hlo_op_profiles.h", + ], compatible_with = get_compatible_with_portable(), deps = [ ":backend_configs_cc", ":cublas_cudnn", + ":gpu_device_info", + ":hlo_op_profile_proto_cc", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/hlo/ir:hlo", "//tensorflow/compiler/xla/service:elemental_ir_emitter", "//tensorflow/compiler/xla/service:hlo_cost_analysis", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings", ], ) @@ -2980,13 +2988,97 @@ xla_cc_test( srcs = ["gpu_performance_model_test.cc"], deps = [ ":backend_configs_cc", + ":gpu_device_info", ":gpu_device_info_for_tests", + ":gpu_hlo_cost_analysis", ":gpu_performance_model", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla/hlo/ir:hlo", + "//tensorflow/compiler/xla/service:hlo_module_config", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + ], +) + +tf_proto_library( + name = "hlo_op_profile_proto", + srcs = ["hlo_op_profile.proto"], + cc_api_version = 2, + make_default_target_header_only = True, + protodeps = [ + "//tensorflow/compiler/xla/service:hlo_proto", + ], +) + +cc_library( + name = "hlo_op_profiler_lib", + testonly = True, + srcs = ["hlo_op_profiler.cc"], + hdrs = ["hlo_op_profiler.h"], + deps = [ + ":gpu_device_info", + ":hlo_op_profile_proto_cc", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/hlo/ir:hlo", + "//tensorflow/compiler/xla/service:executable", + "//tensorflow/compiler/xla/service:gpu_plugin", + "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/compiler/xla/service:hlo_runner", + "//tensorflow/compiler/xla/service:interpreter_plugin", + "//tensorflow/compiler/xla/tests:test_utils", + "@com_google_absl//absl/time", + ], +) + +xla_cc_test( + name = "hlo_op_profiler_run", + timeout = "eternal", + srcs = ["hlo_op_profiler_run.cc"], + # This is a development tool, not a normal test, and thus should only be run + # manually. + tags = [ + "gpu", + "manual", + "notap", + "requires-gpu-nvidia", + ], + deps = [ + ":gpu_device_info", + ":hlo_op_profile_proto_cc", + ":hlo_op_profiler_lib", + "//tensorflow/compiler/xla:debug_options_flags", + "//tensorflow/compiler/xla/hlo/ir:hlo", + "//tensorflow/compiler/xla/service:hlo_runner", + "//tensorflow/compiler/xla/service:platform_util", + "//tensorflow/tsl/platform:env", + "//tensorflow/tsl/platform:path", + "//tensorflow/tsl/platform:platform_port", + "//tensorflow/tsl/platform:test", + "//tensorflow/tsl/util:command_line_flags", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) +xla_cc_test( + name = "hlo_op_profiler_test", + srcs = if_cuda_is_configured(["hlo_op_profiler_test.cc"]), + tags = tf_cuda_tests_tags(), + deps = if_cuda_is_configured([ + ":hlo_op_profiler_lib", + "//tensorflow/compiler/xla/hlo/ir:hlo", + "//tensorflow/compiler/xla/service:gpu_plugin", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/tsl/platform:test_main", + ]), +) + cc_library( name = "buffer_comparator", srcs = if_cuda_is_configured(["buffer_comparator.cc"]), diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc index c91b8d5a1714a2..d823bb5706ab79 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc @@ -120,8 +120,7 @@ Status FusionInstructionMerger::FuseIntoAllUsers(HloInstruction* producer) { *consumer); } - GpuPerformanceModel::RecordEstimatedRunTime(consumer, &*cost_analysis_, - gpu_device_info_); + GpuPerformanceModel::RecordEstimatedRunTime(consumer, &*cost_analysis_); changed_ = true; } @@ -266,7 +265,8 @@ FusionDecision FusionInstructionMerger::ShouldFuse(HloInstruction* producer) { cost_analysis_.emplace( GpuHloCostAnalysis::Options{shape_size_function_, /*per_second_rates=*/{}, - /*count_multiple_input_accesses=*/true}); + /*count_multiple_input_accesses=*/true}, + &gpu_device_info_); TF_CHECK_OK(computation_->Accept(&cost_analysis_.value())); } @@ -285,7 +285,7 @@ FusionDecision FusionInstructionMerger::ShouldFuse(HloInstruction* producer) { .xla_gpu_enable_experimental_block_size(); GpuPerformanceModel::RunTimes t = GpuPerformanceModel::EstimateRunTimes( - producer, &*cost_analysis_, gpu_device_info_, use_experimental_block_size, + producer, &*cost_analysis_, use_experimental_block_size, compute_capability_, producer->users()); if (t.time_fused > t.time_unfused) { ++num_fail_slower_if_fused_; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 9155a69a4435cd..557abcdf6b1171 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -1417,7 +1417,7 @@ StatusOr> GpuCompiler::RunBackend( HloCostAnalysis::Options cost_analysis_options{ShapeSizeBytesFunction()}; cost_analysis_options.set_bytes_per_second( stream_exec->GetDeviceDescription().memory_bandwidth()); - GpuHloCostAnalysis cost_analysis(cost_analysis_options); + GpuHloCostAnalysis cost_analysis(cost_analysis_options, &gpu_device_info); TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&cost_analysis)); if (!options.is_autotuning_compilation) { VLOG(1) << "HLO memory read+written: " diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_cost_analysis.cc index d9f4927628b43e..708e5c03961d10 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_cost_analysis.cc @@ -15,14 +15,20 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/gpu_hlo_cost_analysis.h" +#include #include #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_opcode.h" +#include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/gpu/cublas_cudnn.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_op_profile.pb.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_op_profiles.h" namespace xla { namespace gpu { @@ -270,95 +276,52 @@ int64_t GpuHloCostAnalysis::GetConvolutionFlops( result_shape); } -Status GpuHloCostAnalysis::HandleElementwiseOp(const HloInstruction* hlo) { - const HloOpcode opcode = hlo->opcode(); - const auto& shape = hlo->shape(); - const PrimitiveType type = shape.element_type(); - - // These are clock cycle estimates of some of the most common expensive - // operations. They most likely vary a lot from GPU to GPU but should - // at least provide reasonable comparisons for the computation cost analysis. - // HLOs used to measure these can be found in gpu_performance_model_test.cc - // This list is far from complete yet. - // TODO(b/256570878): Make a tool to measure these numbers and store them - // separately from the code where possible. - - // Typical elementwise instructions take about 3 clock cycles. - int64_t flop_per_element = 3; - switch (opcode) { - case HloOpcode::kTanh: - if (type == F32) { - flop_per_element = 30; - } else if (type == F64) { - flop_per_element = 2000; - } - break; - case HloOpcode::kDivide: - if (type == S32) { - flop_per_element = 80; - } else if (type == F64) { - flop_per_element = 3200; - } else if (type == C128) { - flop_per_element = 20000; - } - break; - // Expands to multiple instructions. - case HloOpcode::kExp: - if (type == F64) { - flop_per_element = 2200; - } - break; - case HloOpcode::kSqrt: - if (type == F64) { - flop_per_element = 1100; - } else if (type == C128) { - flop_per_element = 25000; - } - break; - case HloOpcode::kRsqrt: - if (type == F64) { - flop_per_element = 900; - } - break; - case HloOpcode::kAdd: - if (type == F64) { - flop_per_element = 120; - } else if (type == C128) { - flop_per_element = 240; - } - break; - case HloOpcode::kMultiply: - if (type == F64) { - flop_per_element = 120; - } else if (type == C128) { - flop_per_element = 650; - } - break; - case HloOpcode::kPower: - if (type == F64) { - flop_per_element = 11000; - } else if (type == C128) { - flop_per_element = 28000; - } - break; - case HloOpcode::kLog: - if (type == F32) { - flop_per_element = 45; - } else if (type == F64) { - flop_per_element = 1000; - } - break; - default: - // Raise default cost of all unlisted F64 and C128 ops. - if (type == F64) { - flop_per_element = 10; - } else if (type == C128) { - flop_per_element = 20; +using profiles_nested_map = absl::flat_hash_map< + std::string, // Device name. + absl::flat_hash_map>>; + +const profiles_nested_map* LoadOpProfiles() { + static profiles_nested_map* profiles = [] { + profiles_nested_map* ret = new profiles_nested_map; + DeviceHloInstructionProfiles all_device_profiles; + CHECK(tsl::protobuf::TextFormat::ParseFromString( + std::string(kDeviceHloOpProfiles), &all_device_profiles)); + for (const auto& device_profile : all_device_profiles.entries()) { + for (const auto& entry : device_profile.second.entries()) { + (*ret)[device_profile.first][entry.instruction().shape().element_type()] + [StringToHloOpcode(entry.instruction().opcode()).value()] = + entry.clock_cycles(); } - break; + } + return ret; + }(); + return profiles; +} + +// Elementwise instructions typically take at least a few clock cycles. +constexpr int64_t kDefaultFlopsPerElement = 3; + +constexpr absl::string_view kDefaultDeviceName = "NVIDIA RTX A6000"; + +int64_t FlopsPerElement(const absl::string_view device_name, + const PrimitiveType type, const HloOpcode opcode) { + const profiles_nested_map* all_profiles = LoadOpProfiles(); + auto device_profiles = FindOrDefault(*all_profiles, std::string(device_name), + all_profiles->at(kDefaultDeviceName)); + auto dtype_profiles = MaybeFind(device_profiles, type); + if (!dtype_profiles.ok()) { + return kDefaultFlopsPerElement; } + return FindOrDefault(dtype_profiles->get(), opcode, kDefaultFlopsPerElement); +} + +Status GpuHloCostAnalysis::HandleElementwiseOp(const HloInstruction* hlo) { + int64_t flop_per_element = + FlopsPerElement(device_info_ ? device_info_->name : kDefaultDeviceName, + hlo->shape().element_type(), hlo->opcode()); current_properties_[kFlopsKey] = - flop_per_element * ShapeUtil::ElementsInRecursive(shape); + flop_per_element * ShapeUtil::ElementsInRecursive(hlo->shape()); return OkStatus(); } @@ -372,7 +335,7 @@ Status GpuHloCostAnalysis::HandleElementwiseBinary(const HloInstruction* hlo) { std::unique_ptr GpuHloCostAnalysis::CreateNestedCostAnalysis() { - return std::make_unique(options_); + return std::make_unique(options_, device_info_); } bool GpuHloCostAnalysis::KeyToCopyFromSubcomputation( diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_cost_analysis.h b/tensorflow/compiler/xla/service/gpu/gpu_hlo_cost_analysis.h index b138e1786e22ca..208fc28218883d 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_cost_analysis.h @@ -17,8 +17,9 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_HLO_COST_ANALYSIS_H_ #include -#include +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_device_info.h" #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" namespace xla { @@ -32,8 +33,9 @@ class GpuHloCostAnalysis : public HloCostAnalysis { static constexpr int64_t kMaxIRSize = 10000; public: - explicit GpuHloCostAnalysis(const Options& options) - : HloCostAnalysis(options) {} + explicit GpuHloCostAnalysis(const Options& options, + const GpuDeviceInfo* device_info = nullptr) + : HloCostAnalysis(options), device_info_(device_info) {} Status Preprocess(const HloInstruction* hlo) override; @@ -63,6 +65,8 @@ class GpuHloCostAnalysis : public HloCostAnalysis { float CommonElementwiseUtilization(const HloInstruction* a, const HloInstruction* b) const; + const GpuDeviceInfo* device_info_; + protected: std::unique_ptr CreateNestedCostAnalysis() override; int64_t FusionParameterReadBytes(const HloInstruction* hlo) const override; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_performance_model.cc b/tensorflow/compiler/xla/service/gpu/gpu_performance_model.cc index c0a20c96e5d9d7..f37f4c9410b30f 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_performance_model.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_performance_model.cc @@ -168,20 +168,19 @@ struct EstimateRunTimeData { absl::Duration exec_time; }; -EstimateRunTimeData EstimateRunTimeImpl(const HloInstruction* instr, - const GpuHloCostAnalysis* cost_analysis, - const GpuDeviceInfo& gpu_device_info) { +EstimateRunTimeData EstimateRunTimeImpl( + const HloInstruction* instr, const GpuHloCostAnalysis* cost_analysis) { int64_t flops = cost_analysis->flop_count(*instr); float bytes_written = cost_analysis->output_bytes_accessed(*instr); float bytes_read = cost_analysis->bytes_accessed(*instr) - bytes_written; float elements_out = ShapeUtil::ElementsInRecursive(instr->shape()); absl::Duration compute_time = - ComputeTime(gpu_device_info, flops, elements_out); - absl::Duration read_time = - ProducerInputAccessTime(cost_analysis, gpu_device_info, instr); - absl::Duration write_time = - absl::Seconds(bytes_written / gpu_device_info.memory_bandwidth); + ComputeTime(*cost_analysis->device_info_, flops, elements_out); + absl::Duration read_time = ProducerInputAccessTime( + cost_analysis, *cost_analysis->device_info_, instr); + absl::Duration write_time = absl::Seconds( + bytes_written / cost_analysis->device_info_->memory_bandwidth); absl::Duration exec_time = std::max(compute_time, read_time + write_time); if (VLOG_IS_ON(8)) { @@ -201,7 +200,7 @@ EstimateRunTimeData EstimateRunTimeImpl(const HloInstruction* instr, GpuPerformanceModel::RunTimes GpuPerformanceModel::EstimateRunTimes( const HloInstruction* producer, const GpuHloCostAnalysis* cost_analysis, - const GpuDeviceInfo& gpu_device_info, bool use_experimental_block_size, + bool use_experimental_block_size, std::optional cc, std::vector fused_users, bool multi_output) { VLOG(8) << "Producer: " << producer->name(); @@ -210,7 +209,7 @@ GpuPerformanceModel::RunTimes GpuPerformanceModel::EstimateRunTimes( } EstimateRunTimeData producer_data = - EstimateRunTimeImpl(producer, cost_analysis, gpu_device_info); + EstimateRunTimeImpl(producer, cost_analysis); int64_t fused_consumer_count = fused_users.size(); float total_producer_utilization = 0; @@ -222,19 +221,21 @@ GpuPerformanceModel::RunTimes GpuPerformanceModel::EstimateRunTimes( cost_analysis->operand_utilization(*u, u->operand_index(producer)); total_producer_utilization += utilization_by_this_consumer; - auto thread_count = EstimateThreadCount(u, gpu_device_info, cc, + auto thread_count = EstimateThreadCount(u, *cost_analysis->device_info_, cc, use_experimental_block_size); int64_t upper_bound = producer_data.elements_out * utilization_by_this_consumer; absl::Duration compute_time_by_this_consumer = ComputeTime( - gpu_device_info, producer_data.flops * utilization_by_this_consumer, + *cost_analysis->device_info_, + producer_data.flops * utilization_by_this_consumer, thread_count.has_value() ? std::min(*thread_count, upper_bound) : upper_bound); - exec_time_fused += std::max( - compute_time_by_this_consumer, - ProducerInputAccessTime(cost_analysis, gpu_device_info, producer, u)); + exec_time_fused += + std::max(compute_time_by_this_consumer, + ProducerInputAccessTime( + cost_analysis, *cost_analysis->device_info_, producer, u)); producer_output_read_time_unfused += ReadTime( - gpu_device_info, + *cost_analysis->device_info_, std::min(producer_data.bytes_written, producer_data.bytes_written * utilization_by_this_consumer), producer_data.bytes_written * utilization_by_this_consumer); @@ -264,15 +265,13 @@ GpuPerformanceModel::RunTimes GpuPerformanceModel::EstimateRunTimes( } void GpuPerformanceModel::RecordEstimatedRunTime( - HloInstruction* instruction, const GpuHloCostAnalysis* cost_analysis, - const GpuDeviceInfo& gpu_device_info) { + HloInstruction* instruction, const GpuHloCostAnalysis* cost_analysis) { DCHECK(Cast(instruction)) << "expected fusion"; DCHECK(cost_analysis != nullptr) << "expected cost analysis"; - EstimateRunTimeData data = - EstimateRunTimeImpl(instruction, cost_analysis, gpu_device_info); + EstimateRunTimeData data = EstimateRunTimeImpl(instruction, cost_analysis); double cycles = absl::ToDoubleNanoseconds(data.exec_time) * - gpu_device_info.clock_rate_ghz; + cost_analysis->device_info_->clock_rate_ghz; auto backend_config = instruction->backend_config(); TF_CHECK_OK(backend_config.status()) << instruction->ToString(); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_performance_model.h b/tensorflow/compiler/xla/service/gpu/gpu_performance_model.h index 765c8b75ee373d..fc8b75847926cf 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_performance_model.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_performance_model.h @@ -20,7 +20,6 @@ limitations under the License. #include #include "absl/time/time.h" -#include "tensorflow/compiler/xla/service/gpu/gpu_device_info.h" #include "tensorflow/compiler/xla/service/gpu/gpu_hlo_cost_analysis.h" namespace xla { @@ -34,15 +33,13 @@ class GpuPerformanceModel { }; static RunTimes EstimateRunTimes( const HloInstruction* producer, const GpuHloCostAnalysis* cost_analysis, - const GpuDeviceInfo& gpu_device_info, bool use_experimental_block_size = false, std::optional cc = std::nullopt, std::vector fused_users = {}, bool multi_output = false); // Writes estimated execution time to FusionBackendConfig.reification_cost. static void RecordEstimatedRunTime(HloInstruction* instruction, - const GpuHloCostAnalysis* cost_analysis, - const GpuDeviceInfo& gpu_device_info); + const GpuHloCostAnalysis* cost_analysis); }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/gpu_performance_model_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_performance_model_test.cc index bcacfc1402db2f..c96e4cc2d097b4 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_performance_model_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_performance_model_test.cc @@ -15,11 +15,23 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/gpu_performance_model.h" +#include #include #include +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_computation.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_module.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_opcode.h" #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_device_info.h" #include "tensorflow/compiler/xla/service/gpu/gpu_device_info_for_tests.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_cost_analysis.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" namespace xla { @@ -27,21 +39,22 @@ namespace gpu { namespace { class GpuPerformanceModelTest : public HloTestBase { - HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const { + GpuHloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const { return [&](const Shape& shape) { constexpr int64_t kPointerSize = 8; return ShapeUtil::ByteSizeOf(shape, kPointerSize); }; } + GpuDeviceInfo dev_info_{TestGpuDeviceInfo::RTXA6000DeviceInfo()}; + public: - HloCostAnalysis::Options options_{ShapeSizeBytesFunction(), - /*per_second_rates=*/{}, - /*count_multiple_input_accesses=*/true}; - GpuHloCostAnalysis analysis_{options_}; + GpuHloCostAnalysis::Options options_{ShapeSizeBytesFunction(), + /*per_second_rates=*/{}, + /*count_multiple_input_accesses=*/true}; // The reference times in the test cases below are measured // on A6000 by profiling the execution of the HLOs. - GpuDeviceInfo device_info_ = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + GpuHloCostAnalysis analysis_{options_, &dev_info_}; GpuPerformanceModelTest() : HloTestBase() {} }; @@ -64,7 +77,7 @@ ENTRY e { ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_)); GpuPerformanceModel::RunTimes t = - GpuPerformanceModel::EstimateRunTimes(root, &analysis_, device_info_); + GpuPerformanceModel::EstimateRunTimes(root, &analysis_); // Dominated by the DRAM bandwidth. EXPECT_NEAR(absl::ToInt64Microseconds(t.time_unfused), 57, 10); } @@ -91,11 +104,11 @@ ENTRY e { ASSERT_IS_OK(root->Accept(&analysis_)); GpuPerformanceModel::RunTimes t = - GpuPerformanceModel::EstimateRunTimes(root, &analysis_, device_info_); + GpuPerformanceModel::EstimateRunTimes(root, &analysis_); // Dominated by the kernel launch overhead. EXPECT_NEAR(absl::ToInt64Microseconds(t.time_unfused), 2, 1); - GpuPerformanceModel::RecordEstimatedRunTime(root, &analysis_, device_info_); + GpuPerformanceModel::RecordEstimatedRunTime(root, &analysis_); double recorded_cycles = root->backend_config() ->reification_cost() .end_to_end_cycles(); @@ -124,11 +137,11 @@ ENTRY e { ASSERT_IS_OK(root->Accept(&analysis_)); GpuPerformanceModel::RunTimes t = - GpuPerformanceModel::EstimateRunTimes(root, &analysis_, device_info_); + GpuPerformanceModel::EstimateRunTimes(root, &analysis_); // Dominated by the DRAM bandwidth. EXPECT_NEAR(absl::ToInt64Microseconds(t.time_unfused), 175, 30); - GpuPerformanceModel::RecordEstimatedRunTime(root, &analysis_, device_info_); + GpuPerformanceModel::RecordEstimatedRunTime(root, &analysis_); double recorded_cycles = root->backend_config() ->reification_cost() .end_to_end_cycles(); @@ -159,7 +172,7 @@ ENTRY e { ASSERT_IS_OK(root->Accept(&analysis_)); GpuPerformanceModel::RunTimes t = - GpuPerformanceModel::EstimateRunTimes(root, &analysis_, device_info_); + GpuPerformanceModel::EstimateRunTimes(root, &analysis_); // Parameter 0 read is accelerated by L1 cache even though the total data // volume is the same as in the test LargeReadWrite above. EXPECT_NEAR(absl::ToInt64Microseconds(t.time_unfused), 118, 12); @@ -189,649 +202,11 @@ ENTRY e { ASSERT_IS_OK(root->Accept(&analysis_)); GpuPerformanceModel::RunTimes t = - GpuPerformanceModel::EstimateRunTimes(root, &analysis_, device_info_); + GpuPerformanceModel::EstimateRunTimes(root, &analysis_); // Parameter 0 read is accelerated by L2 cache (does not fit in L1). EXPECT_NEAR(absl::ToInt64Microseconds(t.time_unfused), 123, 12); } -TEST_F(GpuPerformanceModelTest, S32Divide) { - absl::string_view hlo_string = R"( -HloModule m - -f { - b0 = s32[10000000] parameter(0) - b1 = s32[10000000] parameter(1) - d0 = s32[10000000] divide(b0, b1) - d1 = s32[10000000] divide(d0, b1) - d2 = s32[10000000] divide(d1, b1) - d3 = s32[10000000] divide(d2, b1) - d4 = s32[10000000] divide(d3, b1) - d5 = s32[10000000] divide(d4, b1) - d6 = s32[10000000] divide(d5, b1) - d7 = s32[10000000] divide(d6, b1) - d8 = s32[10000000] divide(d7, b1) - d9 = s32[10000000] divide(d8, b1) - d10 = s32[10000000] divide(d9, b1) - d11 = s32[10000000] divide(d10, b1) - d12 = s32[10000000] divide(d11, b1) - d13 = s32[10000000] divide(d12, b1) - d14 = s32[10000000] divide(d13, b1) - d15 = s32[10000000] divide(d14, b1) - d16 = s32[10000000] divide(d15, b1) - d17 = s32[10000000] divide(d16, b1) - d18 = s32[10000000] divide(d17, b1) - ROOT d19 = s32[10000000] divide(d18, b1) -} - -ENTRY e { - p0 = s32[10000000] parameter(0) - p1 = s32[10000000] parameter(1) - ROOT r.1 = s32[10000000] fusion(p0, p1), kind=kLoop, calls=f -} -)"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - HloInstruction* root = module->entry_computation()->root_instruction(); - ASSERT_IS_OK(root->Accept(&analysis_)); - - GpuPerformanceModel::RunTimes t = - GpuPerformanceModel::EstimateRunTimes(root, &analysis_, device_info_); - EXPECT_NEAR(absl::ToInt64Microseconds(t.time_unfused), 482, 48); -} - -TEST_F(GpuPerformanceModelTest, F32Log) { - absl::string_view hlo_string = R"( -HloModule m - -f { - b0 = f32[10000000] parameter(0) - e0 = f32[10000000] log(b0) - e1 = f32[10000000] log(e0) - e2 = f32[10000000] log(e1) - e3 = f32[10000000] log(e2) - e4 = f32[10000000] log(e3) - e5 = f32[10000000] log(e4) - e6 = f32[10000000] log(e5) - e7 = f32[10000000] log(e6) - e8 = f32[10000000] log(e7) - e9 = f32[10000000] log(e8) - e10 = f32[10000000] log(e9) - e11 = f32[10000000] log(e10) - e12 = f32[10000000] log(e11) - e13 = f32[10000000] log(e12) - e14 = f32[10000000] log(e13) - e15 = f32[10000000] log(e14) - e16 = f32[10000000] log(e15) - e17 = f32[10000000] log(e16) - e18 = f32[10000000] log(e17) - e19 = f32[10000000] log(e18) - ROOT e20 = f32[10000000] log(e19) -} - -ENTRY e { - p0 = f32[10000000] parameter(0) - ROOT r.1 = f32[10000000] fusion(p0), kind=kLoop, calls=f -} -)"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - HloInstruction* root = module->entry_computation()->root_instruction(); - ASSERT_IS_OK(root->Accept(&analysis_)); - - GpuPerformanceModel::RunTimes t = - GpuPerformanceModel::EstimateRunTimes(root, &analysis_, device_info_); - EXPECT_NEAR(absl::ToInt64Microseconds(t.time_unfused), 312, 31); - - GpuPerformanceModel::RecordEstimatedRunTime(root, &analysis_, device_info_); - double recorded_cycles = root->backend_config() - ->reification_cost() - .end_to_end_cycles(); - EXPECT_NEAR(recorded_cycles, 439452, 100); -} - -TEST_F(GpuPerformanceModelTest, F64Log) { - absl::string_view hlo_string = R"( -HloModule m - -f { - b0 = f64[10000000] parameter(0) - e0 = f64[10000000] log(b0) - e1 = f64[10000000] log(e0) - e2 = f64[10000000] log(e1) - e3 = f64[10000000] log(e2) - e4 = f64[10000000] log(e3) - e5 = f64[10000000] log(e4) - e6 = f64[10000000] log(e5) - e7 = f64[10000000] log(e6) - e8 = f64[10000000] log(e7) - e9 = f64[10000000] log(e8) - e10 = f64[10000000] log(e9) - e11 = f64[10000000] log(e10) - e12 = f64[10000000] log(e11) - e13 = f64[10000000] log(e12) - e14 = f64[10000000] log(e13) - e15 = f64[10000000] log(e14) - e16 = f64[10000000] log(e15) - e17 = f64[10000000] log(e16) - e18 = f64[10000000] log(e17) - e19 = f64[10000000] log(e18) - ROOT e20 = f64[10000000] log(e19) -} - -ENTRY e { - p0 = f64[10000000] parameter(0) - ROOT r.1 = f64[10000000] fusion(p0), kind=kLoop, calls=f -} -)"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - HloInstruction* root = module->entry_computation()->root_instruction(); - ASSERT_IS_OK(root->Accept(&analysis_)); - - GpuPerformanceModel::RunTimes t = - GpuPerformanceModel::EstimateRunTimes(root, &analysis_, device_info_); - EXPECT_NEAR(absl::ToInt64Microseconds(t.time_unfused), 7100, 700); -} - -TEST_F(GpuPerformanceModelTest, F64DivideOnce) { - absl::string_view hlo_string = R"( -HloModule m - -f { - b0 = f64[10000000] parameter(0) - b1 = f64[10000000] parameter(1) - ROOT d0 = f64[10000000] divide(b0, b1) -} - -ENTRY e { - p0 = f64[10000000] parameter(0) - p1 = f64[10000000] parameter(1) - ROOT r.1 = f64[10000000] fusion(p0, p1), kind=kLoop, calls=f -} -)"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - HloInstruction* root = module->entry_computation()->root_instruction(); - ASSERT_IS_OK(root->Accept(&analysis_)); - - GpuPerformanceModel::RunTimes t = - GpuPerformanceModel::EstimateRunTimes(root, &analysis_, device_info_); - EXPECT_NEAR(absl::ToInt64Microseconds(t.time_unfused), 1100, 110); -} - -TEST_F(GpuPerformanceModelTest, F64Exp) { - absl::string_view hlo_string = R"( -HloModule m - -f { - b0 = f64[10000000] parameter(0) - e0 = f64[10000000] exponential(b0) - ROOT r0 = f64[10000000] exponential(e0) -} - -ENTRY e { - p0 = f64[10000000] parameter(0) - ROOT r.1 = f64[10000000] fusion(p0), kind=kLoop, calls=f -} -)"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - HloInstruction* root = module->entry_computation()->root_instruction(); - ASSERT_IS_OK(root->Accept(&analysis_)); - - GpuPerformanceModel::RunTimes t = - GpuPerformanceModel::EstimateRunTimes(root, &analysis_, device_info_); - EXPECT_NEAR(absl::ToInt64Microseconds(t.time_unfused), 1400, 140); -} - -TEST_F(GpuPerformanceModelTest, F64DivideManyTimes) { - absl::string_view hlo_string = R"( -HloModule m - -f { - b0 = f64[10000000] parameter(0) - b1 = f64[10000000] parameter(1) - d0 = f64[10000000] divide(b0, b1) - d1 = f64[10000000] divide(d0, b1) - d2 = f64[10000000] divide(d1, b1) - d3 = f64[10000000] divide(d2, b1) - d4 = f64[10000000] divide(d3, b1) - d5 = f64[10000000] divide(d4, b1) - d6 = f64[10000000] divide(d5, b1) - d7 = f64[10000000] divide(d6, b1) - d8 = f64[10000000] divide(d7, b1) - d9 = f64[10000000] divide(d8, b1) - d10 = f64[10000000] divide(d9, b1) - d11 = f64[10000000] divide(d10, b1) - d12 = f64[10000000] divide(d11, b1) - d13 = f64[10000000] divide(d12, b1) - d14 = f64[10000000] divide(d13, b1) - d15 = f64[10000000] divide(d14, b1) - d16 = f64[10000000] divide(d15, b1) - d17 = f64[10000000] divide(d16, b1) - d18 = f64[10000000] divide(d17, b1) - ROOT d19 = f64[10000000] divide(d18, b1) -} - -ENTRY e { - p0 = f64[10000000] parameter(0) - p1 = f64[10000000] parameter(1) - ROOT r.1 = f64[10000000] fusion(p0, p1), kind=kLoop, calls=f -} -)"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - HloInstruction* root = module->entry_computation()->root_instruction(); - ASSERT_IS_OK(root->Accept(&analysis_)); - - GpuPerformanceModel::RunTimes t = - GpuPerformanceModel::EstimateRunTimes(root, &analysis_, device_info_); - EXPECT_NEAR(absl::ToInt64Microseconds(t.time_unfused), 20000, 2000); -} - -TEST_F(GpuPerformanceModelTest, F64Multiply) { - absl::string_view hlo_string = R"( -HloModule m - -f { - b0 = f64[10000000] parameter(0) - b1 = f64[10000000] parameter(1) - d0 = f64[10000000] multiply(b0, b1) - d1 = f64[10000000] multiply(d0, b1) - d2 = f64[10000000] multiply(d1, b1) - d3 = f64[10000000] multiply(d2, b1) - d4 = f64[10000000] multiply(d3, b1) - d5 = f64[10000000] multiply(d4, b1) - d6 = f64[10000000] multiply(d5, b1) - d7 = f64[10000000] multiply(d6, b1) - d8 = f64[10000000] multiply(d7, b1) - d9 = f64[10000000] multiply(d8, b1) - d10 = f64[10000000] multiply(d9, b1) - d11 = f64[10000000] multiply(d10, b1) - d12 = f64[10000000] multiply(d11, b1) - d13 = f64[10000000] multiply(d12, b1) - d14 = f64[10000000] multiply(d13, b1) - d15 = f64[10000000] multiply(d14, b1) - d16 = f64[10000000] multiply(d15, b1) - d17 = f64[10000000] multiply(d16, b1) - d18 = f64[10000000] multiply(d17, b1) - ROOT d19 = f64[10000000] multiply(d18, b1) -} - -ENTRY e { - p0 = f64[10000000] parameter(0) - p1 = f64[10000000] parameter(1) - ROOT r.1 = f64[10000000] fusion(p0, p1), kind=kLoop, calls=f -} -)"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - HloInstruction* root = module->entry_computation()->root_instruction(); - ASSERT_IS_OK(root->Accept(&analysis_)); - - GpuPerformanceModel::RunTimes t = - GpuPerformanceModel::EstimateRunTimes(root, &analysis_, device_info_); - EXPECT_NEAR(absl::ToInt64Microseconds(t.time_unfused), 794, 80); -} - -TEST_F(GpuPerformanceModelTest, C128Multiply) { - absl::string_view hlo_string = R"( -HloModule m - -f { - b0 = c128[10000000] parameter(0) - b1 = c128[10000000] parameter(1) - d0 = c128[10000000] multiply(b0, b1) - d1 = c128[10000000] multiply(d0, b1) - d2 = c128[10000000] multiply(d1, b1) - d3 = c128[10000000] multiply(d2, b1) - d4 = c128[10000000] multiply(d3, b1) - d5 = c128[10000000] multiply(d4, b1) - d6 = c128[10000000] multiply(d5, b1) - d7 = c128[10000000] multiply(d6, b1) - d8 = c128[10000000] multiply(d7, b1) - d9 = c128[10000000] multiply(d8, b1) - d10 = c128[10000000] multiply(d9, b1) - d11 = c128[10000000] multiply(d10, b1) - d12 = c128[10000000] multiply(d11, b1) - d13 = c128[10000000] multiply(d12, b1) - d14 = c128[10000000] multiply(d13, b1) - d15 = c128[10000000] multiply(d14, b1) - d16 = c128[10000000] multiply(d15, b1) - d17 = c128[10000000] multiply(d16, b1) - d18 = c128[10000000] multiply(d17, b1) - ROOT d19 = c128[10000000] multiply(d18, b1) -} - -ENTRY e { - p0 = c128[10000000] parameter(0) - p1 = c128[10000000] parameter(1) - ROOT r.1 = c128[10000000] fusion(p0, p1), kind=kLoop, calls=f -} -)"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - HloInstruction* root = module->entry_computation()->root_instruction(); - ASSERT_IS_OK(root->Accept(&analysis_)); - - GpuPerformanceModel::RunTimes t = - GpuPerformanceModel::EstimateRunTimes(root, &analysis_, device_info_); - EXPECT_NEAR(absl::ToInt64Microseconds(t.time_unfused), 4700, 470); -} - -TEST_F(GpuPerformanceModelTest, C128Power) { - absl::string_view hlo_string = R"( -HloModule m - -f { - b0 = c128[10000000] parameter(0) - b1 = c128[10000000] parameter(1) - d0 = c128[10000000] power(b0, b1) - d1 = c128[10000000] power(d0, b1) - d2 = c128[10000000] power(d1, b1) - d3 = c128[10000000] power(d2, b1) - d4 = c128[10000000] power(d3, b1) - d5 = c128[10000000] power(d4, b1) - d6 = c128[10000000] power(d5, b1) - d7 = c128[10000000] power(d6, b1) - d8 = c128[10000000] power(d7, b1) - ROOT d9 = c128[10000000] power(d8, b1) -} - -ENTRY e { - p0 = c128[10000000] parameter(0) - p1 = c128[10000000] parameter(1) - ROOT r.1 = c128[10000000] fusion(p0, p1), kind=kLoop, calls=f -} -)"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - HloInstruction* root = module->entry_computation()->root_instruction(); - ASSERT_IS_OK(root->Accept(&analysis_)); - - GpuPerformanceModel::RunTimes t = - GpuPerformanceModel::EstimateRunTimes(root, &analysis_, device_info_); - EXPECT_NEAR(absl::ToInt64Microseconds(t.time_unfused), 93000, 9300); -} - -TEST_F(GpuPerformanceModelTest, F64Power) { - absl::string_view hlo_string = R"( -HloModule m - -f { - b0 = f64[10000000] parameter(0) - b1 = f64[10000000] parameter(1) - d0 = f64[10000000] power(b0, b1) - d1 = f64[10000000] power(d0, b1) - d2 = f64[10000000] power(d1, b1) - d3 = f64[10000000] power(d2, b1) - d4 = f64[10000000] power(d3, b1) - d5 = f64[10000000] power(d4, b1) - d6 = f64[10000000] power(d5, b1) - d7 = f64[10000000] power(d6, b1) - d8 = f64[10000000] power(d7, b1) - ROOT d9 = f64[10000000] power(d8, b1) -} - -ENTRY e { - p0 = f64[10000000] parameter(0) - p1 = f64[10000000] parameter(1) - ROOT r.1 = f64[10000000] fusion(p0, p1), kind=kLoop, calls=f -} -)"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - HloInstruction* root = module->entry_computation()->root_instruction(); - ASSERT_IS_OK(root->Accept(&analysis_)); - - GpuPerformanceModel::RunTimes t = - GpuPerformanceModel::EstimateRunTimes(root, &analysis_, device_info_); - EXPECT_NEAR(absl::ToInt64Microseconds(t.time_unfused), 36000, 3600); -} - -TEST_F(GpuPerformanceModelTest, F64Tanh) { - absl::string_view hlo_string = R"( -HloModule m - -f { - b0 = f64[10000000] parameter(0) - e0 = f64[10000000] tanh(b0) - e1 = f64[10000000] tanh(e0) - e2 = f64[10000000] tanh(e1) - e3 = f64[10000000] tanh(e2) - e4 = f64[10000000] tanh(e3) - e5 = f64[10000000] tanh(e4) - e6 = f64[10000000] tanh(e5) - e7 = f64[10000000] tanh(e6) - e8 = f64[10000000] tanh(e7) - e9 = f64[10000000] tanh(e8) - e10 = f64[10000000] tanh(e9) - e11 = f64[10000000] tanh(e10) - e12 = f64[10000000] tanh(e11) - e13 = f64[10000000] tanh(e12) - e14 = f64[10000000] tanh(e13) - e15 = f64[10000000] tanh(e14) - e16 = f64[10000000] tanh(e15) - e17 = f64[10000000] tanh(e16) - e18 = f64[10000000] tanh(e17) - e19 = f64[10000000] tanh(e18) - ROOT e20 = f64[10000000] tanh(e19) -} - -ENTRY e { - p0 = f64[10000000] parameter(0) - ROOT r.1 = f64[10000000] fusion(p0), kind=kLoop, calls=f -} -)"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - HloInstruction* root = module->entry_computation()->root_instruction(); - ASSERT_IS_OK(root->Accept(&analysis_)); - - GpuPerformanceModel::RunTimes t = - GpuPerformanceModel::EstimateRunTimes(root, &analysis_, device_info_); - EXPECT_NEAR(absl::ToInt64Microseconds(t.time_unfused), 14000, 1400); -} - -TEST_F(GpuPerformanceModelTest, F32Tanh) { - absl::string_view hlo_string = R"( -HloModule m - -f { - b0 = f32[10000000] parameter(0) - e0 = f32[10000000] tanh(b0) - e1 = f32[10000000] tanh(e0) - e2 = f32[10000000] tanh(e1) - e3 = f32[10000000] tanh(e2) - e4 = f32[10000000] tanh(e3) - e5 = f32[10000000] tanh(e4) - e6 = f32[10000000] tanh(e5) - e7 = f32[10000000] tanh(e6) - e8 = f32[10000000] tanh(e7) - e9 = f32[10000000] tanh(e8) - e10 = f32[10000000] tanh(e9) - e11 = f32[10000000] tanh(e10) - e12 = f32[10000000] tanh(e11) - e13 = f32[10000000] tanh(e12) - e14 = f32[10000000] tanh(e13) - e15 = f32[10000000] tanh(e14) - e16 = f32[10000000] tanh(e15) - e17 = f32[10000000] tanh(e16) - e18 = f32[10000000] tanh(e17) - e19 = f32[10000000] tanh(e18) - ROOT e20 = f32[10000000] tanh(e19) -} - -ENTRY e { - p0 = f32[10000000] parameter(0) - ROOT r.1 = f32[10000000] fusion(p0), kind=kLoop, calls=f -} -)"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - HloInstruction* root = module->entry_computation()->root_instruction(); - ASSERT_IS_OK(root->Accept(&analysis_)); - - GpuPerformanceModel::RunTimes t = - GpuPerformanceModel::EstimateRunTimes(root, &analysis_, device_info_); - EXPECT_NEAR(absl::ToInt64Microseconds(t.time_unfused), 200, 20); -} - -TEST_F(GpuPerformanceModelTest, F64Sqrt) { - absl::string_view hlo_string = R"( -HloModule m -f { - b0 = f64[10000000] parameter(0) - e0 = f64[10000000] sqrt(b0) - e1 = f64[10000000] sqrt(e0) - e2 = f64[10000000] sqrt(e1) - e3 = f64[10000000] sqrt(e2) - e4 = f64[10000000] sqrt(e3) - e5 = f64[10000000] sqrt(e4) - e6 = f64[10000000] sqrt(e5) - e7 = f64[10000000] sqrt(e6) - e8 = f64[10000000] sqrt(e7) - e9 = f64[10000000] sqrt(e8) - e10 = f64[10000000] sqrt(e9) - e11 = f64[10000000] sqrt(e10) - e12 = f64[10000000] sqrt(e11) - e13 = f64[10000000] sqrt(e12) - e14 = f64[10000000] sqrt(e13) - e15 = f64[10000000] sqrt(e14) - e16 = f64[10000000] sqrt(e15) - e17 = f64[10000000] sqrt(e16) - e18 = f64[10000000] sqrt(e17) - e19 = f64[10000000] sqrt(e18) - ROOT e20 = f64[10000000] sqrt(e19) -} -ENTRY e { - p0 = f64[10000000] parameter(0) - ROOT r.1 = f64[10000000] fusion(p0), kind=kLoop, calls=f -} -)"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - HloInstruction* root = module->entry_computation()->root_instruction(); - ASSERT_IS_OK(root->Accept(&analysis_)); - - GpuPerformanceModel::RunTimes t = - GpuPerformanceModel::EstimateRunTimes(root, &analysis_, device_info_); - EXPECT_NEAR(absl::ToInt64Microseconds(t.time_unfused), 7800, 780); -} - -TEST_F(GpuPerformanceModelTest, C128Sqrt) { - absl::string_view hlo_string = R"( -HloModule m - -f { - b0 = c128[10000000] parameter(0) - e0 = c128[10000000] sqrt(b0) - e1 = c128[10000000] sqrt(e0) - e2 = c128[10000000] sqrt(e1) - e3 = c128[10000000] sqrt(e2) - e4 = c128[10000000] sqrt(e3) - e5 = c128[10000000] sqrt(e4) - e6 = c128[10000000] sqrt(e5) - e7 = c128[10000000] sqrt(e6) - e8 = c128[10000000] sqrt(e7) - ROOTe9 = c128[10000000] sqrt(e8) -} - -ENTRY e { - p0 = c128[10000000] parameter(0) - ROOT r.1 = c128[10000000] fusion(p0), kind=kLoop, calls=f -} -)"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - HloInstruction* root = module->entry_computation()->root_instruction(); - ASSERT_IS_OK(root->Accept(&analysis_)); - - GpuPerformanceModel::RunTimes t = - GpuPerformanceModel::EstimateRunTimes(root, &analysis_, device_info_); - EXPECT_NEAR(absl::ToInt64Microseconds(t.time_unfused), 83000, 8000); -} - -TEST_F(GpuPerformanceModelTest, F64Rsqrt) { - absl::string_view hlo_string = R"( -HloModule m - -f { - b0 = f64[10000000] parameter(0) - e0 = f64[10000000] rsqrt(b0) - e1 = f64[10000000] rsqrt(e0) - e2 = f64[10000000] rsqrt(e1) - e3 = f64[10000000] rsqrt(e2) - e4 = f64[10000000] rsqrt(e3) - e5 = f64[10000000] rsqrt(e4) - e6 = f64[10000000] rsqrt(e5) - e7 = f64[10000000] rsqrt(e6) - e8 = f64[10000000] rsqrt(e7) - e9 = f64[10000000] rsqrt(e8) - e10 = f64[10000000] rsqrt(e9) - e11 = f64[10000000] rsqrt(e10) - e12 = f64[10000000] rsqrt(e11) - e13 = f64[10000000] rsqrt(e12) - e14 = f64[10000000] rsqrt(e13) - e15 = f64[10000000] rsqrt(e14) - e16 = f64[10000000] rsqrt(e15) - e17 = f64[10000000] rsqrt(e16) - e18 = f64[10000000] rsqrt(e17) - e19 = f64[10000000] rsqrt(e18) - ROOT e20 = f64[10000000] rsqrt(e19) -} - -ENTRY e { - p0 = f64[10000000] parameter(0) - ROOT r.1 = f64[10000000] fusion(p0), kind=kLoop, calls=f -} -)"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - HloInstruction* root = module->entry_computation()->root_instruction(); - ASSERT_IS_OK(root->Accept(&analysis_)); - - GpuPerformanceModel::RunTimes t = - GpuPerformanceModel::EstimateRunTimes(root, &analysis_, device_info_); - EXPECT_NEAR(absl::ToInt64Microseconds(t.time_unfused), 6300, 630); -} - -TEST_F(GpuPerformanceModelTest, C128Divide) { - absl::string_view hlo_string = R"( -HloModule m - -f { - b0 = c128[10000000] parameter(0) - b1 = c128[10000000] parameter(1) - d0 = c128[10000000] divide(b0, b1) - d1 = c128[10000000] divide(d0, b1) - d2 = c128[10000000] divide(d1, b1) - d3 = c128[10000000] divide(d2, b1) - d4 = c128[10000000] divide(d3, b1) - d5 = c128[10000000] divide(d4, b1) - d6 = c128[10000000] divide(d5, b1) - d7 = c128[10000000] divide(d6, b1) - d8 = c128[10000000] divide(d7, b1) - ROOT d9 = c128[10000000] divide(d8, b1) -} - -ENTRY e { - p0 = c128[10000000] parameter(0) - p1 = c128[10000000] parameter(1) - ROOT r.1 = c128[10000000] fusion(p0, p1), kind=kLoop, calls=f -} -)"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - HloInstruction* root = module->entry_computation()->root_instruction(); - ASSERT_IS_OK(root->Accept(&analysis_)); - - HloInstruction* instruction = root; - GpuPerformanceModel::RunTimes t = GpuPerformanceModel::EstimateRunTimes( - instruction, &analysis_, device_info_); - EXPECT_NEAR(absl::ToInt64Microseconds(t.time_unfused), 64000, 6400); -} - TEST_F(GpuPerformanceModelTest, UnusedParameter) { Shape shape = ShapeUtil::MakeShape(F32, {100000}); @@ -860,7 +235,7 @@ TEST_F(GpuPerformanceModelTest, UnusedParameter) { ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_)); GpuPerformanceModel::RunTimes t = - GpuPerformanceModel::EstimateRunTimes(root, &analysis_, device_info_); + GpuPerformanceModel::EstimateRunTimes(root, &analysis_); EXPECT_NEAR(absl::ToInt64Microseconds(t.time_unfused), 2, 1); } diff --git a/tensorflow/compiler/xla/service/gpu/hlo_op_profile.proto b/tensorflow/compiler/xla/service/gpu/hlo_op_profile.proto new file mode 100644 index 00000000000000..4d5fe2c6915d86 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/hlo_op_profile.proto @@ -0,0 +1,18 @@ +syntax = "proto3"; + +package xla.gpu; + +import "tensorflow/compiler/xla/service/hlo.proto"; + +message HloInstructionProfile { + xla.HloInstructionProto instruction = 1; + int64 clock_cycles = 2; +} + +message HloInstructionProfileList { + repeated HloInstructionProfile entries = 1; +} + +message DeviceHloInstructionProfiles { + map entries = 2; +} diff --git a/tensorflow/compiler/xla/service/gpu/hlo_op_profiler.cc b/tensorflow/compiler/xla/service/gpu/hlo_op_profiler.cc new file mode 100644 index 00000000000000..505682a9dd319b --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/hlo_op_profiler.cc @@ -0,0 +1,185 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/hlo_op_profiler.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_opcode.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/service/executable.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_op_profile.pb.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace gpu { + +/*static*/ std::unique_ptr HloOpProfiler::MakeModuleForMeasurements( + HloOpcode op, PrimitiveType data_type, int64_t n_elements, + int chain_length) { + const Shape shape = ShapeUtil::MakeShape(data_type, {n_elements}); + auto module = std::make_unique("m", HloModuleConfig{}); + HloComputation::Builder entry_builder("b"); + HloComputation::Builder fusion_builder("sb"); + + if (HloOpcodeArity(op) == 2) { + HloInstruction* pf0 = fusion_builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "pf0")); + HloInstruction* pf1 = fusion_builder.AddInstruction( + HloInstruction::CreateParameter(1, shape, "pf1")); + HloInstruction* last = pf0; + for (int i = 0; i < chain_length; ++i) { + last = fusion_builder.AddInstruction( + HloInstruction::CreateBinary(shape, op, last, pf1)); + } + HloComputation* subcomp = + module->AddEmbeddedComputation(fusion_builder.Build()); + HloInstruction* p0 = entry_builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "p0")); + HloInstruction* p1 = entry_builder.AddInstruction( + HloInstruction::CreateParameter(1, shape, "p1")); + entry_builder.AddInstruction(HloInstruction::CreateFusion( + shape, HloInstruction::FusionKind::kLoop, {p0, p1}, subcomp)); + } else if (HloOpcodeArity(op) == 1) { + HloInstruction* pf = fusion_builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "pf")); + HloInstruction* last = pf; + for (int i = 0; i < chain_length; ++i) { + last = fusion_builder.AddInstruction( + HloInstruction::CreateUnary(shape, op, last)); + } + HloComputation* subcomp = + module->AddEmbeddedComputation(fusion_builder.Build()); + HloInstruction* p0 = entry_builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "p0")); + entry_builder.AddInstruction(HloInstruction::CreateFusion( + shape, HloInstruction::FusionKind::kLoop, {p0}, subcomp)); + } else { + LOG(FATAL) << "Unsupported opcode: " << HloOpcodeString(op); + } + module->AddEntryComputation(entry_builder.Build()); + VLOG(9) << module->ToString(); + return module; +} + +StatusOr HloOpProfiler::MeasureOpChainDuration( + HloOpcode op, PrimitiveType data_type, int64_t input_size, + int chain_length) { + std::unique_ptr module = + MakeModuleForMeasurements(op, data_type, input_size, chain_length); + + std::minstd_rand0 engine; + // Some operations have dynamic duration that depends on the input values. + // Measure each operation with small and large inputs and average. + std::vector args_small = MakeFakeArguments(module.get(), &engine, + /*use_large_range=*/false) + .value(); + std::vector args_large = MakeFakeArguments(module.get(), &engine, + /*use_large_range=*/true) + .value(); + const absl::Time t_compile_start = absl::Now(); + TF_ASSIGN_OR_RETURN(std::unique_ptr ex, + runner_.CreateExecutable(std::move(module), + /*run_hlo_passes=*/false)); + if (absl::Now() - t_compile_start > absl::Seconds(10)) { + return ResourceExhausted("Too slow compilation"); + } + + // Warmup. + TF_RETURN_IF_ERROR( + runner_.ExecuteWithExecutable(ex.get(), args_small).status()); + + absl::Duration sum = absl::ZeroDuration(); + constexpr int kIterations = 10; + for (int i = 0; i < kIterations; ++i) { + ExecutionProfile profile_small; + TF_RETURN_IF_ERROR( + runner_.ExecuteWithExecutable(ex.get(), args_small, &profile_small) + .status()); + ExecutionProfile profile_large; + TF_RETURN_IF_ERROR( + runner_.ExecuteWithExecutable(ex.get(), args_large, &profile_large) + .status()); + sum += absl::Nanoseconds( + (profile_small.compute_time_ns() + profile_large.compute_time_ns()) / + 2); + } + return sum / kIterations; +} + +StatusOr HloOpProfiler::MeasureClockCyclesPerOp( + HloOpcode op, bool is_binary, PrimitiveType data_type, int64_t input_size) { + VLOG(2) << "Measuring " << HloOpcodeString(op) << " " + << primitive_util::LowercasePrimitiveTypeName(data_type); + + const absl::Duration overheads = + MeasureOpChainDuration(HloOpcode::kNegate, data_type, input_size, + /*chain_length=*/1) + .value(); + VLOG(3) << "Overheads: " << overheads; + + absl::Duration duration = absl::ZeroDuration(); + int chain_length = 1; + // Double the length of the operation chain until it becomes measurable + // compared to the overheads. + while (duration < 5 * overheads) { + TF_ASSIGN_OR_RETURN(duration, MeasureOpChainDuration( + op, data_type, input_size, chain_length)); + VLOG(3) << chain_length << "\t" << duration; + chain_length *= 2; + if (chain_length > kMaxOpChainLength) { + VLOG(2) << "The op is too fast to be measured with this method"; + return Unimplemented("op is too fast"); + } + } + + TF_ASSIGN_OR_RETURN( + absl::Duration double_duration, + MeasureOpChainDuration(op, data_type, input_size, chain_length)); + VLOG(3) << chain_length << "\t" << double_duration; + + // The difference between t_double and t corresponds to half of chain_length. + const absl::Duration time_per_op = + (double_duration - duration) * 2.0 / chain_length; + + const float clocks_per_nanosecond = + dev_info_.clock_rate_ghz * 2; // 2 for FMA + const int64_t n_clocks = + absl::ToInt64Nanoseconds(time_per_op) * clocks_per_nanosecond; + VLOG(3) << time_per_op << " = " << n_clocks << " clock cycles"; + HloInstructionProfile profile; + profile.mutable_instruction()->mutable_opcode()->assign(HloOpcodeString(op)); + profile.mutable_instruction()->mutable_shape()->set_element_type(data_type); + profile.set_clock_cycles(n_clocks); + return profile; +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/hlo_op_profiler.h b/tensorflow/compiler/xla/service/gpu/hlo_op_profiler.h new file mode 100644 index 00000000000000..180582512c3e6b --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/hlo_op_profiler.h @@ -0,0 +1,60 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_OP_PROFILER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_OP_PROFILER_H_ + +#include +#include + +#include "absl/time/time.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_module.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_device_info.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_op_profile.pb.h" +#include "tensorflow/compiler/xla/service/hlo_runner.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace gpu { + +class HloOpProfiler { + static std::unique_ptr MakeModuleForMeasurements( + HloOpcode op, PrimitiveType data_type, int64_t n_elements, + int chain_length); + StatusOr MeasureOpChainDuration(HloOpcode op, + PrimitiveType data_type, + int64_t input_size, + int chain_length); + + public: + explicit HloOpProfiler(HloRunner& runner) + : runner_(runner), + dev_info_(GetGpuDeviceInfo(runner.backend().stream_executors()[0])) {} + StatusOr MeasureClockCyclesPerOp( + HloOpcode op, bool binary, PrimitiveType data_type, int64_t input_size); + + private: + // Long chains can be too slow to compile. + static constexpr int kMaxOpChainLength = 4096; + + HloRunner& runner_; + const GpuDeviceInfo dev_info_; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_OP_PROFILER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/hlo_op_profiler_run.cc b/tensorflow/compiler/xla/service/gpu/hlo_op_profiler_run.cc new file mode 100644 index 00000000000000..357db1fd2e3e5c --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/hlo_op_profiler_run.cc @@ -0,0 +1,127 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_device_info.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_op_profile.pb.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_op_profiler.h" +#include "tensorflow/compiler/xla/service/hlo_runner.h" +#include "tensorflow/compiler/xla/service/platform_util.h" +#include "tensorflow/tsl/platform/env.h" +#include "tensorflow/tsl/platform/init_main.h" +#include "tensorflow/tsl/platform/path.h" +#include "tensorflow/tsl/platform/test.h" +#include "tensorflow/tsl/util/command_line_flags.h" + +namespace xla { +namespace gpu { +namespace { + +constexpr absl::string_view kUsage = R"( +This tool measures clock cycles per operation on GPU. +)"; + +void WriteOutput(const DeviceHloInstructionProfiles& literal, + absl::string_view name) { + std::string file_name; + std::string output_directory; + if (tsl::io::GetTestUndeclaredOutputsDir(&output_directory)) { + std::string filename = tsl::io::JoinPath( + output_directory, + absl::StrFormat("profiles-%d-%s", tsl::Env::Default()->NowMicros(), + name)); + file_name = absl::StrCat(filename, ".textproto"); + } else { + file_name = tsl::io::GetTempFilename(absl::StrCat(name, ".textproto")); + } + VLOG(0) << "Writing output to " << file_name; + TF_CHECK_OK(tsl::WriteStringToFile(tsl::Env::Default(), file_name, + literal.DebugString())); +} + +int RunProfiler(int argc, char** argv) { + std::string output_file; + std::vector flag_list = { + tsl::Flag("output_file", &output_file, + "Output measurements protobuf to the destination file."), + }; + AppendDebugOptionsFlags(&flag_list); + bool parse_ok = tsl::Flags::Parse(&argc, argv, flag_list); + tsl::port::InitMain(kUsage.data(), &argc, &argv); + if (!parse_ok) { + LOG(QFATAL) << "Error parsing flags"; + } + + HloRunner runner(PlatformUtil::GetPlatform("cuda").value()); + HloOpProfiler profiler(runner); + const gpu::GpuDeviceInfo dev_info = + gpu::GetGpuDeviceInfo(runner.backend().stream_executors()[0]); + VLOG(0) << dev_info.name << " @ " << dev_info.clock_rate_ghz << " GHz"; + + constexpr int64_t kInputSize = 1; + + const std::vector dtypes = {S8, S16, S32, S64, U8, U16, U32, + U64, F16, F32, F64, C64, C128}; + const std::vector unary_ops = { + HloOpcode::kCbrt, HloOpcode::kCos, HloOpcode::kExp, + HloOpcode::kExpm1, HloOpcode::kLog, HloOpcode::kLog1p, + HloOpcode::kLogistic, HloOpcode::kRsqrt, HloOpcode::kSin, + HloOpcode::kSqrt, HloOpcode::kTanh}; + const std::vector binary_ops = { + HloOpcode::kAdd, HloOpcode::kAtan2, HloOpcode::kDivide, + HloOpcode::kMultiply, HloOpcode::kPower, HloOpcode::kSubtract}; + + HloInstructionProfileList instr_profiles; + + for (const PrimitiveType data_type : dtypes) { + for (const HloOpcode op : unary_ops) { + auto result = + profiler.MeasureClockCyclesPerOp(op, false, data_type, kInputSize); + if (result.ok()) { + instr_profiles.add_entries()->Swap(&*result); + } + } + for (const HloOpcode op : binary_ops) { + auto result = + profiler.MeasureClockCyclesPerOp(op, true, data_type, kInputSize); + if (result.ok()) { + instr_profiles.add_entries()->Swap(&*result); + } + } + } + + VLOG(1) << "\n" << instr_profiles.DebugString(); + + DeviceHloInstructionProfiles device_profiles; + device_profiles.mutable_entries()->insert({dev_info.name, instr_profiles}); + if (!output_file.empty()) { + WriteOutput(device_profiles, output_file); + } + + return 0; +} + +} // namespace +} // namespace gpu +} // namespace xla + +int main(int argc, char** argv) { return xla::gpu::RunProfiler(argc, argv); } diff --git a/tensorflow/compiler/xla/service/gpu/hlo_op_profiler_test.cc b/tensorflow/compiler/xla/service/gpu/hlo_op_profiler_test.cc new file mode 100644 index 00000000000000..2bfbb9bd6de662 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/hlo_op_profiler_test.cc @@ -0,0 +1,46 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/hlo_op_profiler.h" + +#include "tensorflow/compiler/xla/hlo/ir/hlo_opcode.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" + +namespace xla { +namespace gpu { +namespace { + +using HloOpProfilerTest = HloTestBase; + +TEST_F(HloOpProfilerTest, BasicMeasurementsAreCorrect) { + HloOpProfiler profiler(test_runner_); + // f32 add is too fast to be measurable here. + EXPECT_FALSE( + profiler.MeasureClockCyclesPerOp(HloOpcode::kAdd, true, F32, 1).ok()); + // f64 divide is somewhat slow. + EXPECT_GT(profiler.MeasureClockCyclesPerOp(HloOpcode::kDivide, true, F64, 1) + .value() + .clock_cycles(), + 500); + // c128 sqrt is slow. + EXPECT_GT(profiler.MeasureClockCyclesPerOp(HloOpcode::kSqrt, false, C128, 1) + .value() + .clock_cycles(), + 5000); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/hlo_op_profiles.h b/tensorflow/compiler/xla/service/gpu/hlo_op_profiles.h new file mode 100644 index 00000000000000..83a6c426f1b6b8 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/hlo_op_profiles.h @@ -0,0 +1,2370 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_OP_PROFILES_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_OP_PROFILES_H_ + +namespace xla { +namespace gpu { + +// The data below is obtained with +// xla/service/gpu:hlo_op_profiler_run + +constexpr char kDeviceHloOpProfiles[] = R"pb( + entries { + key: "NVIDIA RTX A6000" + value { + entries { + instruction { + opcode: "divide" + shape { element_type: S8 } + } + clock_cycles: 370 + } + entries { + instruction { + opcode: "power" + shape { element_type: S8 } + } + clock_cycles: 392 + } + entries { + instruction { + opcode: "divide" + shape { element_type: S16 } + } + clock_cycles: 367 + } + entries { + instruction { + opcode: "power" + shape { element_type: S16 } + } + clock_cycles: 396 + } + entries { + instruction { + opcode: "divide" + shape { element_type: S32 } + } + clock_cycles: 306 + } + entries { + instruction { + opcode: "divide" + shape { element_type: S64 } + } + clock_cycles: 918 + } + entries { + instruction { + opcode: "power" + shape { element_type: S64 } + } + clock_cycles: 601 + } + entries { + instruction { + opcode: "divide" + shape { element_type: U8 } + } + clock_cycles: 306 + } + entries { + instruction { + opcode: "power" + shape { element_type: U8 } + } + clock_cycles: 388 + } + entries { + instruction { + opcode: "divide" + shape { element_type: U16 } + } + clock_cycles: 302 + } + entries { + instruction { + opcode: "power" + shape { element_type: U16 } + } + clock_cycles: 399 + } + entries { + instruction { + opcode: "divide" + shape { element_type: U32 } + } + clock_cycles: 115 + } + entries { + instruction { + opcode: "divide" + shape { element_type: U64 } + } + clock_cycles: 838 + } + entries { + instruction { + opcode: "power" + shape { element_type: U64 } + } + clock_cycles: 604 + } + entries { + instruction { + opcode: "cbrt" + shape { element_type: F16 } + } + clock_cycles: 925 + } + entries { + instruction { + opcode: "cosine" + shape { element_type: F16 } + } + clock_cycles: 691 + } + entries { + instruction { + opcode: "exponential" + shape { element_type: F16 } + } + clock_cycles: 108 + } + entries { + instruction { + opcode: "exponential-minus-one" + shape { element_type: F16 } + } + clock_cycles: 396 + } + entries { + instruction { + opcode: "log" + shape { element_type: F16 } + } + clock_cycles: 266 + } + entries { + instruction { + opcode: "log-plus-one" + shape { element_type: F16 } + } + clock_cycles: 284 + } + entries { + instruction { + opcode: "logistic" + shape { element_type: F16 } + } + clock_cycles: 226 + } + entries { + instruction { + opcode: "rsqrt" + shape { element_type: F16 } + } + clock_cycles: 97 + } + entries { + instruction { + opcode: "sqrt" + shape { element_type: F16 } + } + clock_cycles: 97 + } + entries { + instruction { + opcode: "tanh" + shape { element_type: F16 } + } + clock_cycles: 212 + } + entries { + instruction { + opcode: "atan2" + shape { element_type: F16 } + } + clock_cycles: 482 + } + entries { + instruction { + opcode: "power" + shape { element_type: F16 } + } + clock_cycles: 975 + } + entries { + instruction { + opcode: "cbrt" + shape { element_type: F32 } + } + clock_cycles: 867 + } + entries { + instruction { + opcode: "cosine" + shape { element_type: F32 } + } + clock_cycles: 662 + } + entries { + instruction { + opcode: "exponential" + shape { element_type: F32 } + } + clock_cycles: 86 + } + entries { + instruction { + opcode: "exponential-minus-one" + shape { element_type: F32 } + } + clock_cycles: 381 + } + entries { + instruction { + opcode: "log" + shape { element_type: F32 } + } + clock_cycles: 244 + } + entries { + instruction { + opcode: "log-plus-one" + shape { element_type: F32 } + } + clock_cycles: 262 + } + entries { + instruction { + opcode: "logistic" + shape { element_type: F32 } + } + clock_cycles: 176 + } + entries { + instruction { + opcode: "rsqrt" + shape { element_type: F32 } + } + clock_cycles: 75 + } + entries { + instruction { + opcode: "sine" + shape { element_type: F32 } + } + clock_cycles: 662 + } + entries { + instruction { + opcode: "sqrt" + shape { element_type: F32 } + } + clock_cycles: 75 + } + entries { + instruction { + opcode: "tanh" + shape { element_type: F32 } + } + clock_cycles: 190 + } + entries { + instruction { + opcode: "atan2" + shape { element_type: F32 } + } + clock_cycles: 486 + } + entries { + instruction { + opcode: "power" + shape { element_type: F32 } + } + clock_cycles: 925 + } + entries { + instruction { + opcode: "cbrt" + shape { element_type: F64 } + } + clock_cycles: 6339 + } + entries { + instruction { + opcode: "cosine" + shape { element_type: F64 } + } + clock_cycles: 1717 + } + entries { + instruction { + opcode: "exponential" + shape { element_type: F64 } + } + clock_cycles: 1652 + } + entries { + instruction { + opcode: "exponential-minus-one" + shape { element_type: F64 } + } + clock_cycles: 1900 + } + entries { + instruction { + opcode: "log" + shape { element_type: F64 } + } + clock_cycles: 608 + } + entries { + instruction { + opcode: "log-plus-one" + shape { element_type: F64 } + } + clock_cycles: 2073 + } + entries { + instruction { + opcode: "logistic" + shape { element_type: F64 } + } + clock_cycles: 2412 + } + entries { + instruction { + opcode: "rsqrt" + shape { element_type: F64 } + } + clock_cycles: 698 + } + entries { + instruction { + opcode: "sine" + shape { element_type: F64 } + } + clock_cycles: 1789 + } + entries { + instruction { + opcode: "sqrt" + shape { element_type: F64 } + } + clock_cycles: 986 + } + entries { + instruction { + opcode: "tanh" + shape { element_type: F64 } + } + clock_cycles: 1609 + } + entries { + instruction { + opcode: "add" + shape { element_type: F64 } + } + clock_cycles: 97 + } + entries { + instruction { + opcode: "atan2" + shape { element_type: F64 } + } + clock_cycles: 3747 + } + entries { + instruction { + opcode: "divide" + shape { element_type: F64 } + } + clock_cycles: 2016 + } + entries { + instruction { + opcode: "multiply" + shape { element_type: F64 } + } + clock_cycles: 97 + } + entries { + instruction { + opcode: "power" + shape { element_type: F64 } + } + clock_cycles: 5511 + } + entries { + instruction { + opcode: "subtract" + shape { element_type: F64 } + } + clock_cycles: 97 + } + entries { + instruction { + opcode: "exponential" + shape { element_type: C64 } + } + clock_cycles: 1360 + } + entries { + instruction { + opcode: "exponential-minus-one" + shape { element_type: C64 } + } + clock_cycles: 1400 + } + entries { + instruction { + opcode: "log" + shape { element_type: C64 } + } + clock_cycles: 950 + } + entries { + instruction { + opcode: "log-plus-one" + shape { element_type: C64 } + } + clock_cycles: 842 + } + entries { + instruction { + opcode: "rsqrt" + shape { element_type: C64 } + } + clock_cycles: 2383 + } + entries { + instruction { + opcode: "sqrt" + shape { element_type: C64 } + } + clock_cycles: 3193 + } + entries { + instruction { + opcode: "atan2" + shape { element_type: C64 } + } + clock_cycles: 5353 + } + entries { + instruction { + opcode: "divide" + shape { element_type: C64 } + } + clock_cycles: 687 + } + entries { + instruction { + opcode: "power" + shape { element_type: C64 } + } + clock_cycles: 3351 + } + entries { + instruction { + opcode: "cosine" + shape { element_type: C128 } + } + clock_cycles: 6613 + } + entries { + instruction { + opcode: "exponential" + shape { element_type: C128 } + } + clock_cycles: 4028 + } + entries { + instruction { + opcode: "exponential-minus-one" + shape { element_type: C128 } + } + clock_cycles: 4161 + } + entries { + instruction { + opcode: "log" + shape { element_type: C128 } + } + clock_cycles: 7599 + } + entries { + instruction { + opcode: "log-plus-one" + shape { element_type: C128 } + } + clock_cycles: 6962 + } + entries { + instruction { + opcode: "rsqrt" + shape { element_type: C128 } + } + clock_cycles: 11318 + } + entries { + instruction { + opcode: "sine" + shape { element_type: C128 } + } + clock_cycles: 5878 + } + entries { + instruction { + opcode: "sqrt" + shape { element_type: C128 } + } + clock_cycles: 15606 + } + entries { + instruction { + opcode: "tanh" + shape { element_type: C128 } + } + clock_cycles: 9939 + } + entries { + instruction { + opcode: "add" + shape { element_type: C128 } + } + clock_cycles: 97 + } + entries { + instruction { + opcode: "atan2" + shape { element_type: C128 } + } + clock_cycles: 39027 + } + entries { + instruction { + opcode: "divide" + shape { element_type: C128 } + } + clock_cycles: 7941 + } + entries { + instruction { + opcode: "multiply" + shape { element_type: C128 } + } + clock_cycles: 270 + } + entries { + instruction { + opcode: "power" + shape { element_type: C128 } + } + clock_cycles: 18205 + } + entries { + instruction { + opcode: "subtract" + shape { element_type: C128 } + } + clock_cycles: 97 + } + } + } + + entries { + key: "NVIDIA A100-SXM4-40GB" + value { + entries { + instruction { + opcode: "divide" + shape { element_type: S8 } + } + clock_cycles: 417 + } + entries { + instruction { + opcode: "divide" + shape { element_type: S16 } + } + clock_cycles: 468 + } + entries { + instruction { + opcode: "divide" + shape { element_type: S64 } + } + clock_cycles: 1094 + } + entries { + instruction { + opcode: "divide" + shape { element_type: U8 } + } + clock_cycles: 420 + } + entries { + instruction { + opcode: "power" + shape { element_type: U8 } + } + clock_cycles: 417 + } + entries { + instruction { + opcode: "divide" + shape { element_type: U16 } + } + clock_cycles: 391 + } + entries { + instruction { + opcode: "power" + shape { element_type: U16 } + } + clock_cycles: 454 + } + entries { + instruction { + opcode: "divide" + shape { element_type: U64 } + } + clock_cycles: 908 + } + entries { + instruction { + opcode: "power" + shape { element_type: U64 } + } + clock_cycles: 744 + } + entries { + instruction { + opcode: "cbrt" + shape { element_type: F16 } + } + clock_cycles: 1195 + } + entries { + instruction { + opcode: "log" + shape { element_type: F16 } + } + clock_cycles: 321 + } + entries { + instruction { + opcode: "log-plus-one" + shape { element_type: F16 } + } + clock_cycles: 346 + } + entries { + instruction { + opcode: "sqrt" + shape { element_type: F16 } + } + clock_cycles: 124 + } + entries { + instruction { + opcode: "tanh" + shape { element_type: F16 } + } + clock_cycles: 499 + } + entries { + instruction { + opcode: "log" + shape { element_type: F32 } + } + clock_cycles: 259 + } + entries { + instruction { + opcode: "tanh" + shape { element_type: F32 } + } + clock_cycles: 504 + } + entries { + instruction { + opcode: "power" + shape { element_type: F32 } + } + clock_cycles: 1221 + } + entries { + instruction { + opcode: "cbrt" + shape { element_type: F64 } + } + clock_cycles: 1638 + } + entries { + instruction { + opcode: "exponential-minus-one" + shape { element_type: F64 } + } + clock_cycles: 572 + } + entries { + instruction { + opcode: "log" + shape { element_type: F64 } + } + clock_cycles: 699 + } + entries { + instruction { + opcode: "log-plus-one" + shape { element_type: F64 } + } + clock_cycles: 1223 + } + entries { + instruction { + opcode: "rsqrt" + shape { element_type: F64 } + } + clock_cycles: 329 + } + entries { + instruction { + opcode: "sine" + shape { element_type: F64 } + } + clock_cycles: 597 + } + entries { + instruction { + opcode: "sqrt" + shape { element_type: F64 } + } + clock_cycles: 397 + } + entries { + instruction { + opcode: "tanh" + shape { element_type: F64 } + } + clock_cycles: 733 + } + entries { + instruction { + opcode: "atan2" + shape { element_type: F64 } + } + clock_cycles: 1080 + } + entries { + instruction { + opcode: "divide" + shape { element_type: F64 } + } + clock_cycles: 831 + } + entries { + instruction { + opcode: "power" + shape { element_type: F64 } + } + clock_cycles: 1861 + } + entries { + instruction { + opcode: "log" + shape { element_type: C64 } + } + clock_cycles: 1037 + } + entries { + instruction { + opcode: "log-plus-one" + shape { element_type: C64 } + } + clock_cycles: 1029 + } + entries { + instruction { + opcode: "atan2" + shape { element_type: C64 } + } + clock_cycles: 6618 + } + entries { + instruction { + opcode: "power" + shape { element_type: C64 } + } + clock_cycles: 4131 + } + entries { + instruction { + opcode: "cosine" + shape { element_type: C128 } + } + clock_cycles: 2309 + } + entries { + instruction { + opcode: "log" + shape { element_type: C128 } + } + clock_cycles: 2371 + } + entries { + instruction { + opcode: "log-plus-one" + shape { element_type: C128 } + } + clock_cycles: 2405 + } + entries { + instruction { + opcode: "rsqrt" + shape { element_type: C128 } + } + clock_cycles: 3945 + } + entries { + instruction { + opcode: "sine" + shape { element_type: C128 } + } + clock_cycles: 2284 + } + entries { + instruction { + opcode: "sqrt" + shape { element_type: C128 } + } + clock_cycles: 5304 + } + entries { + instruction { + opcode: "tanh" + shape { element_type: C128 } + } + clock_cycles: 3618 + } + entries { + instruction { + opcode: "atan2" + shape { element_type: C128 } + } + clock_cycles: 13564 + } + entries { + instruction { + opcode: "divide" + shape { element_type: C128 } + } + clock_cycles: 3037 + } + entries { + instruction { + opcode: "power" + shape { element_type: C128 } + } + clock_cycles: 6054 + } + } + } + + entries { + key: "Tesla V100-SXM2-16GB" + value { + entries { + instruction { + opcode: "divide" + shape { element_type: S8 } + } + clock_cycles: 345 + } + entries { + instruction { + opcode: "divide" + shape { element_type: S16 } + } + clock_cycles: 345 + } + entries { + instruction { + opcode: "divide" + shape { element_type: S64 } + } + clock_cycles: 954 + } + entries { + instruction { + opcode: "divide" + shape { element_type: U8 } + } + clock_cycles: 302 + } + entries { + instruction { + opcode: "power" + shape { element_type: U8 } + } + clock_cycles: 526 + } + entries { + instruction { + opcode: "divide" + shape { element_type: U16 } + } + clock_cycles: 309 + } + entries { + instruction { + opcode: "power" + shape { element_type: U16 } + } + clock_cycles: 544 + } + entries { + instruction { + opcode: "divide" + shape { element_type: U64 } + } + clock_cycles: 749 + } + entries { + instruction { + opcode: "power" + shape { element_type: U64 } + } + clock_cycles: 820 + } + entries { + instruction { + opcode: "cbrt" + shape { element_type: F16 } + } + clock_cycles: 1227 + } + entries { + instruction { + opcode: "cosine" + shape { element_type: F16 } + } + clock_cycles: 865 + } + entries { + instruction { + opcode: "exponential" + shape { element_type: F16 } + } + clock_cycles: 137 + } + entries { + instruction { + opcode: "exponential-minus-one" + shape { element_type: F16 } + } + clock_cycles: 544 + } + entries { + instruction { + opcode: "log" + shape { element_type: F16 } + } + clock_cycles: 354 + } + entries { + instruction { + opcode: "log-plus-one" + shape { element_type: F16 } + } + clock_cycles: 388 + } + entries { + instruction { + opcode: "rsqrt" + shape { element_type: F16 } + } + clock_cycles: 122 + } + entries { + instruction { + opcode: "sine" + shape { element_type: F16 } + } + clock_cycles: 841 + } + entries { + instruction { + opcode: "sqrt" + shape { element_type: F16 } + } + clock_cycles: 134 + } + entries { + instruction { + opcode: "tanh" + shape { element_type: F16 } + } + clock_cycles: 556 + } + entries { + instruction { + opcode: "power" + shape { element_type: F16 } + } + clock_cycles: 1279 + } + entries { + instruction { + opcode: "cbrt" + shape { element_type: F32 } + } + clock_cycles: 1168 + } + entries { + instruction { + opcode: "cosine" + shape { element_type: F32 } + } + clock_cycles: 823 + } + entries { + instruction { + opcode: "exponential" + shape { element_type: F32 } + } + clock_cycles: 110 + } + entries { + instruction { + opcode: "exponential-minus-one" + shape { element_type: F32 } + } + clock_cycles: 514 + } + entries { + instruction { + opcode: "log" + shape { element_type: F32 } + } + clock_cycles: 333 + } + entries { + instruction { + opcode: "log-plus-one" + shape { element_type: F32 } + } + clock_cycles: 361 + } + entries { + instruction { + opcode: "tanh" + shape { element_type: F32 } + } + clock_cycles: 529 + } + entries { + instruction { + opcode: "atan2" + shape { element_type: F32 } + } + clock_cycles: 660 + } + entries { + instruction { + opcode: "power" + shape { element_type: F32 } + } + clock_cycles: 1214 + } + entries { + instruction { + opcode: "cbrt" + shape { element_type: F64 } + } + clock_cycles: 1392 + } + entries { + instruction { + opcode: "cosine" + shape { element_type: F64 } + } + clock_cycles: 673 + } + entries { + instruction { + opcode: "exponential" + shape { element_type: F64 } + } + clock_cycles: 474 + } + entries { + instruction { + opcode: "exponential-minus-one" + shape { element_type: F64 } + } + clock_cycles: 676 + } + entries { + instruction { + opcode: "log" + shape { element_type: F64 } + } + clock_cycles: 618 + } + entries { + instruction { + opcode: "log-plus-one" + shape { element_type: F64 } + } + clock_cycles: 1061 + } + entries { + instruction { + opcode: "rsqrt" + shape { element_type: F64 } + } + clock_cycles: 290 + } + entries { + instruction { + opcode: "sine" + shape { element_type: F64 } + } + clock_cycles: 667 + } + entries { + instruction { + opcode: "sqrt" + shape { element_type: F64 } + } + clock_cycles: 391 + } + entries { + instruction { + opcode: "tanh" + shape { element_type: F64 } + } + clock_cycles: 709 + } + entries { + instruction { + opcode: "atan2" + shape { element_type: F64 } + } + clock_cycles: 1178 + } + entries { + instruction { + opcode: "divide" + shape { element_type: F64 } + } + clock_cycles: 682 + } + entries { + instruction { + opcode: "power" + shape { element_type: F64 } + } + clock_cycles: 1679 + } + entries { + instruction { + opcode: "cosine" + shape { element_type: C64 } + } + clock_cycles: 1762 + } + entries { + instruction { + opcode: "log" + shape { element_type: C64 } + } + clock_cycles: 1450 + } + entries { + instruction { + opcode: "log-plus-one" + shape { element_type: C64 } + } + clock_cycles: 1141 + } + entries { + instruction { + opcode: "sine" + shape { element_type: C64 } + } + clock_cycles: 1787 + } + entries { + instruction { + opcode: "sqrt" + shape { element_type: C64 } + } + clock_cycles: 3935 + } + entries { + instruction { + opcode: "atan2" + shape { element_type: C64 } + } + clock_cycles: 7025 + } + entries { + instruction { + opcode: "divide" + shape { element_type: C64 } + } + clock_cycles: 948 + } + entries { + instruction { + opcode: "power" + shape { element_type: C64 } + } + clock_cycles: 4277 + } + entries { + instruction { + opcode: "cosine" + shape { element_type: C128 } + } + clock_cycles: 2386 + } + entries { + instruction { + opcode: "exponential" + shape { element_type: C128 } + } + clock_cycles: 1881 + } + entries { + instruction { + opcode: "exponential-minus-one" + shape { element_type: C128 } + } + clock_cycles: 1875 + } + entries { + instruction { + opcode: "log" + shape { element_type: C128 } + } + clock_cycles: 2622 + } + entries { + instruction { + opcode: "log-plus-one" + shape { element_type: C128 } + } + clock_cycles: 2328 + } + entries { + instruction { + opcode: "rsqrt" + shape { element_type: C128 } + } + clock_cycles: 4531 + } + entries { + instruction { + opcode: "sine" + shape { element_type: C128 } + } + clock_cycles: 2408 + } + entries { + instruction { + opcode: "sqrt" + shape { element_type: C128 } + } + clock_cycles: 5388 + } + entries { + instruction { + opcode: "tanh" + shape { element_type: C128 } + } + clock_cycles: 3867 + } + entries { + instruction { + opcode: "atan2" + shape { element_type: C128 } + } + clock_cycles: 13794 + } + entries { + instruction { + opcode: "divide" + shape { element_type: C128 } + } + clock_cycles: 3001 + } + entries { + instruction { + opcode: "power" + shape { element_type: C128 } + } + clock_cycles: 6046 + } + } + } + + entries { + key: "Tesla P100-SXM2-16GB" + value { + entries { + instruction { + opcode: "divide" + shape { element_type: S8 } + } + clock_cycles: 438 + } + entries { + instruction { + opcode: "divide" + shape { element_type: S16 } + } + clock_cycles: 479 + } + entries { + instruction { + opcode: "divide" + shape { element_type: S32 } + } + clock_cycles: 758 + } + entries { + instruction { + opcode: "divide" + shape { element_type: S64 } + } + clock_cycles: 2037 + } + entries { + instruction { + opcode: "power" + shape { element_type: S64 } + } + clock_cycles: 2937 + } + entries { + instruction { + opcode: "divide" + shape { element_type: U8 } + } + clock_cycles: 307 + } + entries { + instruction { + opcode: "divide" + shape { element_type: U16 } + } + clock_cycles: 293 + } + entries { + instruction { + opcode: "divide" + shape { element_type: U64 } + } + clock_cycles: 1708 + } + entries { + instruction { + opcode: "power" + shape { element_type: U64 } + } + clock_cycles: 2993 + } + entries { + instruction { + opcode: "cbrt" + shape { element_type: F16 } + } + clock_cycles: 1661 + } + entries { + instruction { + opcode: "exponential" + shape { element_type: F16 } + } + clock_cycles: 213 + } + entries { + instruction { + opcode: "exponential-minus-one" + shape { element_type: F16 } + } + clock_cycles: 778 + } + entries { + instruction { + opcode: "log" + shape { element_type: F16 } + } + clock_cycles: 598 + } + entries { + instruction { + opcode: "log-plus-one" + shape { element_type: F16 } + } + clock_cycles: 538 + } + entries { + instruction { + opcode: "logistic" + shape { element_type: F16 } + } + clock_cycles: 402 + } + entries { + instruction { + opcode: "rsqrt" + shape { element_type: F16 } + } + clock_cycles: 130 + } + entries { + instruction { + opcode: "tanh" + shape { element_type: F16 } + } + clock_cycles: 453 + } + entries { + instruction { + opcode: "power" + shape { element_type: F16 } + } + clock_cycles: 1717 + } + entries { + instruction { + opcode: "cbrt" + shape { element_type: F32 } + } + clock_cycles: 1672 + } + entries { + instruction { + opcode: "exponential" + shape { element_type: F32 } + } + clock_cycles: 168 + } + entries { + instruction { + opcode: "exponential-minus-one" + shape { element_type: F32 } + } + clock_cycles: 731 + } + entries { + instruction { + opcode: "log" + shape { element_type: F32 } + } + clock_cycles: 435 + } + entries { + instruction { + opcode: "log-plus-one" + shape { element_type: F32 } + } + clock_cycles: 589 + } + entries { + instruction { + opcode: "logistic" + shape { element_type: F32 } + } + clock_cycles: 343 + } + entries { + instruction { + opcode: "sine" + shape { element_type: F32 } + } + clock_cycles: 1024 + } + entries { + instruction { + opcode: "tanh" + shape { element_type: F32 } + } + clock_cycles: 417 + } + entries { + instruction { + opcode: "atan2" + shape { element_type: F32 } + } + clock_cycles: 873 + } + entries { + instruction { + opcode: "power" + shape { element_type: F32 } + } + clock_cycles: 1779 + } + entries { + instruction { + opcode: "cbrt" + shape { element_type: F64 } + } + clock_cycles: 1649 + } + entries { + instruction { + opcode: "cosine" + shape { element_type: F64 } + } + clock_cycles: 1175 + } + entries { + instruction { + opcode: "exponential" + shape { element_type: F64 } + } + clock_cycles: 639 + } + entries { + instruction { + opcode: "exponential-minus-one" + shape { element_type: F64 } + } + clock_cycles: 911 + } + entries { + instruction { + opcode: "log" + shape { element_type: F64 } + } + clock_cycles: 935 + } + entries { + instruction { + opcode: "log-plus-one" + shape { element_type: F64 } + } + clock_cycles: 1421 + } + entries { + instruction { + opcode: "logistic" + shape { element_type: F64 } + } + clock_cycles: 1098 + } + entries { + instruction { + opcode: "rsqrt" + shape { element_type: F64 } + } + clock_cycles: 355 + } + entries { + instruction { + opcode: "sine" + shape { element_type: F64 } + } + clock_cycles: 1187 + } + entries { + instruction { + opcode: "sqrt" + shape { element_type: F64 } + } + clock_cycles: 645 + } + entries { + instruction { + opcode: "tanh" + shape { element_type: F64 } + } + clock_cycles: 917 + } + entries { + instruction { + opcode: "atan2" + shape { element_type: F64 } + } + clock_cycles: 1394 + } + entries { + instruction { + opcode: "divide" + shape { element_type: F64 } + } + clock_cycles: 959 + } + entries { + instruction { + opcode: "power" + shape { element_type: F64 } + } + clock_cycles: 2667 + } + entries { + instruction { + opcode: "log" + shape { element_type: C64 } + } + clock_cycles: 1726 + } + entries { + instruction { + opcode: "log-plus-one" + shape { element_type: C64 } + } + clock_cycles: 1518 + } + entries { + instruction { + opcode: "rsqrt" + shape { element_type: C64 } + } + clock_cycles: 4142 + } + entries { + instruction { + opcode: "sqrt" + shape { element_type: C64 } + } + clock_cycles: 5069 + } + entries { + instruction { + opcode: "tanh" + shape { element_type: C64 } + } + clock_cycles: 4053 + } + entries { + instruction { + opcode: "atan2" + shape { element_type: C64 } + } + clock_cycles: 9469 + } + entries { + instruction { + opcode: "divide" + shape { element_type: C64 } + } + clock_cycles: 1317 + } + entries { + instruction { + opcode: "power" + shape { element_type: C64 } + } + clock_cycles: 5617 + } + entries { + instruction { + opcode: "cosine" + shape { element_type: C128 } + } + clock_cycles: 3416 + } + entries { + instruction { + opcode: "exponential" + shape { element_type: C128 } + } + clock_cycles: 2730 + } + entries { + instruction { + opcode: "exponential-minus-one" + shape { element_type: C128 } + } + clock_cycles: 2765 + } + entries { + instruction { + opcode: "log" + shape { element_type: C128 } + } + clock_cycles: 3106 + } + entries { + instruction { + opcode: "log-plus-one" + shape { element_type: C128 } + } + clock_cycles: 2895 + } + entries { + instruction { + opcode: "rsqrt" + shape { element_type: C128 } + } + clock_cycles: 5922 + } + entries { + instruction { + opcode: "sine" + shape { element_type: C128 } + } + clock_cycles: 3496 + } + entries { + instruction { + opcode: "sqrt" + shape { element_type: C128 } + } + clock_cycles: 7014 + } + entries { + instruction { + opcode: "tanh" + shape { element_type: C128 } + } + clock_cycles: 5400 + } + entries { + instruction { + opcode: "atan2" + shape { element_type: C128 } + } + clock_cycles: 21766 + } + entries { + instruction { + opcode: "divide" + shape { element_type: C128 } + } + clock_cycles: 4133 + } + entries { + instruction { + opcode: "power" + shape { element_type: C128 } + } + clock_cycles: 10458 + } + } + } + + entries { + key: "NVIDIA TU-AUTO-PROD" + value { + entries { + instruction { + opcode: "divide" + shape { element_type: S8 } + } + clock_cycles: 360 + } + entries { + instruction { + opcode: "power" + shape { element_type: S8 } + } + clock_cycles: 336 + } + entries { + instruction { + opcode: "divide" + shape { element_type: S16 } + } + clock_cycles: 357 + } + entries { + instruction { + opcode: "power" + shape { element_type: S16 } + } + clock_cycles: 339 + } + entries { + instruction { + opcode: "divide" + shape { element_type: S32 } + } + clock_cycles: 296 + } + entries { + instruction { + opcode: "divide" + shape { element_type: S64 } + } + clock_cycles: 979 + } + entries { + instruction { + opcode: "power" + shape { element_type: S64 } + } + clock_cycles: 495 + } + entries { + instruction { + opcode: "divide" + shape { element_type: U8 } + } + clock_cycles: 293 + } + entries { + instruction { + opcode: "power" + shape { element_type: U8 } + } + clock_cycles: 334 + } + entries { + instruction { + opcode: "divide" + shape { element_type: U16 } + } + clock_cycles: 290 + } + entries { + instruction { + opcode: "power" + shape { element_type: U16 } + } + clock_cycles: 336 + } + entries { + instruction { + opcode: "divide" + shape { element_type: U32 } + } + clock_cycles: 118 + } + entries { + instruction { + opcode: "divide" + shape { element_type: U64 } + } + clock_cycles: 812 + } + entries { + instruction { + opcode: "power" + shape { element_type: U64 } + } + clock_cycles: 515 + } + entries { + instruction { + opcode: "cbrt" + shape { element_type: F16 } + } + clock_cycles: 792 + } + entries { + instruction { + opcode: "cosine" + shape { element_type: F16 } + } + clock_cycles: 815 + } + entries { + instruction { + opcode: "exponential" + shape { element_type: F16 } + } + clock_cycles: 132 + } + entries { + instruction { + opcode: "exponential-minus-one" + shape { element_type: F16 } + } + clock_cycles: 342 + } + entries { + instruction { + opcode: "log" + shape { element_type: F16 } + } + clock_cycles: 239 + } + entries { + instruction { + opcode: "log-plus-one" + shape { element_type: F16 } + } + clock_cycles: 239 + } + entries { + instruction { + opcode: "logistic" + shape { element_type: F16 } + } + clock_cycles: 262 + } + entries { + instruction { + opcode: "rsqrt" + shape { element_type: F16 } + } + clock_cycles: 126 + } + entries { + instruction { + opcode: "sine" + shape { element_type: F16 } + } + clock_cycles: 794 + } + entries { + instruction { + opcode: "sqrt" + shape { element_type: F16 } + } + clock_cycles: 123 + } + entries { + instruction { + opcode: "tanh" + shape { element_type: F16 } + } + clock_cycles: 175 + } + entries { + instruction { + opcode: "atan2" + shape { element_type: F16 } + } + clock_cycles: 414 + } + entries { + instruction { + opcode: "divide" + shape { element_type: F16 } + } + clock_cycles: 74 + } + entries { + instruction { + opcode: "power" + shape { element_type: F16 } + } + clock_cycles: 1120 + } + entries { + instruction { + opcode: "cbrt" + shape { element_type: F32 } + } + clock_cycles: 783 + } + entries { + instruction { + opcode: "cosine" + shape { element_type: F32 } + } + clock_cycles: 737 + } + entries { + instruction { + opcode: "exponential" + shape { element_type: F32 } + } + clock_cycles: 83 + } + entries { + instruction { + opcode: "exponential-minus-one" + shape { element_type: F32 } + } + clock_cycles: 319 + } + entries { + instruction { + opcode: "log" + shape { element_type: F32 } + } + clock_cycles: 201 + } + entries { + instruction { + opcode: "log-plus-one" + shape { element_type: F32 } + } + clock_cycles: 218 + } + entries { + instruction { + opcode: "logistic" + shape { element_type: F32 } + } + clock_cycles: 181 + } + entries { + instruction { + opcode: "rsqrt" + shape { element_type: F32 } + } + clock_cycles: 74 + } + entries { + instruction { + opcode: "sine" + shape { element_type: F32 } + } + clock_cycles: 717 + } + entries { + instruction { + opcode: "sqrt" + shape { element_type: F32 } + } + clock_cycles: 74 + } + entries { + instruction { + opcode: "tanh" + shape { element_type: F32 } + } + clock_cycles: 167 + } + entries { + instruction { + opcode: "atan2" + shape { element_type: F32 } + } + clock_cycles: 414 + } + entries { + instruction { + opcode: "power" + shape { element_type: F32 } + } + clock_cycles: 1085 + } + entries { + instruction { + opcode: "cbrt" + shape { element_type: F64 } + } + clock_cycles: 6494 + } + entries { + instruction { + opcode: "cosine" + shape { element_type: F64 } + } + clock_cycles: 1800 + } + entries { + instruction { + opcode: "exponential" + shape { element_type: F64 } + } + clock_cycles: 1630 + } + entries { + instruction { + opcode: "exponential-minus-one" + shape { element_type: F64 } + } + clock_cycles: 1929 + } + entries { + instruction { + opcode: "log" + shape { element_type: F64 } + } + clock_cycles: 596 + } + entries { + instruction { + opcode: "log-plus-one" + shape { element_type: F64 } + } + clock_cycles: 1774 + } + entries { + instruction { + opcode: "logistic" + shape { element_type: F64 } + } + clock_cycles: 2430 + } + entries { + instruction { + opcode: "rsqrt" + shape { element_type: F64 } + } + clock_cycles: 705 + } + entries { + instruction { + opcode: "sine" + shape { element_type: F64 } + } + clock_cycles: 1805 + } + entries { + instruction { + opcode: "sqrt" + shape { element_type: F64 } + } + clock_cycles: 984 + } + entries { + instruction { + opcode: "tanh" + shape { element_type: F64 } + } + clock_cycles: 1535 + } + entries { + instruction { + opcode: "add" + shape { element_type: F64 } + } + clock_cycles: 95 + } + entries { + instruction { + opcode: "atan2" + shape { element_type: F64 } + } + clock_cycles: 3744 + } + entries { + instruction { + opcode: "divide" + shape { element_type: F64 } + } + clock_cycles: 1915 + } + entries { + instruction { + opcode: "multiply" + shape { element_type: F64 } + } + clock_cycles: 95 + } + entries { + instruction { + opcode: "power" + shape { element_type: F64 } + } + clock_cycles: 5538 + } + entries { + instruction { + opcode: "subtract" + shape { element_type: F64 } + } + clock_cycles: 95 + } + entries { + instruction { + opcode: "cosine" + shape { element_type: C64 } + } + clock_cycles: 1702 + } + entries { + instruction { + opcode: "exponential" + shape { element_type: C64 } + } + clock_cycles: 1503 + } + entries { + instruction { + opcode: "exponential-minus-one" + shape { element_type: C64 } + } + clock_cycles: 1474 + } + entries { + instruction { + opcode: "log" + shape { element_type: C64 } + } + clock_cycles: 835 + } + entries { + instruction { + opcode: "log-plus-one" + shape { element_type: C64 } + } + clock_cycles: 737 + } + entries { + instruction { + opcode: "rsqrt" + shape { element_type: C64 } + } + clock_cycles: 2232 + } + entries { + instruction { + opcode: "sine" + shape { element_type: C64 } + } + clock_cycles: 1632 + } + entries { + instruction { + opcode: "sqrt" + shape { element_type: C64 } + } + clock_cycles: 2989 + } + entries { + instruction { + opcode: "tanh" + shape { element_type: C64 } + } + clock_cycles: 2263 + } + entries { + instruction { + opcode: "atan2" + shape { element_type: C64 } + } + clock_cycles: 4847 + } + entries { + instruction { + opcode: "power" + shape { element_type: C64 } + } + clock_cycles: 3219 + } + entries { + instruction { + opcode: "cosine" + shape { element_type: C128 } + } + clock_cycles: 6474 + } + entries { + instruction { + opcode: "exponential" + shape { element_type: C128 } + } + clock_cycles: 4962 + } + entries { + instruction { + opcode: "exponential-minus-one" + shape { element_type: C128 } + } + clock_cycles: 4037 + } + entries { + instruction { + opcode: "log" + shape { element_type: C128 } + } + clock_cycles: 7286 + } + entries { + instruction { + opcode: "log-plus-one" + shape { element_type: C128 } + } + clock_cycles: 6848 + } + entries { + instruction { + opcode: "rsqrt" + shape { element_type: C128 } + } + clock_cycles: 10748 + } + entries { + instruction { + opcode: "sine" + shape { element_type: C128 } + } + clock_cycles: 5391 + } + entries { + instruction { + opcode: "sqrt" + shape { element_type: C128 } + } + clock_cycles: 15981 + } + entries { + instruction { + opcode: "tanh" + shape { element_type: C128 } + } + clock_cycles: 9653 + } + entries { + instruction { + opcode: "add" + shape { element_type: C128 } + } + clock_cycles: 95 + } + entries { + instruction { + opcode: "atan2" + shape { element_type: C128 } + } + clock_cycles: 38206 + } + entries { + instruction { + opcode: "divide" + shape { element_type: C128 } + } + clock_cycles: 8040 + } + entries { + instruction { + opcode: "multiply" + shape { element_type: C128 } + } + clock_cycles: 273 + } + entries { + instruction { + opcode: "power" + shape { element_type: C128 } + } + clock_cycles: 18550 + } + entries { + instruction { + opcode: "subtract" + shape { element_type: C128 } + } + clock_cycles: 97 + } + } + } +)pb"; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_OP_PROFILES_H_ diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc index 4af685bf3afac9..1add20eb6e9b23 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc @@ -176,7 +176,7 @@ FusionDecision OperandReachableFromProducer( std::vector GetProducerConsumerMultiOutputFusionCandidates( const HloInstruction* producer, const HloReachabilityMap& reachability, FusionInfoCache* fusion_info_cache, GpuHloCostAnalysis* cost_analysis, - const GpuDeviceInfo& device_info, se::CudaComputeCapability cc) { + se::CudaComputeCapability cc) { std::vector fusion_candidates; const HloComputation* computation = producer->parent(); const HloModule* module = computation->parent(); @@ -206,7 +206,8 @@ std::vector GetProducerConsumerMultiOutputFusionCandidates( }, &ShapesCompatibleForMultiOutputFusion, std::bind(OperandReachableFromProducer, _1, _2, std::cref(reachability)), - std::bind(FusionFitsInBudget, _1, _2, std::cref(device_info), + std::bind(FusionFitsInBudget, _1, _2, + std::cref(*cost_analysis->device_info_), /*is_consumer_producer_fusion=*/false, fusion_info_cache), [&](const HloInstruction& producer, const HloInstruction& consumer) -> FusionDecision { @@ -222,8 +223,7 @@ std::vector GetProducerConsumerMultiOutputFusionCandidates( .debug_options() .xla_gpu_enable_experimental_block_size(); GpuPerformanceModel::RunTimes t = GpuPerformanceModel::EstimateRunTimes( - &producer, cost_analysis, device_info, use_experimental_block_size, - cc, + &producer, cost_analysis, use_experimental_block_size, cc, // `EstimateRunTimes`'s interface violates const correctness, so we // need the const cast here. {const_cast(&consumer)}, @@ -312,7 +312,7 @@ bool GpuMultiOutputFusion::FuseSiblings(HloInstruction* parent, // practice. std::bind(ParameterSlicesAreNonOverlapping, _1, _2, parent), // This check should be last, as it may be expensive. - std::bind(LegalToFuse, _1, _2, std::cref(device_info_), + std::bind(LegalToFuse, _1, _2, std::cref(*cost_analysis->device_info_), fusion_info_cache)}; for (auto i = siblings.begin(); i != siblings.end(); ++i) { VLOG(3) << "Considering " << (*i)->name(); @@ -385,7 +385,8 @@ StatusOr GpuMultiOutputFusion::DoMultiOutputFusion() { RecomputeReachability(); GpuHloCostAnalysis cost_analysis({shape_size_function_, /*per_second_rates=*/{}, - /*count_multiple_input_accesses=*/true}); + /*count_multiple_input_accesses=*/true}, + &device_info_); TF_RETURN_IF_ERROR(computation_->Accept(&cost_analysis)); std::vector defs_before_uses = computation_->MakeInstructionPostOrder(); @@ -415,7 +416,7 @@ StatusOr GpuMultiOutputFusion::DoMultiOutputFusion() { // traversal, and hence, not get into the way of subsequent fusion attempts. const auto candidates = GetProducerConsumerMultiOutputFusionCandidates( producer, *reachability_, &fusion_info_cache, &cost_analysis, - device_info_, compute_capability_); + compute_capability_); auto* consumer_for_fusion = SelectPreferredFusionCandidate(candidates); if (consumer_for_fusion == nullptr) { continue; @@ -457,7 +458,7 @@ StatusOr GpuMultiOutputFusion::DoMultiOutputFusion() { "| inside GPU multi-output fusion")); RecomputeReachability(); GpuPerformanceModel::RecordEstimatedRunTime(consumer_for_fusion, - &cost_analysis, device_info_); + &cost_analysis); continue; } HloInstruction* input_fusion = @@ -488,8 +489,7 @@ StatusOr GpuMultiOutputFusion::DoMultiOutputFusion() { absl::StrCat("Fusing producer |", producer_name, "| into consumer |", input_fusion->name(), "| inside GPU multi-output fusion")); RecomputeReachability(); - GpuPerformanceModel::RecordEstimatedRunTime(input_fusion, &cost_analysis, - device_info_); + GpuPerformanceModel::RecordEstimatedRunTime(input_fusion, &cost_analysis); } return changed; } diff --git a/tensorflow/compiler/xla/service/gpu/priority_fusion.cc b/tensorflow/compiler/xla/service/gpu/priority_fusion.cc index 1abc3fd48ea9c1..2e258ce5f6cc9a 100644 --- a/tensorflow/compiler/xla/service/gpu/priority_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/priority_fusion.cc @@ -40,7 +40,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h" #include "tensorflow/compiler/xla/service/gpu/gpu_hlo_cost_analysis.h" #include "tensorflow/compiler/xla/service/gpu/gpu_performance_model.h" -#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" #include "tensorflow/compiler/xla/service/instruction_fusion.h" #include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -70,11 +69,9 @@ class GpuPriorityFusionQueue : public FusionQueue { public: GpuPriorityFusionQueue( - HloComputation* computation, const GpuDeviceInfo& d, - GpuHloCostAnalysis* cost_analysis, + HloComputation* computation, GpuHloCostAnalysis* cost_analysis, const std::function& can_fuse) : computation_(computation), - gpu_device_info_(d), cost_analysis_(cost_analysis), can_fuse_(can_fuse) { VLOG(2) << "Running full HLO cost analysis for " << computation_->name(); @@ -243,9 +240,9 @@ class GpuPriorityFusionQueue : public FusionQueue { } GpuPerformanceModel::RunTimes run_times = - GpuPerformanceModel::EstimateRunTimes( - producer, cost_analysis_, gpu_device_info_, - use_experimental_block_size, std::nullopt, fusible_users); + GpuPerformanceModel::EstimateRunTimes(producer, cost_analysis_, + use_experimental_block_size, + std::nullopt, fusible_users); return absl::ToInt64Nanoseconds(run_times.time_unfused - run_times.time_fused); } @@ -273,9 +270,6 @@ class GpuPriorityFusionQueue : public FusionQueue { // Store computation for cost analysis. HloComputation* computation_; - // Data that describes the execution target. - const GpuDeviceInfo gpu_device_info_; - // Reference to cost model that defines priorities in the queue. GpuHloCostAnalysis* cost_analysis_; @@ -374,7 +368,8 @@ FusionDecision GpuPriorityFusion::ShouldFuse(HloInstruction* consumer, auto producer = consumer->operand(operand_index); // The following checks are potentially expensive. - if (auto fusible = FusionFitsInBudget(*consumer, *producer, device_info_, + if (auto fusible = FusionFitsInBudget(*consumer, *producer, + *cost_analysis_->device_info_, /*is_consumer_producer_fusion=*/true); !fusible) { return fusible; @@ -408,14 +403,14 @@ HloInstruction* GpuPriorityFusion::FuseInstruction( result = InstructionFusion::FuseInstruction(fusion_instruction, producer); } GpuPerformanceModel::RecordEstimatedRunTime(fusion_instruction, - &*cost_analysis_, device_info_); + &*cost_analysis_); return result; } std::unique_ptr GpuPriorityFusion::GetFusionQueue( HloComputation* computation) { return std::unique_ptr(new GpuPriorityFusionQueue( - computation, device_info_, &*cost_analysis_, + computation, &*cost_analysis_, [this](HloInstruction* consumer, int64_t operand_index) { return ShouldFuse(consumer, operand_index).CanFuse(); })); diff --git a/tensorflow/compiler/xla/service/gpu/priority_fusion.h b/tensorflow/compiler/xla/service/gpu/priority_fusion.h index 92f67fcd629d90..9f737e97c508ea 100644 --- a/tensorflow/compiler/xla/service/gpu/priority_fusion.h +++ b/tensorflow/compiler/xla/service/gpu/priority_fusion.h @@ -21,15 +21,12 @@ limitations under the License. #include #include -#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_computation.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_module.h" -#include "tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.h" #include "tensorflow/compiler/xla/service/fusion_queue.h" -#include "tensorflow/compiler/xla/service/gpu/gpu_device_info.h" #include "tensorflow/compiler/xla/service/gpu/gpu_hlo_cost_analysis.h" #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" @@ -54,7 +51,7 @@ class GpuPriorityFusion : public InstructionFusion { StatusOr Run(HloModule* module, const absl::flat_hash_set& execution_threads) override { - cost_analysis_.emplace(cost_analysis_options_); + cost_analysis_.emplace(cost_analysis_options_, &device_info_); return InstructionFusion::Run(module, execution_threads); } diff --git a/tensorflow/compiler/xla/xla.bzl b/tensorflow/compiler/xla/xla.bzl index 603469afa22efb..02f209e8a574da 100644 --- a/tensorflow/compiler/xla/xla.bzl +++ b/tensorflow/compiler/xla/xla.bzl @@ -61,6 +61,7 @@ def xla_cc_binary(deps = None, copts = tsl_copts(), **kwargs): "//tensorflow/compiler/xla/service:hlo_proto_cc_impl", "//tensorflow/compiler/xla/service:memory_space_assignment_proto_cc_impl", "//tensorflow/compiler/xla/service/gpu:backend_configs_cc_impl", + "//tensorflow/compiler/xla/service/gpu:hlo_op_profile_proto_cc_impl", "//tensorflow/compiler/xla/stream_executor:dnn_proto_cc_impl", "//tensorflow/tsl/platform:env_impl", "//tensorflow/tsl/platform:tensor_float_32_utils", @@ -92,6 +93,7 @@ def xla_cc_test( clean_dep("//tensorflow/compiler/xla/service:hlo_proto_cc_impl"), clean_dep("//tensorflow/compiler/xla/service:memory_space_assignment_proto_cc_impl"), clean_dep("//tensorflow/compiler/xla/service/gpu:backend_configs_cc_impl"), + clean_dep("//tensorflow/compiler/xla/service/gpu:hlo_op_profile_proto_cc_impl"), clean_dep("//tensorflow/compiler/xla/stream_executor:dnn_proto_cc_impl"), clean_dep("//tensorflow/compiler/xla/stream_executor:stream_executor_impl"), clean_dep("//tensorflow/compiler/xla/stream_executor:device_id_utils"), diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 18c419a7623c8e..cf1fe084ebae7b 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -1478,6 +1478,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_proto_cc_impl", "//tensorflow/compiler/xla/service:memory_space_assignment_proto_cc_impl", "//tensorflow/compiler/xla/service/gpu:backend_configs_cc_impl", + "//tensorflow/compiler/xla/service/gpu:hlo_op_profile_proto_cc_impl", ] + tf_protos_grappler_impl() + tf_monitoring_framework_deps(), # Alwayslink causes a cc_binary to "always link" in the # srcs for a given cc_library, even if they are unreferenced, see: From 6e1b56b12a93c5c093b6f78b723276c6b4bd2121 Mon Sep 17 00:00:00 2001 From: Jieying Luo Date: Tue, 25 Jul 2023 13:54:19 -0700 Subject: [PATCH 137/410] [PJRT C API] Increase the minor version to reflect the recent change to add PJRT_CopyToDeviceStream_Destroy. PiperOrigin-RevId: 550987203 --- tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h index 3717879159e73c..76b01484a80757 100644 --- a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h +++ b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h @@ -53,7 +53,7 @@ extern "C" { // Changes include: // * Adding a new field to the PJRT_Api or argument structs // * Renaming a method or argument (doesn't affect ABI) -#define PJRT_API_MINOR 11 +#define PJRT_API_MINOR 12 // The plugin should set the major_version and minor_version of // PJRT_Api.pjrt_api_version to be the `PJRT_API_MAJOR` and `PJRT_API_MINOR` in From b534f3a1fb6e9cbd7746b76af3fa916509f9d0ea Mon Sep 17 00:00:00 2001 From: Srinivas Vasudevan Date: Tue, 25 Jul 2023 14:05:19 -0700 Subject: [PATCH 138/410] Add `LinearOperatorAdjoint` and `LinearOperatorInversion` to their corresponding module-specific `__all__`. * This will help fix a downstream TFP issue. PiperOrigin-RevId: 550990706 --- tensorflow/python/ops/linalg/linear_operator_adjoint.py | 2 +- tensorflow/python/ops/linalg/linear_operator_inversion.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/ops/linalg/linear_operator_adjoint.py b/tensorflow/python/ops/linalg/linear_operator_adjoint.py index 4bc323a1dd576a..d765ea2f077c1f 100644 --- a/tensorflow/python/ops/linalg/linear_operator_adjoint.py +++ b/tensorflow/python/ops/linalg/linear_operator_adjoint.py @@ -22,7 +22,7 @@ from tensorflow.python.ops.linalg import linear_operator_util from tensorflow.python.util.tf_export import tf_export -__all__ = [] +__all__ = ["LinearOperatorAdjoint"] @tf_export("linalg.LinearOperatorAdjoint") diff --git a/tensorflow/python/ops/linalg/linear_operator_inversion.py b/tensorflow/python/ops/linalg/linear_operator_inversion.py index 21c02e039f536e..75f522e0c3cfb4 100644 --- a/tensorflow/python/ops/linalg/linear_operator_inversion.py +++ b/tensorflow/python/ops/linalg/linear_operator_inversion.py @@ -19,7 +19,7 @@ from tensorflow.python.ops.linalg import linear_operator_util from tensorflow.python.util.tf_export import tf_export -__all__ = [] +__all__ = ["LinearOperatorInversion"] @tf_export("linalg.LinearOperatorInversion") From 34c678bb3fcc84354fccaba0d1a9dff8b22fdc24 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 25 Jul 2023 14:13:07 -0700 Subject: [PATCH 139/410] Add a new `default_memory_space` method to `PjRtDevice` so that it can be used by IFRT and Jax frontend. Each platform/backend need to implement this method after memories are supported on that platform/backend. PiperOrigin-RevId: 550993089 --- tensorflow/compiler/xla/pjrt/pjrt_c_api_client.h | 4 ++++ tensorflow/compiler/xla/pjrt/pjrt_client.h | 3 +++ .../compiler/xla/pjrt/pjrt_stream_executor_client.cc | 5 +++++ tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h | 2 ++ tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.cc | 4 ++++ tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.h | 2 ++ tensorflow/compiler/xla/python/ifrt/mock.cc | 3 +++ tensorflow/compiler/xla/python/ifrt/mock.h | 2 ++ tensorflow/compiler/xla/python/py_compile_only_client.cc | 3 +++ tensorflow/compiler/xla/python/tpu_driver/client/BUILD | 2 -- .../compiler/xla/python/tpu_driver/client/tpu_client.h | 5 +++++ tensorflow/compiler/xla/python/xla.cc | 7 +++++++ tensorflow/compiler/xla/python/xla_client.py | 2 +- tensorflow/compiler/xla/python/xla_extension/__init__.pyi | 1 + 14 files changed, 42 insertions(+), 3 deletions(-) diff --git a/tensorflow/compiler/xla/pjrt/pjrt_c_api_client.h b/tensorflow/compiler/xla/pjrt/pjrt_c_api_client.h index a7138dd0034b45..38a9350d6a3a71 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_c_api_client.h +++ b/tensorflow/compiler/xla/pjrt/pjrt_c_api_client.h @@ -79,6 +79,10 @@ class PjRtCApiDevice : public PjRtDevice { return Unimplemented("PJRT C API does not support TransferFromOutfeed"); } + StatusOr default_memory_space() const override { + return Unimplemented("PJRT C API does not support default_memory_space"); + } + std::unique_ptr CreateAsyncTrackingEvent( absl::string_view description) const override { LOG(FATAL) << "PJRT C API does not support CreateAsyncTrackingEvent"; diff --git a/tensorflow/compiler/xla/pjrt/pjrt_client.h b/tensorflow/compiler/xla/pjrt/pjrt_client.h index a8f7d942126848..2aefb40a4996e8 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_client.h +++ b/tensorflow/compiler/xla/pjrt/pjrt_client.h @@ -176,6 +176,9 @@ class PjRtDevice { virtual absl::Span memory_spaces() const { return {}; } + + // Returns the default memory space attached to this device. + virtual StatusOr default_memory_space() const = 0; }; // Forward declaration. diff --git a/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc b/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc index 3988575f923149..90e1a51b3e5dd2 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc +++ b/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc @@ -1140,6 +1140,11 @@ Status PjRtStreamExecutorDevice::TransferFromOutfeed( local_device->device_ordinal(), literal); } +StatusOr PjRtStreamExecutorDevice::default_memory_space() + const { + return Unimplemented("default_memory_space is not supported."); +} + StatusOr PjRtStreamExecutorClient::LookupAddressableDevice( int local_hardware_id) const { for (auto* device : addressable_devices_) { diff --git a/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h b/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h index 185120813ab5ba..9032d174621094 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h +++ b/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h @@ -157,6 +157,8 @@ class PjRtStreamExecutorDevice : public PjRtDevice { Status TransferFromOutfeed(MutableBorrowingLiteral literal) override; + StatusOr default_memory_space() const override; + std::unique_ptr CreateAsyncTrackingEvent( absl::string_view description) const override { return nullptr; diff --git a/tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.cc b/tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.cc index 123eeacbc18acd..924680d857613e 100644 --- a/tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.cc +++ b/tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.cc @@ -383,6 +383,10 @@ Status TfrtCpuDevice::TransferFromOutfeed(MutableBorrowingLiteral literal) { return TransferLiteralFromOutfeedOnCpu(local_hardware_id(), literal); } +StatusOr TfrtCpuDevice::default_memory_space() const { + return Unimplemented("default_memory_space is not supported"); +} + static int CpuDeviceCount() { // By default we fix the number of devices to one. However we do let the user // override this behavior to help run tests on the host that run models in diff --git a/tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.h b/tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.h index bf23a5bfb02f87..782eb54075a121 100644 --- a/tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.h +++ b/tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.h @@ -113,6 +113,8 @@ class TfrtCpuDevice final : public PjRtDevice { Status TransferFromOutfeed(MutableBorrowingLiteral literal) override; + StatusOr default_memory_space() const override; + // Returns a semaphore for admission control on inflight computations. Semaphore& max_inflight_computations_semaphore() { return max_inflight_computations_semaphore_; diff --git a/tensorflow/compiler/xla/python/ifrt/mock.cc b/tensorflow/compiler/xla/python/ifrt/mock.cc index 17854f97ecd640..b8844ee00fc6ba 100644 --- a/tensorflow/compiler/xla/python/ifrt/mock.cc +++ b/tensorflow/compiler/xla/python/ifrt/mock.cc @@ -191,6 +191,9 @@ MockDevice::MockDevice(Device* delegated) : delegated_(delegated) { .WillByDefault([this](MutableBorrowingLiteral literal) { return delegated_->TransferFromOutfeed(std::move(literal)); }); + ON_CALL(*this, default_memory_space).WillByDefault([this]() { + return delegated_->default_memory_space(); + }); ON_CALL(*this, GetAllocatorStats).WillByDefault([this]() { return delegated_->GetAllocatorStats(); }); diff --git a/tensorflow/compiler/xla/python/ifrt/mock.h b/tensorflow/compiler/xla/python/ifrt/mock.h index 3cb77ff4275257..346fd8e03dad55 100644 --- a/tensorflow/compiler/xla/python/ifrt/mock.h +++ b/tensorflow/compiler/xla/python/ifrt/mock.h @@ -176,6 +176,8 @@ class MockDevice final : public Device { MOCK_METHOD(Status, TransferToInfeed, (const LiteralSlice& literal), (final)); MOCK_METHOD(Status, TransferFromOutfeed, (MutableBorrowingLiteral literal), (final)); + MOCK_METHOD(StatusOr, default_memory_space, (), + (const, final)); MOCK_METHOD(StatusOr, GetAllocatorStats, (), (const, final)); MOCK_METHOD(absl::Span, memory_spaces, (), diff --git a/tensorflow/compiler/xla/python/py_compile_only_client.cc b/tensorflow/compiler/xla/python/py_compile_only_client.cc index 03f04dc0edd931..6e61e1e23a5e7a 100644 --- a/tensorflow/compiler/xla/python/py_compile_only_client.cc +++ b/tensorflow/compiler/xla/python/py_compile_only_client.cc @@ -53,6 +53,9 @@ class PjRtCompileOnlyDevice : public PjRtDevice { Status TransferFromOutfeed(MutableBorrowingLiteral literal) override { return Unimplemented("TransferFromOutfeed is not supported"); } + StatusOr default_memory_space() const override { + return Unimplemented("default_memory_space is not supported"); + } private: const PjRtDeviceDescription* description_; diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/BUILD b/tensorflow/compiler/xla/python/tpu_driver/client/BUILD index 5f772f8572183b..26605659887727 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/BUILD +++ b/tensorflow/compiler/xla/python/tpu_driver/client/BUILD @@ -34,10 +34,8 @@ cc_library( "//tensorflow/compiler/xla/python/tpu_driver:tpu_driver_proto_cc", "//tensorflow/compiler/xla/service:computation_placer", "//tensorflow/tsl/framework:allocator", - "//tensorflow/tsl/platform:casts", "//tensorflow/tsl/platform:env", "//tensorflow/tsl/profiler/lib:traceme", - "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h index a70cc6429c9142..20b06ad863ab48 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h @@ -116,6 +116,11 @@ class TpuDevice : public PjRtDevice { return Unimplemented("Outfeed not yet implemented via this API"); } + StatusOr default_memory_space() const override { + return Unimplemented( + "default_memory_space not yet implemented via this API"); + } + std::unique_ptr CreateAsyncTrackingEvent( absl::string_view description) const override { return nullptr; diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc index ba86a32c5609f7..f599508105b22f 100644 --- a/tensorflow/compiler/xla/python/xla.cc +++ b/tensorflow/compiler/xla/python/xla.cc @@ -232,6 +232,13 @@ PYBIND11_MODULE(xla_extension, m) { return jax::GetMemory(device, kind); }, py::arg("kind")) + // Returns the default memory of a device. + .def("default_memory", + [](const ClientAndPtr& device) { + auto* memory_space = + xla::ValueOrThrow(device->default_memory_space()); + return WrapWithClient(device.client(), memory_space); + }) // Returns all the memories that a device can address. .def("addressable_memories", [](const ClientAndPtr& device) { diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index 6bfb6f15e4cfce..f869fcd39e1f66 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -44,7 +44,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. -_version = 173 +_version = 174 # Version number for MLIR:Python components. mlir_api_version = 54 diff --git a/tensorflow/compiler/xla/python/xla_extension/__init__.pyi b/tensorflow/compiler/xla/python/xla_extension/__init__.pyi index 64508460b9710d..71982c789addfc 100644 --- a/tensorflow/compiler/xla/python/xla_extension/__init__.pyi +++ b/tensorflow/compiler/xla/python/xla_extension/__init__.pyi @@ -355,6 +355,7 @@ class Device: def transfer_to_infeed(self, literal: _LiteralSlice): ... def transfer_from_outfeed(self, shape: Shape): ... def memory(self, kind: str) -> Memory: ... + def default_memory(self) -> Memory: ... def addressable_memories(self) -> List[Memory]: ... def live_buffers(self) -> List[Any]: ... def memory_stats(self) -> Optional[Dict[str, int]]: ... From ed715e10549ff069c87776c6bd32057aa7cb1ab1 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 25 Jul 2023 14:35:08 -0700 Subject: [PATCH 140/410] Integrate LLVM at llvm/llvm-project@049d6a3f428e Updates LLVM usage to match [049d6a3f428e](https://github.com/llvm/llvm-project/commit/049d6a3f428e) PiperOrigin-RevId: 550999404 --- third_party/llvm/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index 906573eae3531e..afed500e54f70d 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "c6f66de21af060ead6e5402858351e9e869dc15f" - LLVM_SHA256 = "91d75027562704b7e9941bbeff6b174ecfe5ea26be5dcd149b5131a2109520ea" + LLVM_COMMIT = "049d6a3f428efeb1a22f62e55b808f60b0bf27cc" + LLVM_SHA256 = "559ecaa58d53ac29c6da5df8b8b358ba8ec99f3cd8e7d08babe165f8506c2c19" tf_http_archive( name = name, From adbb3dd2a32f4ee9c5f69ed32f802a07b146a692 Mon Sep 17 00:00:00 2001 From: Deqiang Chen Date: Tue, 25 Jul 2023 14:49:31 -0700 Subject: [PATCH 141/410] Create python proto for runtime config to allow its python usage. PiperOrigin-RevId: 551003640 --- tensorflow/core/tfrt/graph_executor/BUILD | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tensorflow/core/tfrt/graph_executor/BUILD b/tensorflow/core/tfrt/graph_executor/BUILD index fcbbce16ae2e81..8a46aa94d33201 100644 --- a/tensorflow/core/tfrt/graph_executor/BUILD +++ b/tensorflow/core/tfrt/graph_executor/BUILD @@ -180,6 +180,15 @@ tf_proto_library( visibility = ["//visibility:public"], ) +# copybara:uncomment_begin(google-only) +# py_proto_library( +# name = "config_proto_py_pb2", +# api_version = 2, +# visibility = ["//visibility:public"], +# deps = [":config_proto"], +# ) +# copybara:uncomment_end + tf_proto_library( name = "test_config_proto", testonly = True, From d4f58eb6083c829ca03d3cfd78e2e1ed5f4c9181 Mon Sep 17 00:00:00 2001 From: Scott Zhu Date: Tue, 25 Jul 2023 14:59:14 -0700 Subject: [PATCH 142/410] Add a warning message to the docstring for the tf.experimental.dtensor.DTensorDataset. The DTensor dataset isn't a tf.data.Dataset, and only support the API for iterator creation and element_spec inspection. This is on par with existing `tf.distribute.DistributeDataset`. PiperOrigin-RevId: 551006420 --- tensorflow/dtensor/python/input_util.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tensorflow/dtensor/python/input_util.py b/tensorflow/dtensor/python/input_util.py index 54ccea6385df6a..280fcdabe7d081 100644 --- a/tensorflow/dtensor/python/input_util.py +++ b/tensorflow/dtensor/python/input_util.py @@ -419,6 +419,11 @@ def __init__(self, For a DTensor mesh, the number of replicas is equal to the size of the mesh's batch dimension. + Note: `tf.experimental.dtensor.DTensorDataset` instances do *not* implement + the full interface of `tf.data.Dataset`. It only supports two usages we will + mention below: iteration and `element_spec`. We don't support any other APIs + to transform or inspect the dataset. + TODO(b/223275517): add support for input datasets that are already batched to the global batch size. From 75666d334d1c2217cd66a3771acc1ae8e6338b7f Mon Sep 17 00:00:00 2001 From: Kevin Gleason Date: Tue, 25 Jul 2023 15:10:09 -0700 Subject: [PATCH 143/410] Add HLO_Commutative trait for commutative StableHLO ops PiperOrigin-RevId: 551009441 --- third_party/stablehlo/temporary.patch | 95 +++++++++++++++++++++++++++ 1 file changed, 95 insertions(+) diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch index da15cb22c58f06..ee89192bb082c9 100644 --- a/third_party/stablehlo/temporary.patch +++ b/third_party/stablehlo/temporary.patch @@ -225,6 +225,32 @@ diff --ruN a/stablehlo/stablehlo/dialect/Base.h b/stablehlo/stablehlo/dialect/Ba // This interface is implemented by both StableHLO and MHLO dialects // and is used as the foundation for sharing verification, type inference and // prettyprinting logic between them. +@@ -249,6 +253,10 @@ + template + class BroadcastingElementwise + : public mlir::OpTrait::TraitBase {}; ++ ++template ++class IsCommutative ++ : public mlir::OpTrait::TraitBase {}; + + template + class PairwiseSameOperandAndResultType +diff --ruN a/stablehlo/stablehlo/dialect/Base.td b/stablehlo/stablehlo/dialect/Base.td +--- stablehlo/stablehlo/dialect/Base.td ++++ stablehlo/stablehlo/dialect/Base.td +@@ -188,6 +188,11 @@ + // An operation that is essentially element-wise but may implement broadcasting + // semantics. + def HLO_BroadcastingElementwise : HLO_NativeOpTrait<"BroadcastingElementwise">; ++ ++// This class adds property that the operation is commutative. ++// Upstream IsCommutative has default folders, and StableHLO aims to have no ++// default folders or canonicalizations. ++def HLO_Commutative : HLO_NativeOpTrait<"IsCommutative">; + + // Op has pairwise operand and result type matching: the number of operands + // must be equal to the number of results and the type of ith operand must diff --ruN a/stablehlo/stablehlo/dialect/CMakeLists.txt b/stablehlo/stablehlo/dialect/CMakeLists.txt --- stablehlo/stablehlo/dialect/CMakeLists.txt +++ stablehlo/stablehlo/dialect/CMakeLists.txt @@ -819,6 +845,75 @@ diff --ruN a/stablehlo/stablehlo/dialect/ExperimentalOps.h b/stablehlo/stablehlo +} // namespace mlir + +#endif // STABLEHLO_DIALECT_EXPERIMENTAL_OPS_H +diff --ruN a/stablehlo/stablehlo/dialect/StablehloOps.cpp b/stablehlo/stablehlo/dialect/StablehloOps.cpp +--- stablehlo/stablehlo/dialect/StablehloOps.cpp ++++ stablehlo/stablehlo/dialect/StablehloOps.cpp +@@ -1467,7 +1467,7 @@ + if (innerOp.getNumOperands() != 2 || + !innerOp.hasTrait() || + !hasSameOperandAndResultTypes(innerOp) || +- !innerOp.hasTrait() || ++ !innerOp.hasTrait() || + !innerOp.hasTrait()) + return false; + +@@ -1664,7 +1664,7 @@ + if (!innerOpDialect || !innerOpDialect->getNamespace().equals("stablehlo") || + !innerOpNameInfo->hasTrait::Impl>() || + !innerOpNameInfo->hasTrait() || +- !innerOpNameInfo->hasTrait() || ++ !innerOpNameInfo->hasTrait() || + !innerOpNameInfo->hasTrait()) { + parser.emitError(loc, + "expected the inner-op to be a commutative binary-op from " +diff --ruN a/stablehlo/stablehlo/dialect/StablehloOps.td b/stablehlo/stablehlo/dialect/StablehloOps.td +--- stablehlo/stablehlo/dialect/StablehloOps.td ++++ stablehlo/stablehlo/dialect/StablehloOps.td +@@ -687,7 +687,7 @@ + } + + def StableHLO_AddOp : StableHLO_BinaryElementwiseOp<"add", +- [Commutative, Pure, HLO_CompatibleOperandsAndResultType]> { ++ [HLO_Commutative, Pure, HLO_CompatibleOperandsAndResultType]> { + let summary = "Add operation"; + let description = [{ + Performs element-wise addition of two tensors `lhs` and `rhs` and produces a +@@ -769,7 +769,7 @@ + } + + def StableHLO_MaxOp : StableHLO_BinaryElementwiseOp<"maximum", +- [Commutative, Pure, HLO_CompatibleOperandsAndResultType]> { ++ [HLO_Commutative, Pure, HLO_CompatibleOperandsAndResultType]> { + let summary = "Max operation"; + let description = [{ + Performs element-wise max operation on tensors `lhs` and `rhs` and produces +@@ -786,7 +786,7 @@ + } + + def StableHLO_MinOp : StableHLO_BinaryElementwiseOp<"minimum", +- [Commutative, Pure, HLO_CompatibleOperandsAndResultType]> { ++ [HLO_Commutative, Pure, HLO_CompatibleOperandsAndResultType]> { + let summary = "Min operation"; + let description = [{ + Performs element-wise min operation on tensors `lhs` and `rhs` and produces a +@@ -803,7 +803,7 @@ + } + + def StableHLO_MulOp : StableHLO_BinaryElementwiseOp<"multiply", +- [Commutative, Pure, HLO_CompatibleOperandsAndResultType]> { ++ [HLO_Commutative, Pure, HLO_CompatibleOperandsAndResultType]> { + let summary = "Mul operation"; + let description = [{ + Performs element-wise product of two tensors `lhs` and `rhs` and produces a +@@ -933,7 +933,7 @@ + // See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations + class StableHLO_BinaryBiwiseOrLogicalElementwiseOp : + StableHLO_BinaryElementwiseOp { ++ [HLO_Commutative, Pure, HLO_CompatibleOperandsAndResultType]> { + let arguments = (ins + HLO_PredOrIntTensor:$lhs, + HLO_PredOrIntTensor:$rhs diff --ruN a/stablehlo/stablehlo/tests/stablehlo_canonicalize_dynamism.mlir b/stablehlo/stablehlo/tests/stablehlo_canonicalize_dynamism.mlir --- stablehlo/stablehlo/tests/stablehlo_canonicalize_dynamism.mlir +++ stablehlo/stablehlo/tests/stablehlo_canonicalize_dynamism.mlir From 6013e897e7dcf3809c79e9eeac205a5a4b0f6b08 Mon Sep 17 00:00:00 2001 From: Jian Cai Date: Tue, 25 Jul 2023 15:17:14 -0700 Subject: [PATCH 144/410] Preserve attributes while applying canonicalization patterns. C++ canonicazation patterns genearated from tablegen records currently does not preserve attributes. This may cause performance regression as GPU ops are placed in host CPUs after the rewrite, due to loss of `device` attribute. This adds a workaround to copy the op attributes during rewrites. This currently only fixes patterns for RealDiv ops, which caused the performance regression in b/281164776. Fixes for other ops will be added in subsequent CLs. PiperOrigin-RevId: 551011381 --- .../compiler/mlir/tensorflow/tests/canonicalize.mlir | 10 +++++----- .../mlir/tensorflow/transforms/canonicalize.td | 10 ++++++---- .../mlir/tensorflow/transforms/rewrite_util.cc | 7 +++++++ .../compiler/mlir/tensorflow/transforms/rewrite_util.h | 3 +++ 4 files changed, 21 insertions(+), 9 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir index 7bca5e649b2d96..0134a07c96a8a9 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir @@ -824,23 +824,23 @@ func.func @testDivWithSqrtDivisor(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf3 // CHECK-LABEL: testRealDivWithSqrtDivisor func.func @testRealDivWithSqrtDivisor(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xf32> { %0 = "tf.Sqrt"(%arg1) : (tensor<8x16xf32>) -> tensor<8x16xf32> - %1 = "tf.RealDiv"(%arg0, %0) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> + %1 = "tf.RealDiv"(%arg0, %0) {device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> func.return %1: tensor<8x16xf32> -// CHECK: %0 = "tf.Rsqrt"(%arg1) : (tensor<8x16xf32>) -> tensor<8x16xf32> -// CHECK: %1 = "tf.Mul"(%arg0, %0) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> +// CHECK: %0 = "tf.Rsqrt"(%arg1) {device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<8x16xf32>) -> tensor<8x16xf32> +// CHECK: %1 = "tf.Mul"(%arg0, %0) {device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> // CHECK: return %1 } // CHECK-LABEL: testRealDivWithConstDivisor func.func @testRealDivWithConstDivisor(%arg0: tensor<8x2xf32>) -> tensor<8x2xf32> { %0 = "tf.Const"() {value = dense<[2.0, 4.0]> : tensor<2xf32>} : () -> tensor<2xf32> - %1 = "tf.RealDiv"(%arg0, %0) : (tensor<8x2xf32>, tensor<2xf32>) -> tensor<8x2xf32> + %1 = "tf.RealDiv"(%arg0, %0) {device = "/job:localhost/replica:0/task:0/device:GPU:0"} : (tensor<8x2xf32>, tensor<2xf32>) -> tensor<8x2xf32> func.return %1: tensor<8x2xf32> // CHECK: %[[CONST:.*]] = "tf.Const" // CHECK-SAME: value = dense<[5.000000e-01, 2.500000e-01] - // CHECK: %[[MUL:.*]] = "tf.Mul"(%arg0, %[[CONST]]) + // CHECK: %[[MUL:.*]] = "tf.Mul"(%arg0, %[[CONST]]) {device = "/job:localhost/replica:0/task:0/device:GPU:0"} // CHECK: return %[[MUL]] } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td b/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td index ac803111db4c63..49261e9122e5e6 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td @@ -35,6 +35,8 @@ def HasOnlyReadVariableOpUsers : Constraint< CPred<"llvm::all_of($0.getUsers(), [](mlir::OpOperand op) { " "return llvm::isa(op.getOwner()); })">>; +def SetAttrs: NativeCodeCall<"CopyAttributes($0.getOwner(), $1)">; + //===----------------------------------------------------------------------===// // Add op patterns. //===----------------------------------------------------------------------===// @@ -230,15 +232,15 @@ def QuantizeAndDequantizeV2ToQuantizeAndDequantizeV4 : Pat< // RealDiv op patterns. //===----------------------------------------------------------------------===// -def RealDivWithSqrtDivisor : Pat<(TF_RealDivOp $arg0, (TF_SqrtOp $arg1)), - (TF_MulOp $arg0, (TF_RsqrtOp $arg1))>; +def RealDivWithSqrtDivisor : Pat<(TF_RealDivOp:$src $arg0, (TF_SqrtOp $arg1)), + (SetAttrs $src, (TF_MulOp $arg0, (SetAttrs $src, (TF_RsqrtOp $arg1))))>; // Replace division by a constant with a multiplication by a reciprocal of that // constant. Floating point division can be ~10x more expensive than a // multiplication. def RealDivWithConstDivisor : Pat< - (TF_RealDivOp $arg0, (TF_ConstOp FloatElementsAttr<32>:$value)), - (TF_MulOp $arg0, (TF_ReciprocalOp (TF_ConstOp $value)))>; + (TF_RealDivOp:$src $arg0, (TF_ConstOp FloatElementsAttr<32>:$value)), + (SetAttrs $src, (TF_MulOp $arg0, (SetAttrs $src, (TF_ReciprocalOp (TF_ConstOp $value)))))>; //===----------------------------------------------------------------------===// // Reshape op patterns. diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/rewrite_util.cc b/tensorflow/compiler/mlir/tensorflow/transforms/rewrite_util.cc index 4fdcaf6f0803c1..c6e4ed4b1bd4de 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/rewrite_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/rewrite_util.cc @@ -51,5 +51,12 @@ bool IsOnGpuDevice(mlir::Operation *op) { return *device == kDeviceGpu; } +mlir::Value CopyAttributes(mlir::Operation *src, mlir::Operation *dest) { + // This is not expected to happen in practice. + if (dest->getNumResults() != 1) + llvm_unreachable("expected single result in `dest`"); + dest->setAttrs(src->getAttrs()); + return dest->getResult(0); +} } // namespace TF } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/rewrite_util.h b/tensorflow/compiler/mlir/tensorflow/transforms/rewrite_util.h index bdebeed6351e1a..4d4ba9ff686e5d 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/rewrite_util.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/rewrite_util.h @@ -70,6 +70,9 @@ bool IsConstantValueOf(Value value, T raw_value) { // devices or the device is not specified. bool IsOnGpuDevice(mlir::Operation *op); +// Copy the attributes of `src` op to the `dest` op , and return the first +// result of the `dest` op. +mlir::Value CopyAttributes(mlir::Operation *src, mlir::Operation *dest); } // namespace TF } // namespace mlir From 17d621bd352f8f51be3ea89d7882f046e69b451c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 25 Jul 2023 15:19:24 -0700 Subject: [PATCH 145/410] Update TFRT dependency to use revision http://github.com/tensorflow/runtime/commit/e143099db75fe9b89b3959965840dd3c4237eab3. PiperOrigin-RevId: 551011948 --- third_party/tf_runtime/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/tf_runtime/workspace.bzl b/third_party/tf_runtime/workspace.bzl index 3b2e8a79e554b7..f94f03e669937a 100644 --- a/third_party/tf_runtime/workspace.bzl +++ b/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "185b3fe2a676620ece266b6a3e4df6d5fb7264ae" - TFRT_SHA256 = "6c9f82696dac444d2d5c8a409c329170145bf6304de910ec0d1f608fec3f2d92" + TFRT_COMMIT = "e143099db75fe9b89b3959965840dd3c4237eab3" + TFRT_SHA256 = "5591f65576dabeadbba8ebe7d2d95d47576389de9b0f433cbc4e33fb9bd7eb08" tf_http_archive( name = "tf_runtime", From 4630988678072109859f2b47634dc95fca5b0984 Mon Sep 17 00:00:00 2001 From: Rahul Joshi Date: Tue, 25 Jul 2023 15:22:33 -0700 Subject: [PATCH 146/410] [XLA/GPU] Handle rematerialized clones in PGLE profile - Detect clones of instructions created by HLO rematerialization pass when reading PGLE profile, and if seen, use the average cost of the instruction across all of its clones and the original one as the cost to be used for PGLE. PiperOrigin-RevId: 551012789 --- .../xla/service/gpu/gpu_hlo_schedule.cc | 53 +++++++++++++-- .../xla/service/gpu/gpu_hlo_schedule_test.cc | 66 +++++++++++++++++++ 2 files changed, 115 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc index abe07d45d41508..bb46e1aa421da5 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc @@ -23,6 +23,8 @@ limitations under the License. #include #include "absl/strings/match.h" +#include "absl/strings/numbers.h" +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_instructions.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_schedule.h" #include "tensorflow/compiler/xla/hlo/utils/hlo_query.h" @@ -410,10 +412,11 @@ tensorflow::profiler::ProfiledInstructionsProto GetProfileForFingerprint( tensorflow::profiler::ProfiledInstructionsProto& profile, const std::string& fingerprint) { tensorflow::profiler::ProfiledInstructionsProto result; + bool merge_remat_clones = false; for (const auto& cost : profile.costs()) { - std::string cost_name = cost.name(); - std::string new_cost_name = cost_name; - std::string cost_sep = "::"; + absl::string_view cost_name = cost.name(); + std::string new_cost_name = cost.name(); + absl::string_view cost_sep = "::"; if (absl::StrContains(cost_name, cost_sep)) { std::vector split_names = absl::StrSplit(cost_name, cost_sep); @@ -423,12 +426,54 @@ tensorflow::profiler::ProfiledInstructionsProto GetProfileForFingerprint( new_cost_name = split_names[1]; } + // Check if we see instructions that have ".rematX" suffix. These are clones + // of original instructions created by HLO rematerialization pass. We will + // average the costs of the remat clones and the original instruction and + // use that as the new cost of the original one. + merge_remat_clones |= absl::StrContains(new_cost_name, ".remat"); auto* new_cost = result.add_costs(); new_cost->set_cost_us(cost.cost_us()); new_cost->set_name(new_cost_name); } - return result; + if (!merge_remat_clones) { + return result; + } + + auto strip_remat_suffix = [](absl::string_view name) -> absl::string_view { + absl::string_view suffix = ".remat"; + size_t index = name.rfind(suffix); + if (index == std::string::npos) { + return name; + } + auto after_suffix = name.substr(index + suffix.size()); + // Everything after ".remat" should be a digit or empty. If yes, strip the + // .rematN suffix. + int64_t numeric_suffix; + if (after_suffix.empty() || + absl::SimpleAtoi(after_suffix, &numeric_suffix)) { + return name.substr(0, index); + } + return name; + }; + + // Map from stripped name -> pair + absl::flat_hash_map> costs; + for (const auto& cost : result.costs()) { + std::pair& data = costs[strip_remat_suffix(cost.name())]; + data.first += cost.cost_us(); + data.second++; + } + + tensorflow::profiler::ProfiledInstructionsProto merged_result; + for (const auto& cost : costs) { + auto* new_cost = merged_result.add_costs(); + double average = cost.second.first / cost.second.second; + new_cost->set_cost_us(average); + new_cost->set_name(std::string(cost.first)); + } + + return merged_result; } std::optional ReadPGLEProfile( diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc index c77806b89fcdf6..68260838a41e98 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc @@ -415,6 +415,72 @@ TEST_F(GpuHloScheduleTest, ProfileGuidedCostModel) { } } +TEST_F(GpuHloScheduleTest, ProfileGuidedCostModelWithRematData) { + const char* hlo_text = R"( + HloModule AsyncAR + apply_op { + x = f32[] parameter(0) + y = f32[] parameter(1) + ROOT apply_op = f32[] add(x, y) + } + + ENTRY ar { + p0 = f32[32] parameter(0) + p1 = f32[32, 32] parameter(1) + p2 = f32[32, 32] parameter(2) + p3 = f32[32] parameter(3) + + // Independent compute + dot0 = f32[32,32]{1,0} custom-call(p1, p2), custom_call_target="__cublas$gemm" + dot1 = f32[32,32]{1,0} custom-call(p1, p2), custom_call_target="__cublas$gemm" + add0 = f32[32,32] add(dot0, dot1) + + // Independent collectives. + ar-start = f32[32] all-reduce-start(p0), to_apply=apply_op + ar-done = f32[32] all-reduce-done(ar-start) + + ar-start1 = f32[32] all-reduce-start(p3), to_apply=apply_op + ar-done1 = f32[32] all-reduce-done(ar-start1) + + ROOT t = (f32[32], f32[32], f32[32,32]) tuple(ar-done, ar-done1, add0) + })"; + + // Costs of "ar-start" and "ar-start.remat100" should be averaged out and + // used as cost for "ar-start". + const std::string ar_long_latency_proto_text = R"pb( + costs { name: "dot0" cost_us: 100.0 } + costs { name: "dot1" cost_us: 100.0 } + costs { name: "add0" cost_us: 10.0 } + costs { name: "ar-start" cost_us: 1.0 } + costs { name: "ar-start1" cost_us: 1.0 } + costs { name: "ar-start.remat100" cost_us: 2000.0 } + )pb"; + TF_ASSERT_OK_AND_ASSIGN( + auto module, + ParseAndReturnVerifiedModule( + hlo_text, + GetModuleConfig(/*enable_latency_hiding_scheduler=*/true, + /*enable_gpu_async_tracker=*/true, + /*fdo_profile=*/ar_long_latency_proto_text))); + SequentialHloOrdering order = BuildHloOrdering(module.get()); + + HloComputation* entry = module->entry_computation(); + + // We expect all the math instructions between ar-start/ar-done + bool between_target_collective_pair = false; + for (const HloInstruction* inst : + order.SequentialOrder(*entry)->instructions()) { + if (inst->name() == "ar-start") { + between_target_collective_pair = true; + } else if (inst->name() == "ar-done") { + between_target_collective_pair = false; + } else if (inst->opcode() == HloOpcode::kDot || + inst->opcode() == HloOpcode::kAdd) { + EXPECT_TRUE(between_target_collective_pair); + } + } +} + // Checks that the Send and Recv sequence created by the CollectivePermute // decomposer is properly scheduled: // recv From c2dccf85cdc9f4db95b041627c5f6193dfb93dc6 Mon Sep 17 00:00:00 2001 From: Austin Anderson Date: Tue, 25 Jul 2023 15:27:11 -0700 Subject: [PATCH 147/410] Fix a typo referencing the old directory structure PiperOrigin-RevId: 551014092 --- ci/official/utilities/rename_and_verify_wheels.sh | 2 +- ci/official/utilities/wheel_verification.bats | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/ci/official/utilities/rename_and_verify_wheels.sh b/ci/official/utilities/rename_and_verify_wheels.sh index 502ce8ac4488a9..8b499afd7c4438 100755 --- a/ci/official/utilities/rename_and_verify_wheels.sh +++ b/ci/official/utilities/rename_and_verify_wheels.sh @@ -25,7 +25,7 @@ for wheel in *.whl; do time python3 -m auditwheel repair --plat manylinux2014_x86_64 "$wheel" --wheel-dir build 2>&1 | tee check.txt # We don't need the original wheel if it was renamed - new_wheel=$(grep --extended-regexp --only-matching '/tf/pkg/\S+.whl' check.txt) + new_wheel=$(grep --extended-regexp --only-matching '\S+.whl' check.txt) if [[ "$new_wheel" != "$wheel" ]]; then rm "$wheel" wheel="$new_wheel" diff --git a/ci/official/utilities/wheel_verification.bats b/ci/official/utilities/wheel_verification.bats index 626954d4570387..c2ea6999cc8972 100644 --- a/ci/official/utilities/wheel_verification.bats +++ b/ci/official/utilities/wheel_verification.bats @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -# Suite of verification tests for the SINGLE TensorFlow wheel in /tf/pkg -# or whatever path is set as $TF_WHEEL. +# Suite of verification tests for the SINGLE TensorFlow wheel in the "build" +# directory, or whatever path is set as $TF_WHEEL. setup_file() { cd build From e8f4a81374a8c72be4b34ffc9178433f3e189901 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 25 Jul 2023 15:31:44 -0700 Subject: [PATCH 148/410] Temporarily remove `createFoldBroadcastPass` PiperOrigin-RevId: 551015318 --- .../compiler/mlir/lite/stablehlo/transforms/transforms.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.cc index a1d9fabb8b3b0a..9965b56dda02fc 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.cc @@ -70,7 +70,6 @@ void AddTFToStablehloPasses(OpPassManager& pm, bool skip_resize, void AddMhloOptimizationPasses(OpPassManager& pm) { pm.addNestedPass(createUnfuseBatchNormPass()); pm.addNestedPass(createFuseConvolutionPass()); - pm.addNestedPass(createFoldBroadcastPass()); pm.addNestedPass(createOptimizePass()); pm.addPass(mlir::createCanonicalizerPass()); } @@ -83,6 +82,9 @@ void AddStablehloOptimizationPasses(OpPassManager& pm) { // this happen. pm.addPass(mhlo::createStablehloLegalizeToHloPass()); AddMhloOptimizationPasses(pm); + // TODO(b/293149194) Add `createFoldBroadcastPass` back to + // `AddMhloOptimizationPasses` + pm.addNestedPass(createFoldBroadcastPass()); pm.addPass(mhlo::createHloLegalizeToStablehloPass()); } From a7fda59bed93022bc52e373a264e7b48619a856a Mon Sep 17 00:00:00 2001 From: Chuanhao Zhuge Date: Tue, 25 Jul 2023 15:45:27 -0700 Subject: [PATCH 149/410] [TF:PJRT] Disable strict_shape_checking for GPU. The shape mismatch we observe (f32[1]{0} vs. f32[]) is not a real one. Probably generated by bridge. PiperOrigin-RevId: 551019078 --- tensorflow/compiler/jit/xla_launch_util.cc | 12 ++++++++++-- tensorflow/compiler/jit/xla_launch_util.h | 1 + tensorflow/compiler/jit/xla_launch_util_test.cc | 4 +++- 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index f3b82684472bd7..e4c1d7c8309e2c 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -770,10 +770,16 @@ Status PopulateCtxOutputsFromPjRtExecutableOutputs( } xla::ExecuteOptions GetPjRtExecuteOptions( + const DeviceType& device_type, absl::flat_hash_set non_donatable_input_indices) { xla::ExecuteOptions options; options.arguments_are_tupled = false; options.untuple_result = true; + // TODO(b/293186653): investigate we should turn on strict shape checking for + // GPU. + if (device_type == DEVICE_GPU) { + options.strict_shape_checking = false; + } // Note: TF does not use PJRT host callbacks as of today. Setting this option // to true to workaround an ExecuteOptions check: [1]. // @@ -822,9 +828,10 @@ Status RunPjRtExecutable( ->tensorflow_accelerator_device_info() ->use_pjrt_tensor_buffer; + const DeviceType& device_type = GetDeviceType(ctx); TF_ASSIGN_OR_RETURN(const int pjrt_device_id, tsl::GetDeviceIdFromDeviceParsedName( - ctx->device()->parsed_name(), GetDeviceType(ctx))); + ctx->device()->parsed_name(), device_type)); TF_ASSIGN_OR_RETURN(xla::PjRtDevice * device, pjrt_client->LookupAddressableDevice(pjrt_device_id)); @@ -839,7 +846,8 @@ Status RunPjRtExecutable( std::vector> execute_outputs, executable->ExecutePortable( executable_args, device, - GetPjRtExecuteOptions(std::move(non_donatable_input_indices)))); + GetPjRtExecuteOptions(device_type, + std::move(non_donatable_input_indices)))); // We need to ensure the PjRtBuffers owned by `owned_executable_args` live // until execution is complete. diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h index 6e68a023c9b6ea..214ded00e09562 100644 --- a/tensorflow/compiler/jit/xla_launch_util.h +++ b/tensorflow/compiler/jit/xla_launch_util.h @@ -105,6 +105,7 @@ Status PopulateCtxOutputsFromPjRtExecutableOutputs( // Returns the options used for executing a PjRtLoadedExecutable. xla::ExecuteOptions GetPjRtExecuteOptions( + const DeviceType& device_type, absl::flat_hash_set non_donatable_input_indices); // Returns the device ordinal from the parsed name of the device. diff --git a/tensorflow/compiler/jit/xla_launch_util_test.cc b/tensorflow/compiler/jit/xla_launch_util_test.cc index 51f8245f5769b1..440903a90c6c79 100644 --- a/tensorflow/compiler/jit/xla_launch_util_test.cc +++ b/tensorflow/compiler/jit/xla_launch_util_test.cc @@ -514,9 +514,11 @@ TEST_F(PjRtExecutionUtilTest, PopulateCtxOutputsResourceUpdates) { } TEST(XlaLaunchUtilTest, GetPjRtExecuteOptions) { - xla::ExecuteOptions options = GetPjRtExecuteOptions({}); + xla::ExecuteOptions options = + GetPjRtExecuteOptions(DeviceType(DEVICE_GPU), {}); EXPECT_FALSE(options.arguments_are_tupled); EXPECT_TRUE(options.untuple_result); + EXPECT_FALSE(options.strict_shape_checking); EXPECT_TRUE(options.use_major_to_minor_data_layout_for_callbacks); } From 4b6288d11e32a7f1e4ecde8d2966b3600450e6f4 Mon Sep 17 00:00:00 2001 From: Austin Anderson Date: Tue, 25 Jul 2023 15:47:59 -0700 Subject: [PATCH 150/410] Fill in more of of continuous and nightly envs and jobs PiperOrigin-RevId: 551019738 --- .../envs/continuous_linux_x86_cpu_py310 | 23 +++++++++++++++++++ .../envs/continuous_linux_x86_cpu_py311 | 23 +++++++++++++++++++ .../envs/continuous_linux_x86_cpu_py39 | 23 +++++++++++++++++++ .../envs/continuous_linux_x86_cuda_py310 | 23 +++++++++++++++++++ .../envs/continuous_linux_x86_cuda_py311 | 23 +++++++++++++++++++ .../envs/continuous_linux_x86_cuda_py39 | 23 +++++++++++++++++++ .../envs/nightly_libtensorflow_linux_x86_cpu | 23 +++++++++++++++++++ ...0 => nightly_libtensorflow_linux_x86_cuda} | 0 ..._cpu_py310 => nightly_linux_x86_cpu_py310} | 0 ..._cpu_py311 => nightly_linux_x86_cpu_py311} | 0 ...ly_cpu_py39 => nightly_linux_x86_cpu_py39} | 0 ci/official/envs/nightly_linux_x86_cuda_py310 | 23 +++++++++++++++++++ ...uda_py311 => nightly_linux_x86_cuda_py311} | 0 ..._cuda_py39 => nightly_linux_x86_cuda_py39} | 0 14 files changed, 184 insertions(+) create mode 100644 ci/official/envs/continuous_linux_x86_cpu_py310 create mode 100644 ci/official/envs/continuous_linux_x86_cpu_py311 create mode 100644 ci/official/envs/continuous_linux_x86_cpu_py39 create mode 100644 ci/official/envs/continuous_linux_x86_cuda_py310 create mode 100644 ci/official/envs/continuous_linux_x86_cuda_py311 create mode 100644 ci/official/envs/continuous_linux_x86_cuda_py39 create mode 100644 ci/official/envs/nightly_libtensorflow_linux_x86_cpu rename ci/official/envs/{nightly_cuda_py310 => nightly_libtensorflow_linux_x86_cuda} (100%) rename ci/official/envs/{nightly_cpu_py310 => nightly_linux_x86_cpu_py310} (100%) rename ci/official/envs/{nightly_cpu_py311 => nightly_linux_x86_cpu_py311} (100%) rename ci/official/envs/{nightly_cpu_py39 => nightly_linux_x86_cpu_py39} (100%) create mode 100644 ci/official/envs/nightly_linux_x86_cuda_py310 rename ci/official/envs/{nightly_cuda_py311 => nightly_linux_x86_cuda_py311} (100%) rename ci/official/envs/{nightly_cuda_py39 => nightly_linux_x86_cuda_py39} (100%) diff --git a/ci/official/envs/continuous_linux_x86_cpu_py310 b/ci/official/envs/continuous_linux_x86_cpu_py310 new file mode 100644 index 00000000000000..46997fb367d3db --- /dev/null +++ b/ci/official/envs/continuous_linux_x86_cpu_py310 @@ -0,0 +1,23 @@ +#TFCI_UPLOAD_LIB_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" +#TFCI_UPLOAD_WHL_GCS_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" +TFCI_BAZEL_BAZELRC_ARGS=(--bazelrc ./ci/official/bazelrcs/cpu.bazelrc) +TFCI_BAZEL_COMMON_ARGS=(--config sigbuild_remote_cache_push) +TFCI_BUILD_PIP_PACKAGE_ARGS=(--cpu) +TFCI_COPYBARA_ENABLE=0 +TFCI_DOCKER_ENABLE=1 +TFCI_DOCKER_GPU_ARGS=() +TFCI_DOCKER_IMAGE=tensorflow/build:latest-python3.10 +TFCI_DOCKER_PULL_ENABLE=1 +TFCI_GIT_DIR=$KOKORO_ARTIFACTS_DIR/github/tensorflow +TFCI_INDEX_HTML_ENABLE=1 +TFCI_LIB_SUFFIX="-cpu-linux-x86_64" +TFCI_NIGHTLY_UPDATE_VERSION_ENABLE= +TFCI_NVIDIA_SMI_ENABLE=1 +TFCI_UPLOAD_LIB_ENABLE= +TFCI_UPLOAD_LIB_LATEST_ENABLE= +TFCI_UPLOAD_LIB_LATEST_URI= +TFCI_UPLOAD_LIB_URI= +TFCI_UPLOAD_WHL_GCS_ENABLE= +TFCI_UPLOAD_WHL_GCS_URI= +TFCI_UPLOAD_WHL_PYPI_ARGS= +TFCI_UPLOAD_WHL_PYPI_ENABLE= diff --git a/ci/official/envs/continuous_linux_x86_cpu_py311 b/ci/official/envs/continuous_linux_x86_cpu_py311 new file mode 100644 index 00000000000000..1e3b7df5ea6857 --- /dev/null +++ b/ci/official/envs/continuous_linux_x86_cpu_py311 @@ -0,0 +1,23 @@ +#TFCI_UPLOAD_LIB_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" +#TFCI_UPLOAD_WHL_GCS_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" +TFCI_BAZEL_BAZELRC_ARGS=(--bazelrc ./ci/official/bazelrcs/cpu.bazelrc) +TFCI_BAZEL_COMMON_ARGS=(--config sigbuild_remote_cache_push) +TFCI_BUILD_PIP_PACKAGE_ARGS=(--cpu) +TFCI_COPYBARA_ENABLE=0 +TFCI_DOCKER_ENABLE=1 +TFCI_DOCKER_GPU_ARGS=() +TFCI_DOCKER_IMAGE=tensorflow/build:latest-python3.11 +TFCI_DOCKER_PULL_ENABLE=1 +TFCI_GIT_DIR=$KOKORO_ARTIFACTS_DIR/github/tensorflow +TFCI_INDEX_HTML_ENABLE=1 +TFCI_LIB_SUFFIX="-cpu-linux-x86_64" +TFCI_NIGHTLY_UPDATE_VERSION_ENABLE= +TFCI_NVIDIA_SMI_ENABLE=1 +TFCI_UPLOAD_LIB_ENABLE= +TFCI_UPLOAD_LIB_LATEST_ENABLE= +TFCI_UPLOAD_LIB_LATEST_URI= +TFCI_UPLOAD_LIB_URI= +TFCI_UPLOAD_WHL_GCS_ENABLE= +TFCI_UPLOAD_WHL_GCS_URI= +TFCI_UPLOAD_WHL_PYPI_ARGS= +TFCI_UPLOAD_WHL_PYPI_ENABLE= diff --git a/ci/official/envs/continuous_linux_x86_cpu_py39 b/ci/official/envs/continuous_linux_x86_cpu_py39 new file mode 100644 index 00000000000000..fc7ca80562235d --- /dev/null +++ b/ci/official/envs/continuous_linux_x86_cpu_py39 @@ -0,0 +1,23 @@ +#TFCI_UPLOAD_LIB_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" +#TFCI_UPLOAD_WHL_GCS_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" +TFCI_BAZEL_BAZELRC_ARGS=(--bazelrc ./ci/official/bazelrcs/cpu.bazelrc) +TFCI_BAZEL_COMMON_ARGS=(--config sigbuild_remote_cache_push) +TFCI_BUILD_PIP_PACKAGE_ARGS=(--cpu) +TFCI_COPYBARA_ENABLE=0 +TFCI_DOCKER_ENABLE=1 +TFCI_DOCKER_GPU_ARGS=() +TFCI_DOCKER_IMAGE=tensorflow/build:latest-python3.9 +TFCI_DOCKER_PULL_ENABLE=1 +TFCI_GIT_DIR=$KOKORO_ARTIFACTS_DIR/github/tensorflow +TFCI_INDEX_HTML_ENABLE=1 +TFCI_LIB_SUFFIX="-cpu-linux-x86_64" +TFCI_NIGHTLY_UPDATE_VERSION_ENABLE= +TFCI_NVIDIA_SMI_ENABLE=1 +TFCI_UPLOAD_LIB_ENABLE= +TFCI_UPLOAD_LIB_LATEST_ENABLE= +TFCI_UPLOAD_LIB_LATEST_URI= +TFCI_UPLOAD_LIB_URI= +TFCI_UPLOAD_WHL_GCS_ENABLE= +TFCI_UPLOAD_WHL_GCS_URI= +TFCI_UPLOAD_WHL_PYPI_ARGS= +TFCI_UPLOAD_WHL_PYPI_ENABLE= diff --git a/ci/official/envs/continuous_linux_x86_cuda_py310 b/ci/official/envs/continuous_linux_x86_cuda_py310 new file mode 100644 index 00000000000000..5c6cf6f8397867 --- /dev/null +++ b/ci/official/envs/continuous_linux_x86_cuda_py310 @@ -0,0 +1,23 @@ +#TFCI_UPLOAD_LIB_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" +#TFCI_UPLOAD_WHL_GCS_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" +TFCI_BAZEL_BAZELRC_ARGS=(--bazelrc ./ci/official/bazelrcs/cuda.bazelrc) +TFCI_BAZEL_COMMON_ARGS=(--config sigbuild_remote_cache_push) +TFCI_BUILD_PIP_PACKAGE_ARGS=() +TFCI_COPYBARA_ENABLE=0 +TFCI_DOCKER_ENABLE=1 +TFCI_DOCKER_GPU_ARGS=(--gpus all) +TFCI_DOCKER_IMAGE=tensorflow/build:latest-python3.10 +TFCI_DOCKER_PULL_ENABLE=1 +TFCI_GIT_DIR=$KOKORO_ARTIFACTS_DIR/github/tensorflow +TFCI_INDEX_HTML_ENABLE=1 +TFCI_LIB_SUFFIX="-gpu-linux-x86_64" +TFCI_NIGHTLY_UPDATE_VERSION_ENABLE= +TFCI_NVIDIA_SMI_ENABLE=1 +TFCI_UPLOAD_LIB_ENABLE= +TFCI_UPLOAD_LIB_LATEST_ENABLE= +TFCI_UPLOAD_LIB_LATEST_URI= +TFCI_UPLOAD_LIB_URI= +TFCI_UPLOAD_WHL_GCS_ENABLE= +TFCI_UPLOAD_WHL_GCS_URI= +TFCI_UPLOAD_WHL_PYPI_ARGS= +TFCI_UPLOAD_WHL_PYPI_ENABLE= diff --git a/ci/official/envs/continuous_linux_x86_cuda_py311 b/ci/official/envs/continuous_linux_x86_cuda_py311 new file mode 100644 index 00000000000000..039c1634c6c23c --- /dev/null +++ b/ci/official/envs/continuous_linux_x86_cuda_py311 @@ -0,0 +1,23 @@ +#TFCI_UPLOAD_LIB_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" +#TFCI_UPLOAD_WHL_GCS_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" +TFCI_BAZEL_BAZELRC_ARGS=(--bazelrc ./ci/official/bazelrcs/cuda.bazelrc) +TFCI_BAZEL_COMMON_ARGS=(--config sigbuild_remote_cache_push) +TFCI_BUILD_PIP_PACKAGE_ARGS=() +TFCI_COPYBARA_ENABLE=0 +TFCI_DOCKER_ENABLE=1 +TFCI_DOCKER_GPU_ARGS=(--gpus all) +TFCI_DOCKER_IMAGE=tensorflow/build:latest-python3.11 +TFCI_DOCKER_PULL_ENABLE=1 +TFCI_GIT_DIR=$KOKORO_ARTIFACTS_DIR/github/tensorflow +TFCI_INDEX_HTML_ENABLE=1 +TFCI_LIB_SUFFIX="-gpu-linux-x86_64" +TFCI_NIGHTLY_UPDATE_VERSION_ENABLE= +TFCI_NVIDIA_SMI_ENABLE=1 +TFCI_UPLOAD_LIB_ENABLE= +TFCI_UPLOAD_LIB_LATEST_ENABLE= +TFCI_UPLOAD_LIB_LATEST_URI= +TFCI_UPLOAD_LIB_URI= +TFCI_UPLOAD_WHL_GCS_ENABLE= +TFCI_UPLOAD_WHL_GCS_URI= +TFCI_UPLOAD_WHL_PYPI_ARGS= +TFCI_UPLOAD_WHL_PYPI_ENABLE= diff --git a/ci/official/envs/continuous_linux_x86_cuda_py39 b/ci/official/envs/continuous_linux_x86_cuda_py39 new file mode 100644 index 00000000000000..1eae7b537a0598 --- /dev/null +++ b/ci/official/envs/continuous_linux_x86_cuda_py39 @@ -0,0 +1,23 @@ +#TFCI_UPLOAD_LIB_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" +#TFCI_UPLOAD_WHL_GCS_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" +TFCI_BAZEL_BAZELRC_ARGS=(--bazelrc ./ci/official/bazelrcs/cuda.bazelrc) +TFCI_BAZEL_COMMON_ARGS=(--config sigbuild_remote_cache_push) +TFCI_BUILD_PIP_PACKAGE_ARGS=() +TFCI_COPYBARA_ENABLE=0 +TFCI_DOCKER_ENABLE=1 +TFCI_DOCKER_GPU_ARGS=(--gpus all) +TFCI_DOCKER_IMAGE=tensorflow/build:latest-python3.9 +TFCI_DOCKER_PULL_ENABLE=1 +TFCI_GIT_DIR=$KOKORO_ARTIFACTS_DIR/github/tensorflow +TFCI_INDEX_HTML_ENABLE=1 +TFCI_LIB_SUFFIX="-gpu-linux-x86_64" +TFCI_NIGHTLY_UPDATE_VERSION_ENABLE= +TFCI_NVIDIA_SMI_ENABLE=1 +TFCI_UPLOAD_LIB_ENABLE= +TFCI_UPLOAD_LIB_LATEST_ENABLE= +TFCI_UPLOAD_LIB_LATEST_URI= +TFCI_UPLOAD_LIB_URI= +TFCI_UPLOAD_WHL_GCS_ENABLE= +TFCI_UPLOAD_WHL_GCS_URI= +TFCI_UPLOAD_WHL_PYPI_ARGS= +TFCI_UPLOAD_WHL_PYPI_ENABLE= diff --git a/ci/official/envs/nightly_libtensorflow_linux_x86_cpu b/ci/official/envs/nightly_libtensorflow_linux_x86_cpu new file mode 100644 index 00000000000000..0c4b25904482dd --- /dev/null +++ b/ci/official/envs/nightly_libtensorflow_linux_x86_cpu @@ -0,0 +1,23 @@ +#TFCI_UPLOAD_LIB_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" +#TFCI_UPLOAD_WHL_GCS_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" +TFCI_BAZEL_BAZELRC_ARGS=(--bazelrc ./ci/official/bazelrcs/cpu.bazelrc) +TFCI_BAZEL_COMMON_ARGS=(--config sigbuild_remote_cache_push) +TFCI_BUILD_PIP_PACKAGE_ARGS=(--cpu --nightly_flag) +TFCI_COPYBARA_ENABLE=0 +TFCI_DOCKER_ENABLE=1 +TFCI_DOCKER_GPU_ARGS=() +TFCI_DOCKER_IMAGE=tensorflow/build:latest-python3.10 +TFCI_DOCKER_PULL_ENABLE=1 +TFCI_GIT_DIR=$KOKORO_ARTIFACTS_DIR/github/tensorflow +TFCI_INDEX_HTML_ENABLE=1 +TFCI_LIB_SUFFIX="-cpu-linux-x86_64" +TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 +TFCI_NVIDIA_SMI_ENABLE=1 +TFCI_UPLOAD_LIB_ENABLE= +TFCI_UPLOAD_LIB_LATEST_ENABLE= +TFCI_UPLOAD_LIB_LATEST_URI="gs://libtensorflow-nightly/latest" +TFCI_UPLOAD_LIB_URI="gs://libtensorflow-nightly/$(date -I)" +TFCI_UPLOAD_WHL_GCS_ENABLE= +TFCI_UPLOAD_WHL_GCS_URI= +TFCI_UPLOAD_WHL_PYPI_ARGS= +TFCI_UPLOAD_WHL_PYPI_ENABLE= diff --git a/ci/official/envs/nightly_cuda_py310 b/ci/official/envs/nightly_libtensorflow_linux_x86_cuda similarity index 100% rename from ci/official/envs/nightly_cuda_py310 rename to ci/official/envs/nightly_libtensorflow_linux_x86_cuda diff --git a/ci/official/envs/nightly_cpu_py310 b/ci/official/envs/nightly_linux_x86_cpu_py310 similarity index 100% rename from ci/official/envs/nightly_cpu_py310 rename to ci/official/envs/nightly_linux_x86_cpu_py310 diff --git a/ci/official/envs/nightly_cpu_py311 b/ci/official/envs/nightly_linux_x86_cpu_py311 similarity index 100% rename from ci/official/envs/nightly_cpu_py311 rename to ci/official/envs/nightly_linux_x86_cpu_py311 diff --git a/ci/official/envs/nightly_cpu_py39 b/ci/official/envs/nightly_linux_x86_cpu_py39 similarity index 100% rename from ci/official/envs/nightly_cpu_py39 rename to ci/official/envs/nightly_linux_x86_cpu_py39 diff --git a/ci/official/envs/nightly_linux_x86_cuda_py310 b/ci/official/envs/nightly_linux_x86_cuda_py310 new file mode 100644 index 00000000000000..30bbaf108bee5e --- /dev/null +++ b/ci/official/envs/nightly_linux_x86_cuda_py310 @@ -0,0 +1,23 @@ +#TFCI_UPLOAD_LIB_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" +#TFCI_UPLOAD_WHL_GCS_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" +TFCI_BAZEL_BAZELRC_ARGS=(--bazelrc ./ci/official/bazelrcs/cuda.bazelrc) +TFCI_BAZEL_COMMON_ARGS=(--config sigbuild_remote_cache_push) +TFCI_BUILD_PIP_PACKAGE_ARGS=(--nightly_flag) +TFCI_COPYBARA_ENABLE=0 +TFCI_DOCKER_ENABLE=1 +TFCI_DOCKER_GPU_ARGS=(--gpus all) +TFCI_DOCKER_IMAGE=tensorflow/build:latest-python3.10 +TFCI_DOCKER_PULL_ENABLE=1 +TFCI_GIT_DIR=$KOKORO_ARTIFACTS_DIR/github/tensorflow +TFCI_INDEX_HTML_ENABLE=1 +TFCI_LIB_SUFFIX="-gpu-linux-x86_64" +TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 +TFCI_NVIDIA_SMI_ENABLE=1 +TFCI_UPLOAD_LIB_ENABLE= +TFCI_UPLOAD_LIB_LATEST_ENABLE= +TFCI_UPLOAD_LIB_LATEST_URI="gs://libtensorflow-nightly/latest" +TFCI_UPLOAD_LIB_URI="gs://libtensorflow-nightly/$(date -I)" +TFCI_UPLOAD_WHL_GCS_ENABLE= +TFCI_UPLOAD_WHL_GCS_URI= +TFCI_UPLOAD_WHL_PYPI_ARGS=( --config-file "$KOKORO_KEYSTORE_DIR/73361_tensorflow_pypirc_using_global_api_token" --repository testpypi ) +TFCI_UPLOAD_WHL_PYPI_ENABLE= diff --git a/ci/official/envs/nightly_cuda_py311 b/ci/official/envs/nightly_linux_x86_cuda_py311 similarity index 100% rename from ci/official/envs/nightly_cuda_py311 rename to ci/official/envs/nightly_linux_x86_cuda_py311 diff --git a/ci/official/envs/nightly_cuda_py39 b/ci/official/envs/nightly_linux_x86_cuda_py39 similarity index 100% rename from ci/official/envs/nightly_cuda_py39 rename to ci/official/envs/nightly_linux_x86_cuda_py39 From 587870e747e1c7f694c0108dd5b78da5720b4f16 Mon Sep 17 00:00:00 2001 From: Clive Verghese Date: Tue, 25 Jul 2023 16:06:32 -0700 Subject: [PATCH 151/410] Log when Find has more than 1 entity. PiperOrigin-RevId: 551024847 --- tensorflow/tsl/profiler/utils/xplane_utils.cc | 22 ++++++++++--------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/tensorflow/tsl/profiler/utils/xplane_utils.cc b/tensorflow/tsl/profiler/utils/xplane_utils.cc index 73b9ab099f17b5..caaf68bbd24c28 100644 --- a/tensorflow/tsl/profiler/utils/xplane_utils.cc +++ b/tensorflow/tsl/profiler/utils/xplane_utils.cc @@ -45,16 +45,6 @@ namespace tsl { namespace profiler { namespace { -// Returns the index of the first element in array for which pred is true. -// Returns -1 if no such element is found. -template -int Find(const protobuf::RepeatedPtrField& array, const Pred& pred) { - for (int i = 0; i < array.size(); ++i) { - if (pred(&array.Get(i))) return i; - } - return -1; -} - // Returns the indices of all elements in array for which pred is true. template std::vector FindAll(const protobuf::RepeatedPtrField& array, @@ -66,6 +56,18 @@ std::vector FindAll(const protobuf::RepeatedPtrField& array, return indices; } +// Returns the index of the first element in array for which pred is true. +// Returns -1 if no such element is found. +template +int Find(const protobuf::RepeatedPtrField& array, const Pred& pred) { + std::vector indices = FindAll(array, pred); + if (indices.size() > 1) { + LOG(WARNING) << "Found multiple " << T::descriptor()->name() + << " when only one was expected."; + } + return indices.empty() ? -1 : indices.front(); +} + template void RemoveAt(protobuf::RepeatedPtrField* array, const std::vector& indices) { From 16cb233a82b7e1accfa36bc1ee928c0be1ea98d7 Mon Sep 17 00:00:00 2001 From: Bixia Zheng Date: Tue, 25 Jul 2023 16:07:32 -0700 Subject: [PATCH 152/410] [xla][gpu] Use two runtime streams for asynchronous operations. Add AsyncStreamKind enum for GPU runtime streams. Previously, we use one runtime stream to execute asynchronous operations. We now use two runtime streams, one for asynchronous collective operations and one for P2P Send and Recv. This allows asynchronous collective operations to be overlapped with Send and Recv operations for performance. PiperOrigin-RevId: 551025117 --- .../xla/service/gpu/gpu_executable.cc | 24 +++++++++---- .../xla/service/gpu/nccl_collective_thunk.cc | 13 ++++--- .../xla/service/gpu/nccl_collective_thunk.h | 10 +++++- .../xla/service/gpu/nccl_recv_thunk.h | 3 ++ .../xla/service/gpu/nccl_send_thunk.h | 3 ++ .../compiler/xla/service/gpu/runtime/BUILD | 1 + .../xla/service/gpu/runtime/collectives.cc | 35 +++++++++++-------- .../xla/service/gpu/runtime/collectives.h | 13 ++++--- .../xla/service/gpu/runtime/executable.cc | 17 ++++++--- tensorflow/compiler/xla/service/gpu/thunk.cc | 4 +-- tensorflow/compiler/xla/service/gpu/thunk.h | 11 ++++-- 11 files changed, 93 insertions(+), 41 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index cfafa628346571..7f3ea145a71313 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -53,6 +53,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/stream_executor/device_memory.h" #include "tensorflow/compiler/xla/stream_executor/platform.h" +#include "tensorflow/compiler/xla/stream_executor/stream.h" #include "tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/tsl/platform/errors.h" @@ -223,8 +224,14 @@ Status ExecuteThunks(const std::string& module_name, ModuleIdentifier module_id, stream_priority = stream_executor::StreamPriority::Highest; } - StatusOr async_comms_stream = - run_options->BorrowStream(executor->device_ordinal(), stream_priority); + // Create the needed streams to support NcclCollectiveThunk. + absl::InlinedVector async_comms_streams; + for (int64_t i = 0; i < kAsyncStreamTotal; ++i) { + StatusOr async_comms_stream = + run_options->BorrowStream(executor->device_ordinal(), stream_priority); + async_comms_streams.push_back( + async_comms_stream.ok() ? async_comms_stream->get() : nullptr); + } uint64_t start_nanos = tsl::Env::Default()->NowNanos(); @@ -247,12 +254,15 @@ Status ExecuteThunks(const std::string& module_name, ModuleIdentifier module_id, // module, we won't get any data, but that's probably an OK trade-off. ScopedAnnotation annotation([&] { return thunk->profile_annotation(); }); VLOG(2) << "Executing the thunk for " << thunk->profile_annotation(); - TF_RET_CHECK(async_comms_stream.ok() || !NeedsAsyncCommsStream(*thunk)) - << "`run_options` must have a stream borrower for async thunks."; + if (NeedsAsyncCommsStream(*thunk)) { + for (se::Stream* async_stream : async_comms_streams) { + TF_RET_CHECK(async_stream != nullptr) + << "`run_options` must have a stream borrower for async thunks."; + } + } - Thunk::ExecuteParams thunk_params{ - *run_options, buffer_allocations, main_stream, - async_comms_stream.ok() ? async_comms_stream->get() : nullptr}; + Thunk::ExecuteParams thunk_params{*run_options, buffer_allocations, + main_stream, async_comms_streams}; TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(thunk_params)); } return MaybeSyncAndProfile(run_options, start_nanos, diff --git a/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.cc b/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.cc index f888837bd2d0c7..ac34422562b281 100644 --- a/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.h" +#include #include #include #include @@ -224,7 +225,7 @@ Status NcclCollectiveThunk::ExecuteOnStream(const ExecuteParams& params) { #if XLA_ENABLE_XCCL VLOG(1) << absl::StreamFormat("Starting %s %s.", IsAsync() ? "async" : "sync", Thunk::KindToString(kind())); - const int64_t stream_id = IsAsync() ? 1 : 0; + const int64_t stream_id = GetStreamId(); TF_ASSIGN_OR_RETURN( NcclComm::Lock comm, LockNcclComm(params.nccl_params, config().replica_groups, @@ -240,7 +241,7 @@ Status NcclCollectiveThunk::ExecuteOnStream(const ExecuteParams& params) { ncclComm_t comm) { return RunNcclCollective(params, stream, comm); }, - params, *comm); + params, *comm, GetAsyncStreamKind()); }(); TF_RETURN_IF_ERROR(status); @@ -249,7 +250,9 @@ Status NcclCollectiveThunk::ExecuteOnStream(const ExecuteParams& params) { // continue enqueuing operations. Otherwise, the allocations can cause // deadlock in the CUDA driver (b/215649390). if (first_call_to_execute_) { - se::Stream* stream = IsAsync() ? params.async_comms_stream : params.stream; + se::Stream* stream = IsAsync() + ? params.async_comms_streams[GetAsyncStreamKind()] + : params.stream; TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); first_call_to_execute_ = false; } @@ -274,8 +277,8 @@ std::string NcclCollectiveThunk::GetDeviceString( Status NcclCollectiveThunk::AsyncExecutor::Execute( absl::FunctionRef fn, - const ExecuteParams& params, ncclComm_t comm) { - se::Stream& async_comms_stream = *params.async_comms_stream; + const ExecuteParams& params, ncclComm_t comm, AsyncStreamKind stream_kind) { + se::Stream& async_comms_stream = *params.async_comms_streams[stream_kind]; // Wait until compute inputs are ready. async_comms_stream.ThenWaitFor(params.stream); diff --git a/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.h b/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.h index 45cf57e57f6b66..de29028e2ad5ec 100644 --- a/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.h @@ -113,7 +113,8 @@ class NcclCollectiveThunk : public Thunk { Status Execute( absl::FunctionRef fn, - const ExecuteParams& params, ncclComm_t comm); + const ExecuteParams& params, ncclComm_t comm, + AsyncStreamKind stream_kind); // Blocks the compute stream until async communication is complete. Status Await(const ExecuteParams& params); @@ -142,9 +143,16 @@ class NcclCollectiveThunk : public Thunk { virtual Status RunNcclCollective(const ExecuteParams& params, se::Stream& stream, ncclComm_t comm) = 0; virtual const NcclCollectiveConfig& config() const = 0; + virtual AsyncStreamKind GetAsyncStreamKind() const { + return kAsyncStreamCollective; + } private: bool IsAsync() const { return async_ != nullptr; } + int64_t GetStreamId() const { + return IsAsync() ? 1 + GetAsyncStreamKind() : 0; + } + #if XLA_ENABLE_XCCL bool first_call_to_execute_ = true; #endif // XLA_ENABLE_XCCL diff --git a/tensorflow/compiler/xla/service/gpu/nccl_recv_thunk.h b/tensorflow/compiler/xla/service/gpu/nccl_recv_thunk.h index 32ca797bd105f9..65fb7c0c34af62 100644 --- a/tensorflow/compiler/xla/service/gpu/nccl_recv_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/nccl_recv_thunk.h @@ -47,6 +47,9 @@ class NcclRecvThunk : public NcclCollectiveThunk { const NcclCollectiveConfig& config() const override { return config_.config; } Status RunNcclCollective(const ExecuteParams& params, se::Stream& stream, ncclComm_t comm) override; + AsyncStreamKind GetAsyncStreamKind() const override { + return kAsyncStreamP2P; + } private: const NcclP2PConfig config_; diff --git a/tensorflow/compiler/xla/service/gpu/nccl_send_thunk.h b/tensorflow/compiler/xla/service/gpu/nccl_send_thunk.h index ea895bdfcfd49d..f56a1d97d7ac7c 100644 --- a/tensorflow/compiler/xla/service/gpu/nccl_send_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/nccl_send_thunk.h @@ -46,6 +46,9 @@ class NcclSendThunk : public NcclCollectiveThunk { const NcclCollectiveConfig& config() const override { return config_.config; } Status RunNcclCollective(const ExecuteParams& params, se::Stream& stream, ncclComm_t comm) override; + AsyncStreamKind GetAsyncStreamKind() const override { + return kAsyncStreamP2P; + } private: const NcclP2PConfig config_; diff --git a/tensorflow/compiler/xla/service/gpu/runtime/BUILD b/tensorflow/compiler/xla/service/gpu/runtime/BUILD index 9076401a67147d..7be2899dcd15c8 100644 --- a/tensorflow/compiler/xla/service/gpu/runtime/BUILD +++ b/tensorflow/compiler/xla/service/gpu/runtime/BUILD @@ -53,6 +53,7 @@ cc_library( "//tensorflow/compiler/xla/service:global_device_id", "//tensorflow/compiler/xla/service/gpu:gpu_executable_run_options", "//tensorflow/compiler/xla/service/gpu:nccl_collective_thunks", + "//tensorflow/compiler/xla/service/gpu:thunk", "//tensorflow/compiler/xla/stream_executor:event", "//tensorflow/compiler/xla/stream_executor:executor_cache", "@com_google_absl//absl/container:flat_hash_map", diff --git a/tensorflow/compiler/xla/service/gpu/runtime/collectives.cc b/tensorflow/compiler/xla/service/gpu/runtime/collectives.cc index 47780a2613e578..604965eaf28176 100644 --- a/tensorflow/compiler/xla/service/gpu/runtime/collectives.cc +++ b/tensorflow/compiler/xla/service/gpu/runtime/collectives.cc @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/nccl_recv_thunk.h" #include "tensorflow/compiler/xla/service/gpu/nccl_send_thunk.h" #include "tensorflow/compiler/xla/service/gpu/runtime/support.h" +#include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/service_executable_run_options.h" #include "tensorflow/compiler/xla/stream_executor/stream.h" @@ -61,10 +62,11 @@ absl::Status RunSyncOrAsync( const ServiceExecutableRunOptions* run_options, CollectivesSupport* collectives, AsyncCollectivesSupport* async_collectives, int32_t uid, bool is_async, - absl::FunctionRef to_run) { + absl::FunctionRef to_run, + AsyncStreamKind stream_kind = kAsyncStreamCollective) { se::Stream* main_stream = run_options->stream(); - se::Stream* async_stream = async_collectives->async_comm_stream(); - + se::Stream* async_stream = + is_async ? async_collectives->async_comm_stream(stream_kind) : nullptr; if (is_async) { // Wait until compute inputs are ready. async_stream->ThenWaitFor(main_stream); @@ -75,7 +77,7 @@ absl::Status RunSyncOrAsync( TF_RETURN_IF_ERROR(to_run(stream)); if (is_async) { - TF_RETURN_IF_ERROR(async_collectives->RecordEvent(uid)); + TF_RETURN_IF_ERROR(async_collectives->RecordEvent(uid, stream_kind)); } int32_t device_ordinal = main_stream->parent()->device_ordinal(); return collectives->MaybeBlockAfterFirstRun(uid, device_ordinal, main_stream); @@ -295,7 +297,8 @@ static absl::Status P2PSendImpl(const ServiceExecutableRunOptions* run_options, group_mode, op_id, replica_group_offsets, replica_group_values, source_peers, target_peers, RunSend, GetSingleArgAsDeviceBufferPair, is_async); - }); + }, + kAsyncStreamP2P); #else // XLA_ENABLE_XCCL return absl::InternalError("NCCL disabled"); #endif // XLA_ENABLE_XCCL @@ -343,7 +346,8 @@ static absl::Status P2PRecvImpl(const ServiceExecutableRunOptions* run_options, group_mode, op_id, replica_group_offsets, replica_group_values, source_peers, target_peers, RunRecv, GetSingleArgAsDeviceBufferPair, is_async); - }); + }, + kAsyncStreamP2P); #else // XLA_ENABLE_XCCL return absl::InternalError("NCCL disabled"); #endif // XLA_ENABLE_XCCL @@ -716,21 +720,23 @@ absl::Status CollectivesSupport::MaybeBlockAfterFirstRun(int32_t uid, return block ? stream->BlockHostUntilDone() : absl::OkStatus(); } -AsyncCollectivesSupport::AsyncCollectivesSupport(se::Stream* async_comm_stream) - : async_comm_stream_(async_comm_stream) {} +AsyncCollectivesSupport::AsyncCollectivesSupport( + absl::Span async_streams) + : async_comm_streams_(async_streams.begin(), async_streams.end()) {} -absl::Status AsyncCollectivesSupport::RecordEvent(int32_t uid) { +absl::Status AsyncCollectivesSupport::RecordEvent( + int32_t uid, gpu::AsyncStreamKind async_stream_kind) { // Create an event on the async stream for the completion of the collective. - se::Event done_event(async_comm_stream_->parent()); + se::Event done_event(async_comm_stream(async_stream_kind)->parent()); if (!done_event.Init()) return absl::InternalError("Failed to create event"); - async_comm_stream_->ThenRecordEvent(&done_event); + async_comm_stream(async_stream_kind)->ThenRecordEvent(&done_event); absl::MutexLock lock(&mutex_); auto [_, was_inserted] = done_events_.insert({uid, std::move(done_event)}); if (!was_inserted) { return absl::InternalError(absl::StrFormat( "Async done event has not been consumed (uid=%d, device_ordinal=%d)", - uid, async_comm_stream_->parent()->device_ordinal())); + uid, async_comm_stream(async_stream_kind)->parent()->device_ordinal())); } return absl::OkStatus(); } @@ -739,9 +745,8 @@ absl::StatusOr AsyncCollectivesSupport::PopEvent(int32_t uid) { absl::MutexLock lock(&mutex_); auto done_event = done_events_.extract(uid); if (!done_event) { - return absl::InternalError(absl::StrFormat( - "Async done event was not found (uid=%d, device_ordinal=%d)", uid, - async_comm_stream_->parent()->device_ordinal())); + return absl::InternalError( + absl::StrFormat("Async done event was not found (uid=%d)", uid)); } return std::move(done_event.mapped()); } diff --git a/tensorflow/compiler/xla/service/gpu/runtime/collectives.h b/tensorflow/compiler/xla/service/gpu/runtime/collectives.h index 0911df63e63ce5..f00f2a2d7282ef 100644 --- a/tensorflow/compiler/xla/service/gpu/runtime/collectives.h +++ b/tensorflow/compiler/xla/service/gpu/runtime/collectives.h @@ -16,9 +16,12 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_RUNTIME_COLLECTIVES_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_RUNTIME_COLLECTIVES_H_ +#include + #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/runtime/custom_call_registry.h" +#include "tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.h" #include "tensorflow/compiler/xla/stream_executor/event.h" #include "tensorflow/compiler/xla/stream_executor/stream.h" @@ -54,16 +57,18 @@ class CollectivesSupport { // Support for running async collective operations communicating via events. class AsyncCollectivesSupport { public: - explicit AsyncCollectivesSupport(se::Stream* async_comm_stream); + explicit AsyncCollectivesSupport(absl::Span async_streams); - absl::Status RecordEvent(int32_t uid); + absl::Status RecordEvent(int32_t uid, AsyncStreamKind async_stream_kind); absl::StatusOr PopEvent(int32_t uid); - se::Stream* async_comm_stream() const { return async_comm_stream_; } + se::Stream* async_comm_stream(AsyncStreamKind async_stream_kind) const { + return async_comm_streams_[async_stream_kind]; + } private: absl::Mutex mutex_; - se::Stream* async_comm_stream_; + absl::InlinedVector async_comm_streams_; // Store done events for the Done ops to wait upon. absl::flat_hash_map done_events_ ABSL_GUARDED_BY(mutex_); diff --git a/tensorflow/compiler/xla/service/gpu/runtime/executable.cc b/tensorflow/compiler/xla/service/gpu/runtime/executable.cc index d70f834c0ed94c..63c7a1e3aa99cf 100644 --- a/tensorflow/compiler/xla/service/gpu/runtime/executable.cc +++ b/tensorflow/compiler/xla/service/gpu/runtime/executable.cc @@ -364,14 +364,19 @@ Status GpuRuntimeExecutable::Execute( stream_priority = se::StreamPriority::Highest; } - StatusOr async_comms_stream = - run_options->BorrowStream(executor->device_ordinal(), stream_priority); + // Create the needed streams to support NcclCollectiveThunk. + absl::InlinedVector async_comm_streams; + for (int64_t i = 0; i < kAsyncStreamTotal; ++i) { + StatusOr async_comm_stream = + run_options->BorrowStream(executor->device_ordinal(), stream_priority); + async_comm_streams.push_back( + async_comm_stream.ok() ? async_comm_stream->get() : nullptr); + } // Async Collectives support and Send/Recv events instantiated for each Gpu // executable run, so that concurrent executions can run independently using a // separate set of events for communication. - AsyncCollectivesSupport async_collectives( - async_comms_stream.ok() ? async_comms_stream->get() : nullptr); + AsyncCollectivesSupport async_collectives(async_comm_streams); SendRecvEvents send_recv_events; // Always pass in the temp buffer, even if it is null, to accommodate the @@ -428,7 +433,9 @@ Status GpuRuntimeExecutable::Execute( &concurrent_region_status, // Null pointer will be interpreted as an absence of async collectives // support and custom calls will safely return an error. - async_collectives.async_comm_stream() ? &async_collectives : nullptr); + async_collectives.async_comm_stream(kAsyncStreamCollective) + ? &async_collectives + : nullptr); // Initialize state required for running functions from registered modules. auto state_ref = modules_state_.InitializeUserData(user_data); diff --git a/tensorflow/compiler/xla/service/gpu/thunk.cc b/tensorflow/compiler/xla/service/gpu/thunk.cc index d6d74efa18a9b1..db81e840ad32a0 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/thunk.cc @@ -29,10 +29,10 @@ namespace gpu { Thunk::ExecuteParams::ExecuteParams( const ServiceExecutableRunOptions& run_options, const BufferAllocations& buffer_allocations, se::Stream* stream, - se::Stream* async_comms_stream) + absl::Span async_streams) : buffer_allocations(&buffer_allocations), stream(stream), - async_comms_stream(async_comms_stream), + async_comms_streams(async_streams.begin(), async_streams.end()), nccl_params(run_options, stream->parent()) {} /*static*/ absl::string_view Thunk::KindToString(Thunk::Kind kind) { diff --git a/tensorflow/compiler/xla/service/gpu/thunk.h b/tensorflow/compiler/xla/service/gpu/thunk.h index 59fd57cdd07a8c..37994212b3c75a 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk.h +++ b/tensorflow/compiler/xla/service/gpu/thunk.h @@ -35,6 +35,12 @@ namespace gpu { class GpuExecutable; +enum AsyncStreamKind { + kAsyncStreamCollective = 0, // Stream for asynchronous collective ops. + kAsyncStreamP2P = 1, // Stream for P2P Send and Recv ops. +}; +constexpr static int64_t kAsyncStreamTotal = kAsyncStreamP2P + 1; + // Thunk acts as the bridge between IrEmitter and GpuExecutable. It stores the // metadata IrEmitter generates for GpuExecutable to invoke an HloInstruction. // @@ -131,11 +137,12 @@ class Thunk { struct ExecuteParams { ExecuteParams(const ServiceExecutableRunOptions& run_options, const BufferAllocations& buffer_allocations, - se::Stream* stream, se::Stream* async_comms_stream); + se::Stream* stream, + absl::Span async_streams); const BufferAllocations* buffer_allocations; // never null se::Stream* stream; - se::Stream* async_comms_stream; + absl::InlinedVector async_comms_streams; NcclExecuteParams nccl_params; }; From b61c3e80dd7c88624ffe3ab64f61baeb350bf53b Mon Sep 17 00:00:00 2001 From: Yi Situ Date: Tue, 25 Jul 2023 16:15:39 -0700 Subject: [PATCH 153/410] Refactor usages of deprecated `Length()` to `size()` for flatbuffer::Vector<>. Tweaked model analyzer messages. PiperOrigin-RevId: 551027251 --- tensorflow/lite/python/analyzer_test.py | 3 +- .../python/analyzer_wrapper/model_analyzer.cc | 50 +++++++++---------- 2 files changed, 26 insertions(+), 27 deletions(-) diff --git a/tensorflow/lite/python/analyzer_test.py b/tensorflow/lite/python/analyzer_test.py index 85ab74cfa7b578..22a6a939ba900c 100644 --- a/tensorflow/lite/python/analyzer_test.py +++ b/tensorflow/lite/python/analyzer_test.py @@ -144,8 +144,7 @@ def func(x): model_content=fb_model, gpu_compatibility=True) txt = mock_stdout.getvalue() self.assertIn( - 'Your model looks compatible with GPU delegate with TFLite runtime', - txt) + 'Your model looks compatible with GPU delegate on TFLite runtime', txt) def testTxtSignatureDefs(self): with tempfile.TemporaryDirectory() as tmp_dir: diff --git a/tensorflow/lite/python/analyzer_wrapper/model_analyzer.cc b/tensorflow/lite/python/analyzer_wrapper/model_analyzer.cc index 126292dab88ec7..274a24883838b7 100644 --- a/tensorflow/lite/python/analyzer_wrapper/model_analyzer.cc +++ b/tensorflow/lite/python/analyzer_wrapper/model_analyzer.cc @@ -148,20 +148,20 @@ void dump_tensor_detail(std::stringstream& out_stream, // supports dynamic shapes. if (tensor->shape_signature()) { out_stream << "shape_signature:["; - for (int i = 0; i < tensor->shape_signature()->Length(); ++i) { + for (int i = 0; i < tensor->shape_signature()->size(); ++i) { const int j = tensor->shape_signature()->Get(i); out_stream << j; - if (i != tensor->shape_signature()->Length() - 1) { + if (i != tensor->shape_signature()->size() - 1) { out_stream << ", "; } } out_stream << "]"; } else if (tensor->shape()) { out_stream << "shape:["; - for (int i = 0; i < tensor->shape()->Length(); ++i) { + for (int i = 0; i < tensor->shape()->size(); ++i) { const int j = tensor->shape()->Get(i); out_stream << j; - if (i != tensor->shape()->Length() - 1) { + if (i != tensor->shape()->size() - 1) { out_stream << ", "; } } @@ -173,7 +173,7 @@ void dump_tensor_detail(std::stringstream& out_stream, // Dump buffer size of constant tensors. auto buffer_idx = tensor->buffer(); - if (buffer_idx != 0 && buffer_idx < model->buffers()->Length()) { + if (buffer_idx != 0 && buffer_idx < model->buffers()->size()) { auto* buffer = model->buffers()->Get(buffer_idx); if (buffer->data() && buffer->data()->size() != 0) { out_stream << " RO " << buffer->data()->size() << " bytes"; @@ -194,14 +194,14 @@ void dump_tensor_list(std::stringstream& out_stream, if (tensors == nullptr) { return; } - for (int i = 0; i < tensors->Length(); ++i) { + for (int i = 0; i < tensors->size(); ++i) { const int tensor_idx = tensors->Get(i); if (verbose) { out_stream << "tensor #" << tensor_idx; } else { out_stream << tensor_str(tensor_idx, subgraph_idx, model); } - if (i != tensors->Length() - 1) { + if (i != tensors->size() - 1) { if (verbose) { out_stream << " and "; } else { @@ -257,10 +257,10 @@ void dump_model_summary(std::stringstream& out_stream, const ::tflite::Model* model) { auto* subgraphs = model->subgraphs(); out_stream - << "Your TFLite model has '" << subgraphs->Length() + << "Your TFLite model has '" << subgraphs->size() << "' subgraph(s). In the subgraph description below,\nT# represents the " "Tensor numbers. "; - if (subgraphs->Length() > 0 && subgraphs->Get(0)->operators()->Length() > 0) { + if (subgraphs->size() > 0 && subgraphs->Get(0)->operators()->size() > 0) { const Operator* first_op = subgraphs->Get(0)->operators()->Get(0); const OperatorCode* first_op_code = model->operator_codes()->Get(first_op->opcode_index()); @@ -279,20 +279,20 @@ void dump_model_summary(std::stringstream& out_stream, void dump_model_signature_defs(std::stringstream& out_stream, const ::tflite::Model* model) { auto* signatures = model->signature_defs(); - if (signatures == nullptr || signatures->Length() == 0) { + if (signatures == nullptr || signatures->size() == 0) { return; } out_stream << kSectionSplitter; - out_stream << "Your TFLite model has '" << signatures->Length() + out_stream << "Your TFLite model has '" << signatures->size() << "' signature_def(s).\n\n"; - for (int i = 0; i < signatures->Length(); ++i) { + for (int i = 0; i < signatures->size(); ++i) { auto* signature_def = signatures->Get(i); out_stream << "Signature#" << i << " key: '" << signature_def->signature_key()->str() << "'\n"; out_stream << "- Subgraph: " << subgraph_str(signature_def->subgraph_index()) << "\n"; out_stream << "- Inputs: \n"; - for (int j = 0; j < signature_def->inputs()->Length(); ++j) { + for (int j = 0; j < signature_def->inputs()->size(); ++j) { auto* input = signature_def->inputs()->Get(j); out_stream << " '" << input->name()->str() << "' : " << tensor_str(input->tensor_index(), @@ -300,7 +300,7 @@ void dump_model_signature_defs(std::stringstream& out_stream, << "\n"; } out_stream << "- Outputs: \n"; - for (int j = 0; j < signature_def->outputs()->Length(); ++j) { + for (int j = 0; j < signature_def->outputs()->size(); ++j) { auto* output = signature_def->outputs()->Get(j); out_stream << " '" << output->name()->str() << "' : " << tensor_str(output->tensor_index(), @@ -351,8 +351,8 @@ void dump_model_stats(std::stringstream& out_stream, "Total data buffer size", total_buffer_size, (static_cast(total_buffer_size) / model_size * 100)); out_stream << temp; - if (model->subgraphs()->Length() > 1) { - for (int i = 0; i < model->subgraphs()->Length(); ++i) { + if (model->subgraphs()->size() > 1) { + for (int i = 0; i < model->subgraphs()->size(); ++i) { float subgraph_buffer_ratio = static_cast(stats->buffer_usage[i]) / model_size * 100; snprintf(temp, sizeof(temp), @@ -434,12 +434,12 @@ std::string model_analyzer(const std::string& model_file_or_buffer, const ::tflite::Model* model = fb_model->GetModel(); auto* subgraphs = model->subgraphs(); ModelStats stats; - stats.buffer_usage.resize(subgraphs->Length()); + stats.buffer_usage.resize(subgraphs->size()); dump_model_summary(out_stream, model); bool model_is_gpu_compatible = true; - for (int i = 0; i < subgraphs->Length(); ++i) { + for (int i = 0; i < subgraphs->size(); ++i) { std::vector gpu_incompatible_nodes; const SubGraph* subgraph = subgraphs->Get(i); out_stream << subgraph_str(i); @@ -452,7 +452,7 @@ std::string model_analyzer(const std::string& model_file_or_buffer, dump_tensor_list(out_stream, subgraph->outputs(), i); out_stream << "]\n"; if (subgraph->operators()) { - for (int j = 0; j < subgraph->operators()->Length(); ++j) { + for (int j = 0; j < subgraph->operators()->size(); ++j) { const Operator* op = subgraph->operators()->Get(j); const OperatorCode* op_code = model->operator_codes()->Get(op->opcode_index()); @@ -474,15 +474,14 @@ std::string model_analyzer(const std::string& model_file_or_buffer, out_stream << "\nGPU COMPATIBILITY WARNING: Subgraph#" << i << " has GPU delegate compatibility issues at nodes " << absl::StrJoin(gpu_incompatible_nodes, ", ") - << " with TFLite runtime version " << TF_VERSION_STRING - << "\n"; + << " on TFLite runtime version " << TF_VERSION_STRING << "\n"; } // Dump Subgraph Tensors. out_stream << "\nTensors of " << subgraph_str(i) << "\n"; auto tensors = subgraph->tensors(); if (tensors) { - for (int j = 0; j < tensors->Length(); ++j) { + for (int j = 0; j < tensors->size(); ++j) { auto tensor = tensors->Get(j); out_stream << " "; // indents for tensors dump_tensor_detail(out_stream, tensor, j, i, model, &stats); @@ -493,9 +492,10 @@ std::string model_analyzer(const std::string& model_file_or_buffer, if (check_gpu_compatibility && model_is_gpu_compatible) { out_stream << "\nYour model looks compatible with GPU delegate" - << " with TFLite runtime version " << TF_VERSION_STRING - << ".\nBut it doesn't guarantee that your model works well with GPU " - "delegate.\nThere could be some runtime incompatibililty happen.\n"; + << " on TFLite runtime version " << TF_VERSION_STRING + << ".\nThis does not guarantee that your model will work well with GPU" + " delegate because there could still be runtime " + "incompatibililties.\n"; } dump_model_signature_defs(out_stream, model); From 63ffa80a98f0345150175b1b9ee6243f51925c32 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 25 Jul 2023 16:22:51 -0700 Subject: [PATCH 154/410] Added functions to count float and integer values within a literal in LiteralBase. PiperOrigin-RevId: 551029097 --- tensorflow/compiler/xla/literal.cc | 24 +++++++++++ tensorflow/compiler/xla/literal.h | 54 ++++++++++++++++++++++++ tensorflow/compiler/xla/literal_test.cc | 55 +++++++++++++++++++++++++ 3 files changed, 133 insertions(+) diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc index e678582d6b2b10..c9e5893d27bb0d 100644 --- a/tensorflow/compiler/xla/literal.cc +++ b/tensorflow/compiler/xla/literal.cc @@ -1936,6 +1936,30 @@ bool Literal::Piece::IsAll(const Literal& scalar) const { subshape().element_type()); } +int64_t Literal::Piece::CountAll(const Literal& scalar) const { + CHECK(ShapeUtil::IsScalar(scalar.shape())) << scalar.shape().ToString(); + if (!subshape().IsArray()) { + return 0; + } + + CHECK(LayoutUtil::IsDenseArray(subshape())) + << __func__ << " is only supported for dense arrays: " << subshape(); + CHECK_EQ(subshape().element_type(), scalar.shape().element_type()); + return primitive_util::PrimitiveTypeSwitch( + [&](auto primitive_type_constant) -> int64_t { + if constexpr (primitive_util::IsArrayType(primitive_type_constant)) { + using NativeT = NativeTypeOf; + return absl::c_count_if( + this->data(), [&](NativeT elem) -> bool { + return EqualIncludingNan(elem, + scalar.GetFirstElement()); + }); + } + return 0; + }, + subshape().element_type()); +} + bool LiteralBase::IsAll(const Literal& scalar) const { return root_piece().IsAll(scalar); } diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h index 78940744a2bb4a..05485887a7af96 100644 --- a/tensorflow/compiler/xla/literal.h +++ b/tensorflow/compiler/xla/literal.h @@ -276,6 +276,18 @@ class LiteralBase { // Returns false if this literal is not an array. bool IsAllFirst() const; + // Returns the number of elements that have value equal to the given value. + // Returns 0 if value does not fit in this literal's type or if the literal + // is not an array. + template + int64_t CountEqual(T value) const; + + // Returns the number of elements that have value equal to the given complex + // value. Returns 0 if value does not fit in this literal's type or if the + // literal is not an array. + template + int64_t CountEqual(std::complex value) const; + // Literal consists entirely of an iota. bool IsR1Iota() const; @@ -642,6 +654,10 @@ class LiteralBase { // - `scalar`'s type matches that of `this`. bool IsAll(const Literal& scalar) const; + // Returns the number of elements with equal value to the given literal. + // Returns 0 if this Piece is not an array. + int64_t CountAll(const Literal& scalar) const; + // Returns true if this piece and 'other' contain the same data. This piece // and 'other' must be array-shaped and compatible. If a literal has dynamic // shape, comparison is done only for the valid elements. @@ -1313,6 +1329,44 @@ NativeT LiteralBase::GetFirstElement() const { return data().at(0); } +template +int64_t LiteralBase::CountEqual(T value) const { + if (!shape().IsArray()) { + return 0; + } + PrimitiveType ty = shape().element_type(); + Literal scalar(ShapeUtil::MakeScalarShape(ty)); + return primitive_util::PrimitiveTypeSwitch( + [&](auto primitive_type_constant) -> int64_t { + if constexpr (primitive_util::IsArrayType(primitive_type_constant)) { + using NativeT = primitive_util::NativeTypeOf; + scalar.Set({}, static_cast(value)); + return root_piece().CountAll(scalar); + } + return 0; + }, + ty); +} + +template +int64_t LiteralBase::CountEqual(std::complex value) const { + if (!shape().IsArray()) { + return 0; + } + PrimitiveType ty = shape().element_type(); + Literal scalar(ShapeUtil::MakeScalarShape(ty)); + return primitive_util::PrimitiveTypeSwitch( + [&](auto primitive_type_constant) -> int64_t { + if constexpr (primitive_util::IsComplexType(primitive_type_constant)) { + using NativeT = primitive_util::NativeTypeOf; + scalar.Set({}, static_cast(value)); + return root_piece().CountAll(scalar); + } + return 0; + }, + ty); +} + template TF_ATTRIBUTE_NOINLINE void LiteralBase::EachCell( absl::FunctionRef indices, NativeT value)> diff --git a/tensorflow/compiler/xla/literal_test.cc b/tensorflow/compiler/xla/literal_test.cc index 8258ad9f6e7b73..fa7218a11bf1b1 100644 --- a/tensorflow/compiler/xla/literal_test.cc +++ b/tensorflow/compiler/xla/literal_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" +#include #include #include #include @@ -684,6 +685,60 @@ TEST_F(LiteralUtilTest, IsAllFirst) { EXPECT_FALSE(LiteralUtil::CreateR2({{c7_9}, {c8_9}}).IsAllFirst()); } +TEST_F(LiteralUtilTest, CountEqualInt) { + EXPECT_EQ(LiteralUtil::CreateR1({}).CountEqual(1), 0); + EXPECT_EQ( + LiteralUtil::CreateR1({1, 2, 3, 4, 5, 100}).CountEqual(2), + 1); + EXPECT_EQ(LiteralUtil::CreateR1({0, 3, 6, 0, 9, 18, 0}) + .CountEqual(0), + 3); + EXPECT_EQ(LiteralUtil::CreateR1({234, 345, 4, 45, 5467, 5467, 5467}) + .CountEqual(5467), + 3); +} + +TEST_F(LiteralUtilTest, CountEqualFloat) { + EXPECT_EQ(LiteralUtil::CreateR1({}).CountEqual(0), 0); + EXPECT_EQ(LiteralUtil::CreateR1({1.1, 2.2, 3.3, 4.4, 5.5, 100.6}) + .CountEqual(3.3), + 1); + EXPECT_EQ(LiteralUtil::CreateR1({7.62, 3, 7.75, 7.62, 7.3, 2, 7.62}) + .CountEqual(7.62), + 3); + EXPECT_EQ(LiteralUtil::CreateR1( + {NAN, 0, 6.8, NAN, NAN, NAN, 63.12, 24.6, NAN}) + .CountEqual(NAN), + 5); +} + +TEST_F(LiteralUtilTest, CountEqualBool) { + EXPECT_EQ(LiteralUtil::CreateR1({false, true}).CountEqual(false), + 1); +} + +TEST_F(LiteralUtilTest, CountEqualComplex) { + EXPECT_EQ(LiteralUtil::CreateR1>( + {std::complex(1, 2), std::complex(3, 4), + std::complex(5, 6), std::complex(6, 7)}) + .CountEqual(std::complex(5, 6)), + 1); +} + +TEST_F(LiteralUtilTest, CountEqualMismatched) { + EXPECT_EQ(LiteralUtil::CreateR1({13, 10.5, 15.6, 22.7}) + .CountEqual(13), + 1); + EXPECT_EQ( + LiteralUtil::CreateR1({10.5, 15.6, 22.7}).CountEqual(1), + 0); + EXPECT_EQ(LiteralUtil::CreateR1>( + {std::complex(1, 2), std::complex(3, 4), + std::complex(5, 6), std::complex(6, 7)}) + .CountEqual(1), + 0); +} + TEST_F(LiteralUtilTest, IsZero) { auto scalar_zero = LiteralUtil::CreateR0(0.0f); auto scalar_one = LiteralUtil::CreateR0(1.0f); From e2f8d3812090dec0de1bc19cad81c9e68bad642e Mon Sep 17 00:00:00 2001 From: "Varghese, Jojimon" Date: Tue, 25 Jul 2023 16:45:30 -0700 Subject: [PATCH 155/410] Fix to avoid possible memory leak when hash map is allocated --- tensorflow/core/graph/mkl_graph_util.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/graph/mkl_graph_util.h b/tensorflow/core/graph/mkl_graph_util.h index 886c4051c8ca13..1a2d235dcc3b20 100644 --- a/tensorflow/core/graph/mkl_graph_util.h +++ b/tensorflow/core/graph/mkl_graph_util.h @@ -196,8 +196,8 @@ static inline bool IsMklOp(const string& op_name, DataType T, string label = is_native_op ? kMklNameChangeOpLabelPattern : kMklLayoutDependentOpLabelPattern; string registered_kernels_key = op_name + label + std::to_string(T); - thread_local static auto* registered_kernels_map = - new absl::flat_hash_map(); + thread_local static auto registered_kernels_map = + std::make_unique>(); auto kernel_element = registered_kernels_map->find(registered_kernels_key); bool kernel_registered = false; From 2701cc4c24b637aa2972a92acbc2fc76f6cf974d Mon Sep 17 00:00:00 2001 From: Milos Puzovic Date: Wed, 26 Jul 2023 00:56:45 +0100 Subject: [PATCH 156/410] Fixes for failing tests from CI --- .../optimize_function_graph_utils.cc | 24 ++++++++++++------- .../optimize_function_graph_utils.h | 3 ++- .../process_function_library_runtime.cc | 8 +++---- tensorflow/core/util/BUILD | 6 ++--- tensorflow/core/util/mkl_heuristics.h | 3 +-- 5 files changed, 26 insertions(+), 18 deletions(-) diff --git a/tensorflow/core/common_runtime/optimize_function_graph_utils.cc b/tensorflow/core/common_runtime/optimize_function_graph_utils.cc index 222dab9efbaabe..91d9238c8fc8c0 100644 --- a/tensorflow/core/common_runtime/optimize_function_graph_utils.cc +++ b/tensorflow/core/common_runtime/optimize_function_graph_utils.cc @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/function_def_utils.h" #include "tensorflow/core/common_runtime/function_optimization_registry.h" #include "tensorflow/core/common_runtime/function_utils.h" +#include "tensorflow/core/common_runtime/local_device.h" #include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/common_runtime/optimized_function_graph_info.h" #include "tensorflow/core/common_runtime/partitioning_utils.h" @@ -553,13 +554,6 @@ StatusOr OptimizeFunctionGraph( options.shape_inference_on_tfe_dialect_import; optimization_options.debug_filename_prefix = function_name; - if (cpu_device->tensorflow_cpu_worker_threads() != nullptr) { - // Forward to the optimisation pass number of intra threads that are used to - // parallelise operations. - session_options.config.set_intra_op_parallelism_threads( - cpu_device->tensorflow_cpu_worker_threads()->num_threads); - } - DEBUG_DATA_DUMPER()->DumpGraph(function_name, kDebugGroupMain, "before_pre_placement_passes", graph.get(), &reachable_lib_def, false); @@ -747,7 +741,8 @@ PreprocessAndPartitionGraph( OptimizedFunctionGraphInfo& input_optimized_graph, const FunctionLibraryRuntime::InstantiateOptions& options, const DeviceSet& dev_set, const FunctionLibraryDefinition* input_lib_def, - const std::vector& composite_devices, Env* env) { + const std::vector& composite_devices, Env* env, + Device* cpu_device) { std::unique_ptr& graph = input_optimized_graph.function_graph; // Expand the nodes assigned to a CompositeDevice before graph partition to @@ -790,6 +785,10 @@ PreprocessAndPartitionGraph( // Doing post-partitioning passes. GraphOptimizationPassOptions optimization_options; + SessionOptions session_options; + session_options.env = env; + session_options.config = options.config_proto; + optimization_options.session_options = &session_options; optimization_options.flib_def = &(input_optimized_graph.lib_def); optimization_options.is_function_graph = true; optimization_options.graph = nullptr; @@ -797,6 +796,15 @@ PreprocessAndPartitionGraph( optimization_options.partition_graphs = device_name_to_subgraphs.get(); optimization_options.debug_filename_prefix = function_name; + // As not all devices might set number of threads for intra op + // parallelisation we restrict this only to local device which does. + if (cpu_device && dynamic_cast(cpu_device)) { + // Forward to the optimisation pass number of intra threads that are used to + // parallelise operations. + session_options.config.set_intra_op_parallelism_threads( + cpu_device->tensorflow_cpu_worker_threads()->num_threads); + } + // Normally POST_PARTITIONING passes are run by distributed workers. // Distributed workers are currently not supported in this code path, so we // run the passes here. diff --git a/tensorflow/core/common_runtime/optimize_function_graph_utils.h b/tensorflow/core/common_runtime/optimize_function_graph_utils.h index 72fec7528dc582..3ccf9bedf5adc9 100644 --- a/tensorflow/core/common_runtime/optimize_function_graph_utils.h +++ b/tensorflow/core/common_runtime/optimize_function_graph_utils.h @@ -93,7 +93,8 @@ PreprocessAndPartitionGraph( OptimizedFunctionGraphInfo& input_optimized_graph, const FunctionLibraryRuntime::InstantiateOptions& options, const DeviceSet& dev_set, const FunctionLibraryDefinition* input_lib_def, - const std::vector& composite_devices, Env* env); + const std::vector& composite_devices, Env* env, + Device* cpu_device); } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc index fcb6bd2f52e841..6f89c68a01673e 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc @@ -587,10 +587,10 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice( optimized_graph_info->function_graph->mutable_flib_def() ->set_default_registry(&(optimized_graph_info->lib_def)); - TF_ASSIGN_OR_RETURN( - auto subgraphs, - PreprocessAndPartitionGraph(function_name, *optimized_graph_info, options, - *dev_set, lib_def_, composite_devices, env_)); + TF_ASSIGN_OR_RETURN(auto subgraphs, PreprocessAndPartitionGraph( + function_name, *optimized_graph_info, + options, *dev_set, lib_def_, + composite_devices, env_, cpu_device)); const uint64 optimization_end_time_usecs = Env::Default()->NowMicros(); const uint64 graph_optimization_duration = optimization_end_time_usecs - optimization_start_time_usecs; diff --git a/tensorflow/core/util/BUILD b/tensorflow/core/util/BUILD index 6c122b264aeb0b..acd4825c36004c 100644 --- a/tensorflow/core/util/BUILD +++ b/tensorflow/core/util/BUILD @@ -10,8 +10,8 @@ load( "//tensorflow:tensorflow.bzl", "check_deps", "tf_cc_test", - "tf_cc_tests", "tf_cc_test_mkl", + "tf_cc_tests", "tf_copts", "tf_cuda_library", "tf_cuda_only_cc_test", @@ -966,11 +966,11 @@ tf_cc_test_mkl( srcs = ["mkl_heuristics_test.cc"], linkstatic = 1, # Fixes dyld error on MacOS. deps = [ + "//tensorflow/core:framework", "//tensorflow/core:framework_lite", + "//tensorflow/core:graph", "//tensorflow/core:test", "//tensorflow/core:test_main", - "//tensorflow/core:graph", - "//tensorflow/core:framework", "//tensorflow/core/kernels:ops_testutil", ], ) diff --git a/tensorflow/core/util/mkl_heuristics.h b/tensorflow/core/util/mkl_heuristics.h index 518712b85c7d5f..f8f88b6100e1fb 100644 --- a/tensorflow/core/util/mkl_heuristics.h +++ b/tensorflow/core/util/mkl_heuristics.h @@ -22,8 +22,7 @@ limitations under the License. #include -#include "tensorflow/core/framework/node_def.pb.h" -#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/tsl/platform/cpu_info.h" namespace tensorflow { From 7dd34af8e41c00e528b14d545c0f797a034e12fe Mon Sep 17 00:00:00 2001 From: Bixia Zheng Date: Tue, 25 Jul 2023 16:53:14 -0700 Subject: [PATCH 157/410] [xla][gpu] Use different resources to model asynchronous collective and P2P operations scheduling. Previously, we use the same resource to model the scheduling of asynchronous collective operations and P2P Send and Recv operations. This assumes we use one runtime stream to implement these asynchronous operations. We have just modified the runtime to user two streams, one for collective operations and one for P2P Send and Recv operations. As such, we now modify the HLO scheduler to use different resources to schedule collective operations and P2P operations, to match the runtime implementation. Add a test case. PiperOrigin-RevId: 551036618 --- .../xla/service/gpu/gpu_hlo_schedule.cc | 31 ++++--- .../xla/service/gpu/gpu_hlo_schedule_test.cc | 92 +++++++++++++++++++ 2 files changed, 110 insertions(+), 13 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc index bb46e1aa421da5..d32f414f979af0 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc @@ -245,14 +245,15 @@ SchedulerConfig GetSchedulerConfig(const GpuDeviceInfo& gpu_info) { // GPU specific resources for latency hiding scheduler. // -// We use two resources to model collective operations: a resource for sending -// data and a resource for receiving data. All collective operations require -// both resources while the Send and Recv operations requires only the single -// resource corresponding to the operation. +// We use two different set of resources to model the scheduling of asynchronous +// collective operations and P2P Send and Recv operations. This corresponds to +// the fact that the runtime use a stream to run asynchronous collective +// operations and another stream to run P2P Send and Recv operations. enum class GpuResourceType { - kGpuAsyncStreamSend = 0, // The resource for sending data. - kGpuAsyncStreamRecv = 1, // The resource for receiving data. - kNumTargetResources = 2, + kGpuAsyncStreamSend = 0, // The resource for P2P Send operation. + kGpuAsyncStreamRecv = 1, // The resource for P2P Recv operation. + kGpuAsyncStreamCollectives = 2, // The resource for collective operations. + kNumTargetResources = 3, }; // Base GPU async tracker that enables async tracking only for async collectives @@ -301,11 +302,12 @@ class GpuAsyncTracker : public GpuAsyncTrackerBase { resources.push_back(std::make_pair(gpu_stream_resource, usage)); }; - if (op.inner != HloOpcode::kRecv) { + if (op.inner == HloOpcode::kSend) { add_resource(GpuResourceType::kGpuAsyncStreamSend); - } - if (op.inner != HloOpcode::kSend) { + } else if (op.inner == HloOpcode::kRecv) { add_resource(GpuResourceType::kGpuAsyncStreamRecv); + } else { + add_resource(GpuResourceType::kGpuAsyncStreamCollectives); } return resources; } @@ -346,11 +348,14 @@ class GpuAsyncTracker : public GpuAsyncTrackerBase { } CHECK_LE(resource_type, first_target_resource + GetNumTargetDefinedResources()); - switch (resource_type - first_target_resource) { - case static_cast(GpuResourceType::kGpuAsyncStreamSend): + switch ( + static_cast(resource_type - first_target_resource)) { + case GpuResourceType::kGpuAsyncStreamSend: return "kGpuAsyncStreamSend"; - case static_cast(GpuResourceType::kGpuAsyncStreamRecv): + case GpuResourceType::kGpuAsyncStreamRecv: return "kGpuAsyncStreamRecv"; + case GpuResourceType::kGpuAsyncStreamCollectives: + return "kGpuAsyncStreamCollectives"; default: return "kUnsupportedResource"; } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc index 68260838a41e98..46c6fa8ecb17fb 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc @@ -664,6 +664,98 @@ TEST_F(GpuHloScheduleTest, LHSSendRecvPairs2) { EXPECT_LT(get_index("send-done-0"), get_index("send-1")); } +// Checks that asynchronous AllReduce is scheduled to interleave with the Send +// and Recv sequence. +TEST_F(GpuHloScheduleTest, LHSSendRecvAllReduce) { + const char* hlo_text = R"( + HloModule test + add (x: f32[], y: f32[]) -> f32[] { + x = f32[] parameter(0) + y = f32[] parameter(1) + ROOT add = f32[] add(f32[] x, f32[] y) + } + + while_cond { + param = (u32[], f32[1, 1024, 1024]) parameter(0) + count = get-tuple-element(%param), index=0 + ub = u32[] constant(25) + ROOT cond_result = pred[] compare(count, ub), direction=LT + } + + while_body { + param = (u32[], f32[1, 1024, 1024]) parameter(0) + count = get-tuple-element(%param), index=0 + send-data = get-tuple-element(%param), index=1 + + after-all = token[] after-all() + recv = (f32[1, 1024, 1024], u32[], token[]) recv(after-all), channel_id=1, + frontend_attributes={ + _xla_send_recv_source_target_pairs="{{0, 1}}" + } + send = (f32[1, 1024, 1024], u32[], token[]) send(send-data, after-all), + channel_id=1, control-predecessors={recv}, frontend_attributes={ + _xla_send_recv_source_target_pairs="{{0, 1}}" + } + recv-done = (f32[1, 1024, 1024], token[]) recv-done(recv), channel_id=1, control-predecessors={send} + send-done = token[] send-done(send), control-predecessors={recv-done}, channel_id=1 + recv-data = f32[1, 1024, 1024] get-tuple-element(recv-done), index=0 + + c1 = u32[] constant(1) + new_count = u32[] add(count, c1) + replica = u32[] replica-id() + c10 = u32[] constant(10) + sum = u32[] add(replica, c10) + sum2 = u32[] add(sum, count) + conv = f32[] convert(sum2) + p = f32[1, 1024, 1024] broadcast(conv), dimensions={} + b = f32[1, 1024, 1024] add(p, recv-data) + c = f32[1, 1024, 1024] multiply(b, b) + d = f32[1, 1024, 1024] tan(c) + s = f32[1, 1024, 1024] dot(c, d), lhs_batch_dims={0}, + lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1} + + all-reduce-start = f32[1, 1024, 1024] all-reduce-start(f32[1, 1024, 1024] p), + replica_groups={{0,1}}, to_apply=add, backend_config={"is_sync":false} + all-reduce-done = f32[1, 1024, 1024] all-reduce-done(f32[1, 1024, 1024] all-reduce-start) + new-data = f32[1, 1024, 1024] add(s, all-reduce-done) + ROOT result = (u32[], f32[1, 1024, 1024]) tuple(new_count, new-data) + } + + ENTRY test_computation { + c0 = u32[] constant(0) + f0 = f32[] constant(0.0) + init = f32[1, 1024, 1024] broadcast(f0), dimensions={} + while_init = (u32[], f32[1, 1024, 1024]) tuple(c0, init) + while_result = (u32[], f32[1, 1024, 1024]) while(while_init), + body=while_body, condition=while_cond + ROOT entry_result = f32[1, 1024, 1024] get-tuple-element(while_result), index=1 + } + )"; + + TF_ASSERT_OK_AND_ASSIGN( + auto module, + ParseAndReturnVerifiedModule( + hlo_text, GetModuleConfig(/*enable_latency_hiding_scheduler=*/true, + /*enable_gpu_async_tracker=*/true))); + SequentialHloOrdering order = BuildHloOrdering(module.get()); + HloComputation* while_body = module->GetComputationWithName("while_body"); + const std::vector& instruction_sequence = + order.SequentialOrder(*while_body)->instructions(); + auto get_index = [&](absl::string_view hlo_name) { + return absl::c_find_if(instruction_sequence, + [hlo_name](HloInstruction* instruction) { + return instruction->name() == hlo_name; + }) - + instruction_sequence.begin(); + }; + + EXPECT_LT(get_index("recv"), get_index("send")); + EXPECT_LT(get_index("send"), get_index("recv-done")); + EXPECT_GE(get_index("send-done") - get_index("recv-done"), 9); + EXPECT_GT(get_index("send-done"), get_index("all-reduce-done")); + EXPECT_TRUE(HasValidFingerprint(module.get())); +} + class GpuHloScheduleParameterizedTest : public GpuHloScheduleTest, public ::testing::WithParamInterface {}; From a4a92fefd31c556d7532aa72f94dea38b5af4eca Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Tue, 25 Jul 2023 17:02:45 -0700 Subject: [PATCH 158/410] [XLA:Python] Use int64_t instead of ssize_t to fix Mac compiler errors (hopefully). Apparently ssize_t is only a long sometimes (at least 32-bit), instead of long long (at least 64-bit). I don't have a mac so I can't repro the failing build, but hopefully this fixes it based on the error message. PiperOrigin-RevId: 551038967 --- tensorflow/compiler/xla/python/py_buffer.cc | 2 +- .../compiler/xla/python/py_host_callback.cc | 2 +- tensorflow/compiler/xla/python/types.cc | 30 +++++-------------- tensorflow/compiler/xla/python/types.h | 7 ++--- 4 files changed, 13 insertions(+), 28 deletions(-) diff --git a/tensorflow/compiler/xla/python/py_buffer.cc b/tensorflow/compiler/xla/python/py_buffer.cc index 2558c885220cfa..0c63259c7591c1 100644 --- a/tensorflow/compiler/xla/python/py_buffer.cc +++ b/tensorflow/compiler/xla/python/py_buffer.cc @@ -65,7 +65,7 @@ std::optional> ByteStridesOrDefaultForShapeInt64( if (!shape.has_layout() || HasMajorToMinorLayout(shape)) { return std::nullopt; } - return ByteStridesForShapeInt64(shape); + return ByteStridesForShape(shape); } } // namespace diff --git a/tensorflow/compiler/xla/python/py_host_callback.cc b/tensorflow/compiler/xla/python/py_host_callback.cc index d49ef458e024d2..1aed6f4429da43 100644 --- a/tensorflow/compiler/xla/python/py_host_callback.cc +++ b/tensorflow/compiler/xla/python/py_host_callback.cc @@ -86,7 +86,7 @@ StatusOr> CreateCallbackResults( callback_results[i].expected_dims.resize(shape.dimensions_size()); absl::c_copy(shape.dimensions(), callback_results[i].expected_dims.begin()); - callback_results[i].expected_strides = ByteStridesForShapeInt64(shape); + callback_results[i].expected_strides = ByteStridesForShape(shape); callback_results[i].type = shape.element_type(); callback_results[i].size_in_bytes = ShapeUtil::ByteSizeOf(shape); callback_results[i].reversed_layout.resize(shape.dimensions_size()); diff --git a/tensorflow/compiler/xla/python/types.cc b/tensorflow/compiler/xla/python/types.cc index e7fb22c64812cc..0642a8a8cc6ff4 100644 --- a/tensorflow/compiler/xla/python/types.cc +++ b/tensorflow/compiler/xla/python/types.cc @@ -342,21 +342,21 @@ PrimitiveType Squash64BitTypes(PrimitiveType type) { } // Returns the strides for `shape`. -std::vector ByteStridesForShape(const Shape& shape) { - std::vector strides; +std::vector ByteStridesForShape(const Shape& shape) { + std::vector strides; CHECK(shape.IsArray()); CHECK(shape.has_layout()); return ByteStridesForShape(shape.element_type(), shape.dimensions(), shape.layout()); } -static std::vector StridesForShapeHelper( +static std::vector StridesForShapeHelper( PrimitiveType element_type, absl::Span dimensions, - const xla::Layout& layout, ssize_t innermost_stride_size) { + const xla::Layout& layout, int64_t innermost_stride_size) { CHECK_EQ(dimensions.size(), layout.minor_to_major().size()); - std::vector strides; + std::vector strides; strides.resize(dimensions.size()); - ssize_t stride = innermost_stride_size; + int64_t stride = innermost_stride_size; for (int i : layout.minor_to_major()) { strides[i] = stride; stride *= dimensions[i]; @@ -364,7 +364,7 @@ static std::vector StridesForShapeHelper( return strides; } -std::vector ByteStridesForShape(PrimitiveType element_type, +std::vector ByteStridesForShape(PrimitiveType element_type, absl::Span dimensions, const xla::Layout& layout) { return StridesForShapeHelper( @@ -372,27 +372,13 @@ std::vector ByteStridesForShape(PrimitiveType element_type, ShapeUtil::ByteSizeOfPrimitiveType(element_type)); } -std::vector StridesForShape(PrimitiveType element_type, +std::vector StridesForShape(PrimitiveType element_type, absl::Span dimensions, const xla::Layout& layout) { return StridesForShapeHelper(element_type, dimensions, layout, /*innermost_stride_size=*/1); } -std::vector ByteStridesForShapeInt64(const Shape& shape) { - std::vector strides; - CHECK(shape.IsArray()); - CHECK(shape.has_layout()); - - strides.resize(shape.dimensions_size()); - int64_t stride = ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type()); - for (int i : shape.layout().minor_to_major()) { - strides.at(i) = stride; - stride *= shape.dimensions(i); - } - return strides; -} - StatusOr LiteralToPython(std::shared_ptr literal) { xla::Literal& m = *literal; if (m.shape().IsTuple()) { diff --git a/tensorflow/compiler/xla/python/types.h b/tensorflow/compiler/xla/python/types.h index 685db05580dc43..7418da2a5494d4 100644 --- a/tensorflow/compiler/xla/python/types.h +++ b/tensorflow/compiler/xla/python/types.h @@ -83,14 +83,13 @@ const NumpyScalarTypes& GetNumpyScalarTypes(); PrimitiveType Squash64BitTypes(PrimitiveType type); // Returns the strides for `shape`. -std::vector ByteStridesForShape(const Shape& shape); -std::vector ByteStridesForShape(PrimitiveType element_type, +std::vector ByteStridesForShape(const Shape& shape); +std::vector ByteStridesForShape(PrimitiveType element_type, absl::Span dimensions, const xla::Layout& layout); -std::vector StridesForShape(PrimitiveType element_type, +std::vector StridesForShape(PrimitiveType element_type, absl::Span dimensions, const xla::Layout& layout); -std::vector ByteStridesForShapeInt64(const Shape& shape); // Converts a literal to (possibly-nested tuples of) NumPy arrays. // The literal's leaf arrays are not copied; instead the NumPy arrays share From e51812bae31268bfb0f405eecc23cbd25ae2844e Mon Sep 17 00:00:00 2001 From: Matt Callanan Date: Tue, 25 Jul 2023 17:20:41 -0700 Subject: [PATCH 159/410] #tf-data-service Fix typo. PiperOrigin-RevId: 551043155 --- tensorflow/core/data/service/dispatcher_state.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/data/service/dispatcher_state.h b/tensorflow/core/data/service/dispatcher_state.h index 804ca7971f3748..4ace89d4d27d14 100644 --- a/tensorflow/core/data/service/dispatcher_state.h +++ b/tensorflow/core/data/service/dispatcher_state.h @@ -289,7 +289,7 @@ class DispatcherState { // deterministically sharding a dataset among a fixed set of workers. StatusOr GetWorkerIndex(absl::string_view worker_address) const; - // Returns the paths of all snapshots inititated during the lifetime of this + // Returns the paths of all snapshots initiated during the lifetime of this // journal. const absl::flat_hash_set& ListSnapshotPaths() const { return snapshot_paths_; From b245b5657184d96c2921e78367c2b90f1d5df2a0 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 25 Jul 2023 17:25:04 -0700 Subject: [PATCH 160/410] Add WeakTensor support to all binary mathematical dunder methods and a subset of mathematical binary ops in math_ops.py (add, subtract, multply, scalar_mul, matmul, multiply_no_nan, divide, true_div, floor_div, real_div, truncate_div, divide, div_no_nan, mod, truncate_div, pow). PiperOrigin-RevId: 551044031 --- tensorflow/python/ops/BUILD | 7 + tensorflow/python/ops/math_ops.py | 14 +- .../python/ops/weak_tensor_image_ops_test.py | 26 +- .../python/ops/weak_tensor_math_ops_test.py | 275 ++++++- tensorflow/python/ops/weak_tensor_ops.py | 85 +- tensorflow/python/ops/weak_tensor_ops_test.py | 734 +++++++++++++++++- .../python/ops/weak_tensor_test_util.py | 2 + 7 files changed, 1078 insertions(+), 65 deletions(-) diff --git a/tensorflow/python/ops/BUILD b/tensorflow/python/ops/BUILD index 7db4f82ed12b17..9343cc010648b3 100644 --- a/tensorflow/python/ops/BUILD +++ b/tensorflow/python/ops/BUILD @@ -4488,11 +4488,13 @@ py_strict_test( ":clip_ops", ":image_ops", ":math_ops", + ":math_ops_gen", ":weak_tensor_ops", ":weak_tensor_test_util", "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:extension_type", + "//tensorflow/python/framework:flexible_dtypes", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:test_lib", @@ -4514,7 +4516,9 @@ py_strict_test( deps = [ ":array_ops", ":math_ops", + ":resource_variable_ops", ":tensor_array_ops", + ":variables", ":weak_tensor_ops", ":weak_tensor_test_util", "//tensorflow/core:protos_all_py", @@ -4522,7 +4526,9 @@ py_strict_test( "//tensorflow/python/eager:backprop", "//tensorflow/python/eager:context", "//tensorflow/python/eager:def_function", + "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:indexed_slices", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:test_lib", @@ -4672,6 +4678,7 @@ py_strict_library( srcs = ["weak_tensor_test_util.py"], deps = [ "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:weak_tensor", "//third_party/py/numpy", diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index c90bde289564df..10bdf8dac1ec2a 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -1428,6 +1428,8 @@ def maybe_promote_tensors(*tensors, force_same_dtype=False): Returns: The promoted list of tensors. """ + if ops.is_auto_dtype_conversion_enabled(): + return tensors if not tensors: return tensors if not ops.is_numpy_style_type_promotion(): @@ -1837,13 +1839,13 @@ def _add_dispatch(x, y, name=None): Returns: The result of the elementwise `+` operation. """ - if not isinstance(y, tensor_lib.Tensor) and not isinstance( - y, sparse_tensor.SparseTensor): + if ( + not ops.is_auto_dtype_conversion_enabled() + and not isinstance(y, tensor_lib.Tensor) + and not isinstance(y, sparse_tensor.SparseTensor) + ): y = ops.convert_to_tensor(y, dtype_hint=x.dtype.base_dtype, name="y") - if x.dtype == dtypes.string: - return gen_math_ops.add(x, y, name=name) - else: - return gen_math_ops.add_v2(x, y, name=name) + return add(x, y, name=name) def _mul_dispatch(x, y, name=None): diff --git a/tensorflow/python/ops/weak_tensor_image_ops_test.py b/tensorflow/python/ops/weak_tensor_image_ops_test.py index b9603a867e7cb5..fe037723eb8098 100644 --- a/tensorflow/python/ops/weak_tensor_image_ops_test.py +++ b/tensorflow/python/ops/weak_tensor_image_ops_test.py @@ -64,7 +64,7 @@ def testPositiveDeltaFloat64(self): class AdjustGamma(test.TestCase): def test_adjust_gamma_less_zero_float32(self): - """White image should be returned for gamma equal to zero""" + """White image should be returned for gamma equal to zero.""" with self.cached_session(): x_data = np.random.uniform(0, 1.0, (8, 8)) x_np = np.array(x_data, dtype=np.float32) @@ -77,7 +77,7 @@ def test_adjust_gamma_less_zero_float32(self): image_ops.adjust_gamma(x, gamma=-1) def test_adjust_gamma_less_zero_tensor(self): - """White image should be returned for gamma equal to zero""" + """White image should be returned for gamma equal to zero.""" with self.cached_session(): x_data = np.random.uniform(0, 1.0, (8, 8)) x_np = np.array(x_data, dtype=np.float32) @@ -92,10 +92,7 @@ def test_adjust_gamma_less_zero_tensor(self): self.evaluate(image) def _test_adjust_gamma_float32(self, gamma): - """Verifying the output with expected results for gamma - - correction for float32 images - """ + """Verifying the output with expected results for gamma correction for float32 images.""" with self.cached_session(): x_np = np.random.uniform(0, 1.0, (8, 8)) x = _get_weak_tensor(x_np, shape=x_np.shape) @@ -108,28 +105,19 @@ def _test_adjust_gamma_float32(self, gamma): self.assertAllClose(y_tf, y_np, 1e-6) def test_adjust_gamma_one_float32(self): - """Same image should be returned for gamma equal to one""" + """Same image should be returned for gamma equal to one.""" self._test_adjust_gamma_float32(1.0) def test_adjust_gamma_less_one_float32(self): - """Verifying the output with expected results for gamma - - correction with gamma equal to half for float32 images - """ + """Verifying the output with expected results for gamma correction with gamma equal to half for float32 images.""" self._test_adjust_gamma_float32(0.5) def test_adjust_gamma_greater_one_float32(self): - """Verifying the output with expected results for gamma - - correction with gamma equal to two for float32 images - """ + """Verifying the output with expected results for gamma correction with gamma equal to two for float32 images.""" self._test_adjust_gamma_float32(1.0) def test_adjust_gamma_zero_float32(self): - """White image should be returned for gamma equal - - to zero for float32 images - """ + """White image should be returned for gamma equal to zero for float32 images.""" self._test_adjust_gamma_float32(0.0) diff --git a/tensorflow/python/ops/weak_tensor_math_ops_test.py b/tensorflow/python/ops/weak_tensor_math_ops_test.py index 82dda558c67c8f..3cbcd133ea1132 100644 --- a/tensorflow/python/ops/weak_tensor_math_ops_test.py +++ b/tensorflow/python/ops/weak_tensor_math_ops_test.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================== """Tests for tensorflow.ops.math_ops on WeakTensor.""" + from absl.testing import parameterized import numpy as np @@ -21,14 +22,18 @@ from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import def_function +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops from tensorflow.python.framework import tensor from tensorflow.python.framework import test_util from tensorflow.python.framework.weak_tensor import WeakTensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import tensor_array_ops +from tensorflow.python.ops import variables from tensorflow.python.ops import weak_tensor_ops # pylint: disable=unused-import from tensorflow.python.ops import weak_tensor_test_util from tensorflow.python.ops.ragged import ragged_factory_ops @@ -50,7 +55,7 @@ class ReduceTest(test_util.TensorFlowTestCase, parameterized.TestCase): ) def testReduceAllDims(self, input_type, result_type): test_input = _convert_to_input_type( - [[1, 2, 3], [4, 5, 6]], input_type, np.int32 + [[1, 2, 3], [4, 5, 6]], input_type, dtypes.int32 ) with test_util.device(use_gpu=True): res = math_ops.reduce_sum(test_input) @@ -88,7 +93,7 @@ def testCountNonzero(self): ("Tensor", tensor.Tensor), ) def testReduceExplicitAxes(self, input_type, result_type): - x = _convert_to_input_type([[1, 2, 3], [4, 5, 6]], input_type, np.int32) + x = _convert_to_input_type([[1, 2, 3], [4, 5, 6]], input_type, dtypes.int32) with test_util.device(use_gpu=True): for axis in (0, -2): res = math_ops.reduce_sum(x, axis=axis) @@ -453,6 +458,272 @@ def test_fn(): self.assertAllEqual(self.evaluate(test_fn()), [1]) + +@test_util.run_all_in_graph_and_eager_modes +class BinaryOpsTest(test_util.TensorFlowTestCase): + + def testRHSDispatchingAndErrorRaising(self): + if context.executing_eagerly(): + error = ValueError + error_message = r"Attempt to convert a value .* with an unsupported type" + else: + error = TypeError + error_message = r"Failed to convert elements of .* to Tensor" + + class RHSReturnsTrue: + + def __radd__(self, other): + return True + + a = array_ops.ones([1], dtype=dtypes.int32) + RHSReturnsTrue() + self.assertEqual(a, True) + + a = _get_weak_tensor(5, dtype=dtypes.int32) + RHSReturnsTrue() + self.assertEqual(a, True) + + class RHSRaisesError: + + def __radd__(self, other): + raise TypeError("RHS not implemented") + + with self.assertRaisesRegex(error, error_message): + a = array_ops.ones([1], dtype=dtypes.int32) + RHSRaisesError() + self.evaluate(a) + + with self.assertRaisesRegex(error, error_message): + a = _get_weak_tensor([1], dtype=dtypes.int32) + RHSRaisesError() + self.evaluate(a) + + class RHSReturnsNotImplemented: + + def __radd__(self, other): + return NotImplemented + + with self.assertRaisesRegex(error, error_message): + a = array_ops.ones([1], dtype=dtypes.int32) + RHSReturnsNotImplemented() + self.evaluate(a) + + a = _get_weak_tensor([1], dtype=dtypes.int32) + RHSReturnsNotImplemented() + self.evaluate(a) + + class RHSNotImplemented: + pass + + with self.assertRaisesRegex(error, error_message): + a = array_ops.ones([1], dtype=dtypes.int32) + RHSNotImplemented() + self.evaluate(a) + + a = _get_weak_tensor([1], dtype=dtypes.int32) + RHSNotImplemented() + self.evaluate(a) + + +@test_util.run_all_in_graph_and_eager_modes +class ScalarMulTest(test_util.TensorFlowTestCase): + + def testAcceptsRefs(self): + if context.executing_eagerly(): + var = resource_variable_ops.ResourceVariable(10, name="var") + else: + var = variables.Variable(10) + result = math_ops.scalar_mul(3, var) + init = variables.global_variables_initializer() + with test_util.device(use_gpu=True): + self.evaluate(init) + self.assertEqual(30, self.evaluate(result)) + + def testAcceptsIndexedSlices(self): + values = constant_op.constant([2, 3, 5, 7, 0, -1], shape=[3, 2]) + indices = constant_op.constant([0, 2, 5]) + # Test that patched scalar_mul works with IndexedSlices. + x = math_ops.scalar_mul(-3, indexed_slices.IndexedSlices(values, indices)) + with test_util.device(use_gpu=True): + self.assertAllEqual( + self.evaluate(x.values), [[-6, -9], [-15, -21], [0, 3]] + ) + self.assertAllEqual(self.evaluate(x.indices), [0, 2, 5]) + + +@test_util.run_all_in_graph_and_eager_modes +class DivAndModTest(test_util.TensorFlowTestCase): + + def numpySafeFloorDivInt(self, x, y): + z = x // y + # Numpy produces 0 for INT_MIN/-1, but we expect an overflow to INT_MIN + # so that (INT_MIN/-1) + (INT_MIN % -1) = INT_MIN + 0 = INT_MIN. + z[(x == np.iinfo(x.dtype).min) & (y == -1)] = np.iinfo(x.dtype).min + return z + + def numpySafeFloorModInt(self, x, y): + # Numpy crashes with a FPE for INT_MIN % -1. + z = self.numpySafeFloorDivInt(x, y) + return x - z * y + + def numpySafeTruncateDivInt(self, x, y): + z = self.numpySafeFloorDivInt(x, y) + # Round up if non-zero remainder and inputs have opposite signs. + z[(x != z * y) & ((x < 0) != (y < 0))] += 1 + return z + + def numpySafeTruncateModInt(self, x, y): + # Numpy crashes with a FPE for INT_MIN % -1. + z = self.numpySafeTruncateDivInt(x, y) + return x - z * y + + def intEdgeTestData(self, dtype): + """Edge-case test data for integer types.""" + # INT_MIN/-1 expected to produce signed-integer overflow, + # INT_MIN/INT_MAX expected to work. + nums = np.array( + [np.iinfo(dtype).min, -1, 1, np.iinfo(dtype).max], dtype=dtype + ).reshape([4, 1]) + divs = nums.reshape([1, 4]) + return nums, divs + + @test_util.disable_asan("Expected signed integer overflow.") + @test_util.disable_ubsan("Expected signed integer overflow.") + def testFloorDivModIntEdges(self): + for dtype in [np.int32, np.int64]: + x, y = self.intEdgeTestData(dtype) + x_weak, y_weak = _get_weak_tensor(x), _get_weak_tensor(y) + tf_floor_div = math_ops.floor_div(x_weak, y_weak) + np_floor_div = self.numpySafeFloorDivInt(x, y) + self.assertIsInstance(tf_floor_div, WeakTensor) + self.assertAllEqual(tf_floor_div, np_floor_div) + + tf_floor_mod = math_ops.floormod(x_weak, y_weak) + np_floor_mod = self.numpySafeFloorModInt(x, y) + self.assertIsInstance(tf_floor_div, WeakTensor) + self.assertAllEqual(tf_floor_mod, np_floor_mod) + z = math_ops.add(math_ops.multiply(tf_floor_div, y_weak), tf_floor_mod) + # x = floor_div(x, y) * y + floor_mod(x, y) + self.assertIsInstance(z, WeakTensor) + self.assertAllEqual(z, np.broadcast_to(x, z.shape)) + + @test_util.disable_asan("Expected signed integer overflow.") + @test_util.disable_ubsan("Expected signed integer overflow.") + def testTruncateDivModIntEdges(self): + for dtype in [np.int32, np.int64]: + x, y = self.intEdgeTestData(dtype) + x_weak, y_weak = _get_weak_tensor(x), _get_weak_tensor(y) + tf_truncate_div = math_ops.truncatediv(x_weak, y_weak) + np_truncate_div = self.numpySafeTruncateDivInt(x, y) + self.assertIsInstance(tf_truncate_div, WeakTensor) + self.assertAllEqual(tf_truncate_div, np_truncate_div) + + tf_truncate_mod = math_ops.truncatemod(x_weak, y_weak) + np_truncate_mod = self.numpySafeTruncateModInt(x, y) + self.assertIsInstance(tf_truncate_mod, WeakTensor) + self.assertAllEqual(tf_truncate_mod, np_truncate_mod) + z = math_ops.add( + math_ops.multiply(tf_truncate_div, y_weak), tf_truncate_mod + ) + self.assertIsInstance(z, WeakTensor) + # x = truncatediv(x, y) * y + truncatemod(x, y) + self.assertAllEqual(z, np.broadcast_to(x, z.shape)) + + +@test_util.run_all_in_graph_and_eager_modes +class DivNoNanTest(test_util.TensorFlowTestCase, parameterized.TestCase): + _SUPPORTED_DTYPES = [ + dtypes.int32, + dtypes.int64, + dtypes.float32, + dtypes.float64, + dtypes.complex128, + ] + + @parameterized.parameters(*_SUPPORTED_DTYPES) + def testBasic(self, dtype): + if dtype.is_unsigned: + nums = np.arange(0, 120, 3).reshape(40, 1) + divs = np.arange(0, 48, 4).reshape(1, 12) + elif dtype.is_integer: + nums = np.arange(-120, 120, 3).reshape(80, 1) + divs = np.arange(-48, 48, 4).reshape(1, 24) + else: + nums = np.arange(-10, 10, 0.25).reshape(80, 1) + divs = np.arange(-3, 3, 0.25).reshape(1, 24) + assert 0 in divs, "Bad test set-up" + + tf_nums = _get_weak_tensor(nums, dtype=dtype) + tf_divs = _get_weak_tensor(divs, dtype=dtype) + + # Use tf versions for expected value to ensure inputs are identical + np_nums = self.evaluate(tf_nums) + np_divs = self.evaluate(tf_divs) + np_result = np.true_divide(np_nums, np_divs) + np_result[:, np_divs[0] == 0] = 0 + + with test_util.use_gpu(): + tf_result = math_ops.div_no_nan(tf_nums, tf_divs) + self.assertIsInstance(tf_result, WeakTensor) + self.assertAllCloseAccordingToType(tf_result, np_result) + + @parameterized.product( + type_x=_SUPPORTED_DTYPES + [float, int], + type_y=_SUPPORTED_DTYPES + [float, int], + ) + def testSameSupportedTypesAsDivide(self, type_x, type_y): + def one(type_): + if type_ is int: + return 1 + elif type_ is float: + return 1.0 + else: + return _get_weak_tensor(1, dtype=type_) + + x = one(type_x) + y = one(type_y) + + divide_raises = False + try: + divide_result = math_ops.divide(x, y) + except TypeError: + divide_raises = True + + if divide_raises: + with self.assertRaises(TypeError): + _ = math_ops.div_no_nan(x, y) + else: + divide_no_nan_result = math_ops.div_no_nan(x, y) + self.assertIsInstance(divide_no_nan_result, WeakTensor) + self.assertIsInstance(divide_result, WeakTensor) + self.assertEqual(divide_no_nan_result.dtype, divide_result.dtype) + self.assertAllEqual(divide_no_nan_result, divide_result) + + @parameterized.parameters( + (dtypes.float32), + (dtypes.float64), + (dtypes.complex128), + ) + def testSmall(self, dtype): + # Choose values whose squared magnitude underflows to zero/subnormal. + zero = _get_weak_tensor([0, 0, 0, 0], dtype=dtype) + divs = _get_weak_tensor([1e-25, -1e-20, 1e-165, -1e-160], dtype=dtype) + tf_result = math_ops.div_no_nan(zero, divs) + + # Results should always be exactly zero. + self.assertAllEqual(tf_result, zero) + self.assertIsInstance(tf_result, WeakTensor) + + @parameterized.parameters( + (dtypes.float32), + (dtypes.float64), + (dtypes.complex128), + ) + def testNonFiniteInNumerator(self, dtype): + nums = _get_weak_tensor([np.nan, np.inf, np.NINF], dtype=dtype) + zeros = _get_weak_tensor([0, 0, 0], dtype=dtype) + ones = _get_weak_tensor([1, 1, 1], dtype=dtype) + with test_util.use_gpu(): + tf_result_zeros = math_ops.div_no_nan(nums, zeros) + self.assertAllEqual([0, 0, 0], tf_result_zeros) + self.assertIsInstance(tf_result_zeros, WeakTensor) + tf_result_ones = math_ops.div_no_nan(nums, ones) + self.assertAllEqual(nums / ones, tf_result_ones) + self.assertIsInstance(tf_result_ones, WeakTensor) + + if __name__ == "__main__": ops.set_dtype_conversion_mode("all") googletest.main() diff --git a/tensorflow/python/ops/weak_tensor_ops.py b/tensorflow/python/ops/weak_tensor_ops.py index ad11e679617435..ec92340712c71a 100644 --- a/tensorflow/python/ops/weak_tensor_ops.py +++ b/tensorflow/python/ops/weak_tensor_ops.py @@ -56,7 +56,7 @@ def _convert_or_cast(x, dtype, name): return math_ops.cast(x, dtype=dtype, name=name) -def weak_tensor_unary_op_wrapper(op): +def weak_tensor_unary_op_wrapper(op, x_arg_name=None): """Infers input type and adds WeakTensor support to unary ops. This wrapper infers input type according to the auto dtype conversion @@ -66,8 +66,9 @@ def weak_tensor_unary_op_wrapper(op): returns WeakTensor. """ signature = inspect.signature(op) - arg_names = iter(signature.parameters.keys()) - x_arg_name = next(arg_names) + if x_arg_name is None: + arg_names = iter(signature.parameters.keys()) + x_arg_name = next(arg_names) def wrapper(*args, **kwargs): if not ops.is_auto_dtype_conversion_enabled(): @@ -115,7 +116,6 @@ def weak_tensor_binary_op_wrapper(op): inputs. Then, both inputs are promoted to the correct promotion result dtype. If the result promotion dtype is "weak", returns WeakTensor. """ - signature = inspect.signature(op) arg_names = iter(signature.parameters.keys()) x_arg_name = next(arg_names) @@ -448,43 +448,66 @@ def _update_weak_tensor_patched_ops_in_dispatch_dict(patched_op): np_array_ops.var = weak_tensor_unary_op_wrapper(np_array_ops.var) np_array_ops.zeros_like = weak_tensor_unary_op_wrapper(np_array_ops.zeros_like) +# Binary ops +math_ops.add = weak_tensor_binary_op_wrapper(math_ops.add) +gen_math_ops.sub = weak_tensor_binary_op_wrapper(gen_math_ops.sub) +math_ops.multiply = weak_tensor_binary_op_wrapper(math_ops.multiply) +math_ops.multiply_no_nan = weak_tensor_binary_op_wrapper( + math_ops.multiply_no_nan +) +math_ops.matmul = weak_tensor_binary_op_wrapper(math_ops.matmul) +# In scalar_mul(scalar, x), dtype should be solely inferred from the dtype of x. +math_ops.scalar_mul = weak_tensor_unary_op_wrapper(math_ops.scalar_mul, "x") +math_ops.divide = weak_tensor_binary_op_wrapper(math_ops.divide) +math_ops.div_no_nan = weak_tensor_binary_op_wrapper(math_ops.div_no_nan) +# pylint: disable=protected-access +math_ops._truediv_python3 = weak_tensor_binary_op_wrapper( + math_ops._truediv_python3 +) +gen_math_ops.real_div = weak_tensor_binary_op_wrapper(gen_math_ops.real_div) +gen_math_ops.truncate_div = weak_tensor_binary_op_wrapper( + gen_math_ops.truncate_div +) +gen_math_ops.floor_div = weak_tensor_binary_op_wrapper(gen_math_ops.floor_div) +gen_math_ops.truncate_mod = weak_tensor_binary_op_wrapper( + gen_math_ops.truncate_mod +) +gen_math_ops.floor_mod = weak_tensor_binary_op_wrapper(gen_math_ops.floor_mod) +gen_math_ops._pow = weak_tensor_binary_op_wrapper(gen_math_ops._pow) + + # ============================================================================== # Update old op references. # ============================================================================== +math_ops.mod = gen_math_ops.floor_mod +math_ops.realdiv = gen_math_ops.real_div +math_ops.truncatediv = gen_math_ops.truncate_div +math_ops.floor_div = gen_math_ops.floor_div +math_ops.truncatemod = gen_math_ops.truncate_mod +math_ops.floormod = gen_math_ops.floor_mod + # Update Tensor dunder methods. -tensor.Tensor.__add__ = math_ops.add -tensor.Tensor.__sub__ = math_ops.sub -tensor.Tensor.__mul__ = math_ops.multiply -tensor.Tensor.__div__ = math_ops.div -tensor.Tensor.__truediv__ = math_ops.truediv -tensor.Tensor.__floordiv__ = math_ops.floordiv +# Rest of the dunder methods call the updated op because those ops have +# Python wrapper functions that call the patched op. (e.g. __add__ = +# _add_dispatch and _add_dispatch calls the updated math_ops.add). tensor.Tensor.__mod__ = gen_math_ops.floor_mod -tensor.Tensor.__pow__ = math_ops.pow -tensor.Tensor.__matmul__ = math_ops.matmul +tensor.Tensor.__rmod__ = weak_tensor_binary_op_wrapper(tensor.Tensor.__rmod__) # Set WeakTensor dunder methods. +# Tensor unary ops do not need WeakTensor support. weak_tensor.WeakTensor.__invert__ = math_ops.invert_ weak_tensor.WeakTensor.__neg__ = gen_math_ops.neg weak_tensor.WeakTensor.__abs__ = math_ops.abs -weak_tensor.WeakTensor.__add__ = math_ops.add -weak_tensor.WeakTensor.__sub__ = math_ops.sub -weak_tensor.WeakTensor.__mul__ = math_ops.multiply -weak_tensor.WeakTensor.__div__ = math_ops.div -weak_tensor.WeakTensor.__truediv__ = math_ops.truediv -weak_tensor.WeakTensor.__floordiv__ = math_ops.floordiv -weak_tensor.WeakTensor.__mod__ = gen_math_ops.floor_mod -weak_tensor.WeakTensor.__pow__ = math_ops.pow -weak_tensor.WeakTensor.__matmul__ = math_ops.matmul -weak_tensor.WeakTensor.__radd__ = tensor.Tensor.__radd__ -weak_tensor.WeakTensor.__rsub__ = tensor.Tensor.__rsub__ -weak_tensor.WeakTensor.__rmul__ = tensor.Tensor.__rmul__ -weak_tensor.WeakTensor.__rdiv__ = tensor.Tensor.__rdiv__ -weak_tensor.WeakTensor.__rtruediv__ = tensor.Tensor.__rtruediv__ -weak_tensor.WeakTensor.__rfloordiv__ = tensor.Tensor.__rfloordiv__ -weak_tensor.WeakTensor.__rmod__ = tensor.Tensor.__rmod__ -weak_tensor.WeakTensor.__rpow__ = tensor.Tensor.__rpow__ -weak_tensor.WeakTensor.__rmatmul__ = tensor.Tensor.__rmatmul__ + +# Inherit rest of the dunder methods from Tensor. +unary_dunder_methods = ["__invert__", "__neg__", "__abs__"] +for operator in tensor.Tensor.OVERLOADABLE_OPERATORS: + if operator in unary_dunder_methods: + continue + tensor_oper = getattr(tensor.Tensor, operator) + setattr(weak_tensor.WeakTensor, operator, tensor_oper) # Add/Update NumPy methods in Tensor and WeakTensor. np_math_ops.enable_numpy_methods_on_tensor() -np_math_ops._enable_numpy_methods(weak_tensor.WeakTensor) # pylint: disable=protected-access +np_math_ops._enable_numpy_methods(weak_tensor.WeakTensor) +# pylint: enable=protected-access diff --git a/tensorflow/python/ops/weak_tensor_ops_test.py b/tensorflow/python/ops/weak_tensor_ops_test.py index fc57ba70c2bae2..f1a50170ab8ada 100644 --- a/tensorflow/python/ops/weak_tensor_ops_test.py +++ b/tensorflow/python/ops/weak_tensor_ops_test.py @@ -20,6 +20,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import extension_type +from tensorflow.python.framework import flexible_dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor from tensorflow.python.framework import test_util @@ -28,6 +29,7 @@ from tensorflow.python.ops import clip_ops from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import gen_bitwise_ops +from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import image_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import weak_tensor_ops @@ -40,9 +42,10 @@ from tensorflow.python.util import dispatch +DtypeConversionTestEnv = weak_tensor_test_util.DtypeConversionTestEnv _get_weak_tensor = weak_tensor_test_util.get_weak_tensor _convert_to_input_type = weak_tensor_test_util.convert_to_input_type - +_DTYPE_PROMO_RES = flexible_dtypes._BINARY_DTYPE_RES_HALF _TF_UNARY_APIS = weak_tensor_ops._TF_UNARY_APIS _TF_UNARY_APIS_SPECIFIC_DTYPE = [ @@ -75,6 +78,8 @@ array_ops.extract_image_patches_v2, array_ops.space_to_depth, array_ops.space_to_depth_v2, + math_ops.scalar_mul, + math_ops.scalar_mul_v2, ] _TF_UNARY_APIS_WITH_INT_INPUT = [ gen_bitwise_ops.invert, @@ -98,12 +103,24 @@ array_ops.transpose_v2, ] +all_dtype_promos_list = [] +safe_mode_unallowed_promos_list = [] +for key in _DTYPE_PROMO_RES: + if key[0] == dtypes.bool: + continue + for k, v in _DTYPE_PROMO_RES[key].items(): + if v[1] == ops.PromoMode.ALL: + safe_mode_unallowed_promos_list.append((key, k)) + all_dtype_promos_list.append((key, k, v[0])) + class MyTensor(extension_type.ExtensionType): value: tensor.Tensor -class WeakTensorOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase): +class WeakTensorUnaryOpsTest( + test_util.TensorFlowTestCase, parameterized.TestCase +): # Test unary ops with one input. @parameterized.named_parameters( @@ -112,7 +129,7 @@ class WeakTensorOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase): ) def test_unary_ops_return_weak_tensor(self, unary_api): weak_tensor_input, python_input, tensor_input, numpy_input = ( - _get_test_input(unary_api) + get_test_input_for_unary_op(unary_api) ) # Check that WeakTensor input outputs a WeakTensor. @@ -138,7 +155,7 @@ def test_unary_ops_return_weak_tensor(self, unary_api): @parameterized.parameters( ("WeakTensor", dtypes.float32, WeakTensor), ("Python", dtypes.float32, WeakTensor), - ("NumPy", np.float32, tensor.Tensor), + ("NumPy", dtypes.float32, tensor.Tensor), ("NumPy", None, tensor.Tensor), ("Tensor", dtypes.float32, tensor.Tensor), ) @@ -213,7 +230,7 @@ def test_unary_ops_return_normal_tensor(self, unary_api_specific_dtype): @parameterized.parameters( ("WeakTensor", dtypes.float32, WeakTensor), ("Python", None, WeakTensor), - ("NumPy", np.float32, tensor.Tensor), + ("NumPy", dtypes.float32, tensor.Tensor), ("NumPy", None, tensor.Tensor), ("Tensor", dtypes.float32, tensor.Tensor), ) @@ -250,7 +267,7 @@ def test_elementwise_unary_ops_optional_dtype( ("Python", None, None, WeakTensor), ("Python", None, dtypes.int32, tensor.Tensor), ("NumPy", None, None, tensor.Tensor), - ("NumPy", None, np.int32, tensor.Tensor), + ("NumPy", None, dtypes.int32, tensor.Tensor), ("Tensor", dtypes.float32, None, tensor.Tensor), ("Tensor", dtypes.float32, dtypes.int32, tensor.Tensor), ) @@ -365,9 +382,712 @@ def testNumpyMethodsOnWeakTensor(self, np_method, result_type, *args): self.assertAllClose(wt_np_result, t_np_result) +@parameterized.parameters(all_dtype_promos_list) +class WeakTensorBinaryOpsTest( + test_util.TensorFlowTestCase, parameterized.TestCase +): + + def match_expected(self, actual, expected_val, expected_dtype): + dtype, weak = expected_dtype + expected_type = WeakTensor if weak else tensor.Tensor + self.assertIsInstance(actual, expected_type) + self.assertEqual(actual.dtype, dtype) + self.assertAllEqual(actual, expected_val) + + def test_weak_tensor_add(self, a_dtype, b_dtype, expected_dtype): + def run_test_add(a, b): + a_list = get_test_input_for_binary_op(a, a_dtype) + b_list = get_test_input_for_binary_op(b, b_dtype) + expected_val = constant_op.constant( + a, expected_dtype[0] + ) + constant_op.constant(b, expected_dtype[0]) + for x, y in zip(a_list, b_list): + self.match_expected(math_ops.add(x, y), expected_val, expected_dtype) + self.match_expected(math_ops.add(y, x), expected_val, expected_dtype) + if at_least_one_tensor_type(x, y): + self.match_expected(x + y, expected_val, expected_dtype) + self.match_expected(y + x, expected_val, expected_dtype) + + # Limit testing values to positive numbers inputs to account for + # both unsigned and signed input types. + run_test_add(a=2, b=4) + run_test_add(a=100, b=100) + run_test_add(a=10, b=41) + + def test_weak_tensor_sub(self, a_dtype, b_dtype, expected_dtype): + def run_test_sub(a, b): + a_list = get_test_input_for_binary_op(a, a_dtype) + b_list = get_test_input_for_binary_op(b, b_dtype) + a_tensor = constant_op.constant(a, expected_dtype[0]) + b_tensor = constant_op.constant(b, expected_dtype[0]) + expected_val = a_tensor - b_tensor + expected_val_reverse = b_tensor - a_tensor + for x, y in zip(a_list, b_list): + self.match_expected( + math_ops.subtract(x, y), expected_val, expected_dtype + ) + self.match_expected( + math_ops.subtract(y, x), expected_val_reverse, expected_dtype + ) + if at_least_one_tensor_type(x, y): + self.match_expected(x - y, expected_val, expected_dtype) + self.match_expected(y - x, expected_val_reverse, expected_dtype) + + run_test_sub(a=4, b=2) + run_test_sub(a=41, b=0) + run_test_sub(a=100, b=50) + + def test_weak_tensor_mul(self, a_dtype, b_dtype, expected_dtype): + def run_test_mul(a, b): + a_list = get_test_input_for_binary_op(a, a_dtype) + b_list = get_test_input_for_binary_op(b, b_dtype) + expected_val = constant_op.constant( + a, expected_dtype[0] + ) * constant_op.constant(b, expected_dtype[0]) + for x, y in zip(a_list, b_list): + self.match_expected( + math_ops.multiply(x, y), expected_val, expected_dtype + ) + self.match_expected( + math_ops.multiply(y, x), expected_val, expected_dtype + ) + if at_least_one_tensor_type(a, b): + self.match_expected(x * y, expected_val, expected_dtype) + self.match_expected(y * x, expected_val, expected_dtype) + + run_test_mul(a=4, b=2) + run_test_mul(a=41, b=10) + run_test_mul(a=10, b=5) + + def test_weak_tensor_pow(self, a_dtype, b_dtype, expected_dtype): + def run_test_pow(a, b): + a_list = get_test_input_for_binary_op(a, a_dtype) + b_list = get_test_input_for_binary_op(b, b_dtype) + + # Skip if provided dtype is not a valid input dtype for the op. + if not output_dtype_supported_in_op("pow", expected_dtype[0]): + return + + a_tensor = constant_op.constant(a, expected_dtype[0]) + b_tensor = constant_op.constant(b, expected_dtype[0]) + expected_val = a_tensor**b_tensor + reverse_expected_val = b_tensor**a_tensor + for x, y in zip(a_list, b_list): + self.match_expected(math_ops.pow(x, y), expected_val, expected_dtype) + self.match_expected( + math_ops.pow(y, x), reverse_expected_val, expected_dtype + ) + if at_least_one_tensor_type(x, y): + self.match_expected(x**y, expected_val, expected_dtype) + self.match_expected(y**x, reverse_expected_val, expected_dtype) + + run_test_pow(a=4, b=2) + run_test_pow(a=41, b=10) + run_test_pow(a=2, b=6) + + def test_weak_tensor_mod(self, a_dtype, b_dtype, expected_dtype): + def run_test_mod(a, b): + a_list = get_test_input_for_binary_op(a, a_dtype) + b_list = get_test_input_for_binary_op(b, b_dtype) + + # Skip if provided dtype is not a valid input dtype for the op. + if not output_dtype_supported_in_op("mod", expected_dtype[0]): + return + + a_tensor = constant_op.constant(a, expected_dtype[0]) + b_tensor = constant_op.constant(b, expected_dtype[0]) + expected_val = a_tensor % b_tensor + reverse_expected_val = b_tensor % a_tensor + for x, y in zip(a_list, b_list): + self.match_expected(math_ops.mod(x, y), expected_val, expected_dtype) + self.match_expected( + math_ops.mod(y, x), reverse_expected_val, expected_dtype + ) + # math_ops.mod and gen_math_ops.floor_mod are used interchangeably. + self.match_expected( + gen_math_ops.floor_mod(x, y), expected_val, expected_dtype + ) + self.match_expected( + gen_math_ops.floor_mod(y, x), reverse_expected_val, expected_dtype + ) + if at_least_one_tensor_type(x, y): + self.match_expected(x % y, expected_val, expected_dtype) + self.match_expected(y % x, reverse_expected_val, expected_dtype) + + run_test_mod(a=4, b=2) + run_test_mod(a=41, b=124) + run_test_mod(a=2, b=6) + + def test_weak_tensor_floor_div(self, a_dtype, b_dtype, expected_dtype): + def run_test_floor_div(a, b): + a_list = get_test_input_for_binary_op(a, a_dtype) + b_list = get_test_input_for_binary_op(b, b_dtype) + + # Skip if provided dtype is not a valid input dtype for the op. + if not output_dtype_supported_in_op("floor_div", expected_dtype[0]): + return + + a_tensor = constant_op.constant(a, expected_dtype[0]) + b_tensor = constant_op.constant(b, expected_dtype[0]) + expected_val = a_tensor // b_tensor + reverse_expected_val = b_tensor // a_tensor + for x, y in zip(a_list, b_list): + self.match_expected( + math_ops.floordiv(x, y), expected_val, expected_dtype + ) + self.match_expected( + math_ops.floordiv(y, x), reverse_expected_val, expected_dtype + ) + # math_ops.floordiv and math_ops.floor_div are used interchangeably. + self.match_expected( + math_ops.floor_div(x, y), expected_val, expected_dtype + ) + self.match_expected( + math_ops.floor_div(y, x), reverse_expected_val, expected_dtype + ) + if at_least_one_tensor_type(x, y): + self.match_expected(x // y, expected_val, expected_dtype) + self.match_expected(y // x, reverse_expected_val, expected_dtype) + + run_test_floor_div(a=124, b=123) + run_test_floor_div(a=41, b=20) + run_test_floor_div(a=2, b=6) + + def test_weak_tensor_real_div(self, a_dtype, b_dtype, expected_dtype): + def run_test_real_div(a, b): + a_list = get_test_input_for_binary_op(a, a_dtype) + b_list = get_test_input_for_binary_op(b, b_dtype) + + # Skip if provided dtype is not a valid input dtype for the op. + if not output_dtype_supported_in_op("real_div", expected_dtype[0]): + return + + a_tensor = constant_op.constant(a, expected_dtype[0]) + b_tensor = constant_op.constant(b, expected_dtype[0]) + expected_val = math_ops.real_div(a_tensor, b_tensor) + reverse_expected_val = math_ops.real_div(b_tensor, a_tensor) + for x, y in zip(a_list, b_list): + self.match_expected( + math_ops.realdiv(x, y), expected_val, expected_dtype + ) + self.match_expected( + math_ops.realdiv(y, x), reverse_expected_val, expected_dtype + ) + # math_ops.realdiv and gen_math_ops.real_div are used interchangeably. + self.match_expected( + gen_math_ops.real_div(x, y), expected_val, expected_dtype + ) + self.match_expected( + gen_math_ops.real_div(y, x), reverse_expected_val, expected_dtype + ) + + run_test_real_div(a=124, b=123) + run_test_real_div(a=41, b=20) + run_test_real_div(a=2, b=6) + + def test_weak_tensor_truncate_div(self, a_dtype, b_dtype, expected_dtype): + def run_test_truncate_div(a, b): + # Skip if provided dtype is not a valid input dtype for the op. + if not output_dtype_supported_in_op("truncate_div", expected_dtype[0]): + return + + a, b = maybe_to_positive_input(a, b, a_dtype, b_dtype, expected_dtype) + a_tensor = constant_op.constant(a, expected_dtype[0]) + b_tensor = constant_op.constant(b, expected_dtype[0]) + expected_val = math_ops.truncatediv(a_tensor, b_tensor) + reverse_expected_val = math_ops.truncatediv(b_tensor, a_tensor) + + a_list = get_test_input_for_binary_op(a, a_dtype) + b_list = get_test_input_for_binary_op(b, b_dtype) + for x, y in zip(a_list, b_list): + self.match_expected( + math_ops.truncatediv(x, y), expected_val, expected_dtype + ) + self.match_expected( + math_ops.truncatediv(y, x), reverse_expected_val, expected_dtype + ) + # math_ops.truncatediv and gen_math_ops.truncate_div are used + # interchangeably. + self.match_expected( + gen_math_ops.truncate_div(x, y), expected_val, expected_dtype + ) + self.match_expected( + gen_math_ops.truncate_div(y, x), + reverse_expected_val, + expected_dtype, + ) + + run_test_truncate_div(a=124, b=123) + run_test_truncate_div(a=41, b=20) + run_test_truncate_div(a=2, b=6) + run_test_truncate_div(a=-7, b=5) + run_test_truncate_div(a=1, b=-2) + run_test_truncate_div(a=-100, b=-50) + + def test_weak_tensor_truncate_mod(self, a_dtype, b_dtype, expected_dtype): + def run_test_truncate_mod(a, b): + # Skip if provided dtype is not a valid input dtype for the op. + if not output_dtype_supported_in_op("truncate_mod", expected_dtype[0]): + return + + a, b = maybe_to_positive_input(a, b, a_dtype, b_dtype, expected_dtype) + a_tensor = constant_op.constant(a, expected_dtype[0]) + b_tensor = constant_op.constant(b, expected_dtype[0]) + expected_val = math_ops.truncatemod(a_tensor, b_tensor) + reverse_expected_val = math_ops.truncatemod(b_tensor, a_tensor) + + a_list = get_test_input_for_binary_op(a, a_dtype) + b_list = get_test_input_for_binary_op(b, b_dtype) + for x, y in zip(a_list, b_list): + self.match_expected( + math_ops.truncatemod(x, y), expected_val, expected_dtype + ) + self.match_expected( + math_ops.truncatemod(y, x), reverse_expected_val, expected_dtype + ) + # math_ops.truncatemod and gen_math_ops.truncate_mod are used + # interchangeably. + self.match_expected( + gen_math_ops.truncate_mod(x, y), expected_val, expected_dtype + ) + self.match_expected( + gen_math_ops.truncate_mod(y, x), + reverse_expected_val, + expected_dtype, + ) + + run_test_truncate_mod(a=124, b=123) + run_test_truncate_mod(a=41, b=20) + run_test_truncate_mod(a=2, b=6) + run_test_truncate_mod(a=-7, b=5) + run_test_truncate_mod(a=1, b=-1) + run_test_truncate_mod(a=-100, b=-50) + + def test_weak_tensor_scalar_mul(self, a_dtype, b_dtype, expected_dtype): + def run_test_scalar_mul(a, b): + a_list = get_test_input_for_binary_op(a, a_dtype) + b_list = get_test_input_for_binary_op(b, b_dtype) + # Expected dtype = second arg's dtype. + _ = expected_dtype + if not a_dtype[0].is_compatible_with(b_dtype[0]): + return + expected_val = np.multiply(a, b) + for x, y in zip(a_list, b_list): + self.match_expected(math_ops.scalar_mul(x, y), expected_val, b_dtype) + self.match_expected(math_ops.scalar_mul(y, x), expected_val, a_dtype) + + run_test_scalar_mul(a=4, b=1) + run_test_scalar_mul(a=41, b=2) + run_test_scalar_mul(a=2, b=0) + + def test_weak_tensor_mat_mul(self, a_dtype, b_dtype, expected_dtype): + def run_test_mat_mul(a, b): + a_list = get_test_input_for_binary_op(a, a_dtype) + b_list = get_test_input_for_binary_op(b, b_dtype) + + # Skip if provided dtype is not a valid input dtype for the op. + if not output_dtype_supported_in_op("matmul", expected_dtype[0]): + return + + a_tensor = constant_op.constant(a, expected_dtype[0]) + b_tensor = constant_op.constant(b, expected_dtype[0]) + expected_val = math_ops.matmul(a_tensor, b_tensor) + expected_val_reverse = math_ops.matmul(b_tensor, a_tensor) + for x, y in zip(a_list, b_list): + self.match_expected(math_ops.matmul(x, y), expected_val, expected_dtype) + self.match_expected( + math_ops.matmul(y, x), expected_val_reverse, expected_dtype + ) + + run_test_mat_mul(a=[[2, 1], [3, 4]], b=[[1, 2], [3, 4]]) + run_test_mat_mul(a=[[[3]]], b=[[[2]]]) + + def test_weak_tensor_truediv(self, a_dtype, b_dtype, expected_dtype): + def run_test_truediv(a, b): + a_list = get_test_input_for_binary_op(a, a_dtype) + b_list = get_test_input_for_binary_op(b, b_dtype) + a_tensor = constant_op.constant(a, expected_dtype[0]) + b_tensor = constant_op.constant(b, expected_dtype[0]) + expected_val = a_tensor / b_tensor + reverse_expected_val = b_tensor / a_tensor + for x, y in zip(a_list, b_list): + # Truediv has a dtype conversion orthagonal to our change. Therefore, + # we compare our result dtype to Tensor truediv. + expected_result_dtype = expected_val.dtype + self.match_expected( + math_ops.truediv(x, y), + expected_val, + (expected_result_dtype, expected_dtype[1]), + ) + self.match_expected( + math_ops.truediv(y, x), + reverse_expected_val, + (expected_result_dtype, expected_dtype[1]), + ) + # truediv, divide, and divide dunder method all use Python 3 division + # semantics. + self.match_expected( + math_ops.divide(x, y), + expected_val, + (expected_result_dtype, expected_dtype[1]), + ) + self.match_expected( + math_ops.divide(y, x), + reverse_expected_val, + (expected_result_dtype, expected_dtype[1]), + ) + if at_least_one_tensor_type(x, y): + self.match_expected( + x / y, expected_val, (expected_result_dtype, expected_dtype[1]) + ) + self.match_expected( + y / x, + reverse_expected_val, + (expected_result_dtype, expected_dtype[1]), + ) + + run_test_truediv(a=4, b=2) + run_test_truediv(a=41, b=3) + run_test_truediv(a=2, b=6) + + def test_weak_tensor_div_no_nan(self, a_dtype, b_dtype, expected_dtype): + def run_test_div_no_nan(a, b): + a_list = get_test_input_for_binary_op(a, a_dtype) + b_list = get_test_input_for_binary_op(b, b_dtype) + a_tensor = constant_op.constant(a, expected_dtype[0]) + b_tensor = constant_op.constant(b, expected_dtype[0]) + expected_val = math_ops.div_no_nan(a_tensor, b_tensor) + reverse_expected_val = math_ops.div_no_nan(b_tensor, a_tensor) + # The behavior of div_no_nan is same as truediv in most cases, except + # for when it divides by nan or 0. + expected_result_dtype = expected_val.dtype + for x, y in zip(a_list, b_list): + self.match_expected( + math_ops.div_no_nan(x, y), + expected_val, + (expected_result_dtype, expected_dtype[1]), + ) + self.match_expected( + math_ops.div_no_nan(y, x), + reverse_expected_val, + (expected_result_dtype, expected_dtype[1]), + ) + + run_test_div_no_nan(a=4, b=2) + run_test_div_no_nan(a=41, b=40) + run_test_div_no_nan(a=2, b=6) + + # Test div_no_nan(x, 0) = 0 even if x is NaN or Inf. + x = np.NaN + y = 0 + self.match_expected(math_ops.div_no_nan(x, y), 0, (dtypes.float32, True)) + + x = np.Inf + self.match_expected(math_ops.div_no_nan(x, y), 0, (dtypes.float32, True)) + + def test_weak_tensor_multiply_no_nan(self, a_dtype, b_dtype, expected_dtype): + def run_test_multiply_no_nan(a, b): + a_list = get_test_input_for_binary_op(a, a_dtype) + b_list = get_test_input_for_binary_op(b, b_dtype) + + # Skip if provided dtype is not a valid input dtype for the op. + if not output_dtype_supported_in_op("multiply_no_nan", expected_dtype[0]): + return + + a_tensor = constant_op.constant(a, expected_dtype[0]) + b_tensor = constant_op.constant(b, expected_dtype[0]) + expected_val = math_ops.multiply_no_nan(a_tensor, b_tensor) + for x, y in zip(a_list, b_list): + self.match_expected( + math_ops.multiply_no_nan(x, y), expected_val, expected_dtype + ) + self.match_expected( + math_ops.multiply_no_nan(y, x), expected_val, expected_dtype + ) + + run_test_multiply_no_nan(a=4, b=2) + run_test_multiply_no_nan(a=41, b=10) + run_test_multiply_no_nan(a=2, b=6) + + # Test multiply_no_nan(x, 0) = 0 even if x is NaN or Inf. + x = np.NaN + y = 0 + self.match_expected( + math_ops.multiply_no_nan(x, y), 0, (dtypes.float32, True) + ) + + x = np.Inf + self.match_expected( + math_ops.multiply_no_nan(x, y), 0, (dtypes.float32, True) + ) + + +@parameterized.parameters(safe_mode_unallowed_promos_list) +class WeakTensorBinaryOpsTestSafeMode( + test_util.TensorFlowTestCase, parameterized.TestCase +): + + def test_weak_tensor_add(self, a_dtype, b_dtype): + with DtypeConversionTestEnv("safe"): + a_list = get_test_input_for_binary_op(1, a_dtype) + b_list = get_test_input_for_binary_op(1, b_dtype) + for x, y in zip(a_list, b_list): + with self.assertRaises(TypeError): + _ = math_ops.add(x, y) + with self.assertRaises(TypeError): + _ = math_ops.add(y, x) + if at_least_one_tensor_type(x, y): + with self.assertRaises(TypeError): + _ = x + y + with self.assertRaises(TypeError): + _ = y + x + + def test_weak_tensor_sub(self, a_dtype, b_dtype): + with DtypeConversionTestEnv("safe"): + a_list = get_test_input_for_binary_op(1, a_dtype) + b_list = get_test_input_for_binary_op(1, b_dtype) + for x, y in zip(a_list, b_list): + with self.assertRaises(TypeError): + _ = math_ops.subtract(x, y) + with self.assertRaises(TypeError): + _ = math_ops.subtract(y, x) + if at_least_one_tensor_type(x, y): + with self.assertRaises(TypeError): + _ = x - y + with self.assertRaises(TypeError): + _ = y - x + + def test_weak_tensor_mul(self, a_dtype, b_dtype): + with DtypeConversionTestEnv("safe"): + a_list = get_test_input_for_binary_op(1, a_dtype) + b_list = get_test_input_for_binary_op(1, b_dtype) + for x, y in zip(a_list, b_list): + with self.assertRaises(TypeError): + _ = math_ops.multiply(x, y) + with self.assertRaises(TypeError): + _ = math_ops.multiply(y, x) + if at_least_one_tensor_type(x, y): + with self.assertRaises(TypeError): + _ = x * y + with self.assertRaises(TypeError): + _ = y * x + + def test_weak_tensor_pow(self, a_dtype, b_dtype): + with DtypeConversionTestEnv("safe"): + a_list = get_test_input_for_binary_op(1, a_dtype) + b_list = get_test_input_for_binary_op(1, b_dtype) + for x, y in zip(a_list, b_list): + with self.assertRaises(TypeError): + _ = math_ops.pow(x, y) + with self.assertRaises(TypeError): + _ = math_ops.pow(y, x) + if at_least_one_tensor_type(x, y): + with self.assertRaises(TypeError): + _ = x**y + with self.assertRaises(TypeError): + _ = y**x + + def test_weak_tensor_mod(self, a_dtype, b_dtype): + with DtypeConversionTestEnv("safe"): + a_list = get_test_input_for_binary_op(1, a_dtype) + b_list = get_test_input_for_binary_op(1, b_dtype) + for x, y in zip(a_list, b_list): + with self.assertRaises(TypeError): + _ = math_ops.mod(x, y) + with self.assertRaises(TypeError): + _ = math_ops.mod(y, x) + with self.assertRaises(TypeError): + _ = gen_math_ops.floor_mod(x, y) + with self.assertRaises(TypeError): + _ = gen_math_ops.floor_mod(y, x) + if at_least_one_tensor_type(x, y): + with self.assertRaises(TypeError): + _ = x % y + with self.assertRaises(TypeError): + _ = y % x + + def test_weak_tensor_floor_div(self, a_dtype, b_dtype): + with DtypeConversionTestEnv("safe"): + a_list = get_test_input_for_binary_op(1, a_dtype) + b_list = get_test_input_for_binary_op(1, b_dtype) + for x, y in zip(a_list, b_list): + with self.assertRaises(TypeError): + _ = math_ops.floordiv(x, y) + with self.assertRaises(TypeError): + _ = math_ops.floordiv(y, x) + with self.assertRaises(TypeError): + _ = gen_math_ops.floor_div(x, y) + with self.assertRaises(TypeError): + _ = gen_math_ops.floor_div(y, x) + if at_least_one_tensor_type(x, y): + with self.assertRaises(TypeError): + _ = x // y + with self.assertRaises(TypeError): + _ = y // x + + def test_weak_tensor_real_div(self, a_dtype, b_dtype): + with DtypeConversionTestEnv("safe"): + a_list = get_test_input_for_binary_op(1, a_dtype) + b_list = get_test_input_for_binary_op(1, b_dtype) + for x, y in zip(a_list, b_list): + with self.assertRaises(TypeError): + _ = math_ops.realdiv(x, y) + with self.assertRaises(TypeError): + _ = math_ops.realdiv(y, x) + with self.assertRaises(TypeError): + _ = gen_math_ops.real_div(x, y) + with self.assertRaises(TypeError): + _ = gen_math_ops.real_div(y, x) + + def test_weak_tensor_truncate_mod(self, a_dtype, b_dtype): + with DtypeConversionTestEnv("safe"): + a_list = get_test_input_for_binary_op(1, a_dtype) + b_list = get_test_input_for_binary_op(1, b_dtype) + for x, y in zip(a_list, b_list): + with self.assertRaises(TypeError): + _ = math_ops.truncatemod(x, y) + with self.assertRaises(TypeError): + _ = math_ops.truncatemod(y, x) + with self.assertRaises(TypeError): + _ = gen_math_ops.truncate_mod(x, y) + with self.assertRaises(TypeError): + _ = gen_math_ops.truncate_mod(y, x) + + def test_weak_tensor_truncate_div(self, a_dtype, b_dtype): + with DtypeConversionTestEnv("safe"): + a_list = get_test_input_for_binary_op(1, a_dtype) + b_list = get_test_input_for_binary_op(1, b_dtype) + for x, y in zip(a_list, b_list): + with self.assertRaises(TypeError): + _ = math_ops.truncatediv(x, y) + with self.assertRaises(TypeError): + _ = math_ops.truncatediv(y, x) + with self.assertRaises(TypeError): + _ = gen_math_ops.truncate_div(x, y) + with self.assertRaises(TypeError): + _ = gen_math_ops.truncate_div(y, x) + + def test_weak_tensor_mat_mul(self, a_dtype, b_dtype): + with DtypeConversionTestEnv("safe"): + a_list = get_test_input_for_binary_op([[1]], a_dtype) + b_list = get_test_input_for_binary_op([[1]], b_dtype) + for x, y in zip(a_list, b_list): + with self.assertRaises(TypeError): + _ = math_ops.matmul(x, y) + with self.assertRaises(TypeError): + _ = math_ops.matmul(y, x) + with self.assertRaises(TypeError): + _ = math_ops.matmul(x, y) + with self.assertRaises(TypeError): + _ = math_ops.matmul(y, x) + + +def get_test_input_for_binary_op(val, dtype): + """Returns a list containing all the possible inputs with a given dtype.""" + python_inferred_types = { + (dtypes.int32, True): 1, + (dtypes.float32, True): 1.0, + (dtypes.complex128, True): 1.0j, + } + dtype, weak = dtype + inputs = [] + if weak: + # WeakTensor and Python input types. + inputs.append(_convert_to_input_type(val, "WeakTensor", dtype)) + if dtype in python_inferred_types: + # There are only 3 possible Python default types : int, float, complex. + inputs.append(val * python_inferred_types[dtype]) + else: + # Tensor and NumPy input types. + inputs.append(_convert_to_input_type(val, "Tensor", dtype)) + inputs.append(_convert_to_input_type(val, "NumPy", dtype)) + return inputs + + +def at_least_one_tensor_type(a, b): + """Returns True if at least one of the inputs is a Tensor/WeakTensor.""" + if isinstance(a, tensor.Tensor) or isinstance(a, WeakTensor): + return True + if isinstance(b, tensor.Tensor) or isinstance(b, WeakTensor): + return True + return False + + +def maybe_to_positive_input(a, b, a_dtype, b_dtype, expected_dtype): + """Converts inputs to positive inputs if the provided dtypes are unsigned.""" + unsigned_types = [dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64] + if a < 0 and ( + a_dtype[0] in unsigned_types or expected_dtype[0] in unsigned_types + ): + a = a * (-1) + if b < 0 and ( + b_dtype[0] in unsigned_types or expected_dtype[0] in unsigned_types + ): + b = b * (-1) + return a, b + + +def output_dtype_supported_in_op(op_name, input_dtype): + real_dtypes = [ + dtypes.int8, + dtypes.int16, + dtypes.int32, + dtypes.int64, + dtypes.uint8, + dtypes.uint16, + dtypes.uint32, + dtypes.uint64, + dtypes.bfloat16, + dtypes.half, + dtypes.float32, + dtypes.float64, + ] + # Valid dtypes for the given op in Eager Mode. + valid_dtypes_in_eager = { + "pow": [ + dtypes.float16, + dtypes.float32, + dtypes.float64, + dtypes.int32, + dtypes.int64, + dtypes.complex64, + dtypes.complex128, + ], + "mod": real_dtypes, + "floor_div": real_dtypes, + "real_div": [ + dtypes.bfloat16, + dtypes.float16, + dtypes.float32, + dtypes.float64, + dtypes.complex64, + dtypes.complex128, + ], + "truncate_div": real_dtypes, + "truncate_mod": [ + dtypes.int32, + dtypes.int64, + dtypes.float32, + dtypes.float64, + ], + "matmul": [ + dtypes.bfloat16, + dtypes.float16, + dtypes.float32, + dtypes.float64, + dtypes.int32, + dtypes.int64, + dtypes.complex64, + dtypes.complex128, + ], + "multiply_no_nan": [dtypes.float32, dtypes.float64], + } + return input_dtype in valid_dtypes_in_eager[op_name] + + # TODO(b/289333658): Add tf.constant(x) with no dtype arg as a "weak" input # after adding WeakTensor construction logic to tf.constant. -def _get_test_input(op): +def get_test_input_for_unary_op(op): if op in _TF_UNARY_APIS_WITH_INT_INPUT: return ( _get_weak_tensor(5, dtypes.int32), diff --git a/tensorflow/python/ops/weak_tensor_test_util.py b/tensorflow/python/ops/weak_tensor_test_util.py index aa117def50c086..2eb399cbda0c47 100644 --- a/tensorflow/python/ops/weak_tensor_test_util.py +++ b/tensorflow/python/ops/weak_tensor_test_util.py @@ -17,6 +17,7 @@ import numpy as np from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework.weak_tensor import WeakTensor @@ -27,6 +28,7 @@ def convert_to_input_type(base_input, input_type, dtype=None): elif input_type == "Tensor": return constant_op.constant(base_input, dtype=dtype) elif input_type == "NumPy": + dtype = dtype.as_numpy_dtype if isinstance(dtype, dtypes.DType) else dtype return np.array(base_input, dtype=dtype) elif input_type == "Python": return base_input From ec5bbc096de51edd47bbf43cb7776d0716f5a32a Mon Sep 17 00:00:00 2001 From: Yu Feng Date: Tue, 25 Jul 2023 17:27:48 -0700 Subject: [PATCH 161/410] Internal code clean up PiperOrigin-RevId: 551044622 --- tensorflow/dtensor/mlir/BUILD | 29 -- .../dtensor_allreduce_combine_optimization.cc | 42 +-- .../dtensor_allreduce_scatter_optimization.cc | 1 - .../mlir/dtensor_collective_type_lowering.cc | 1 - tensorflow/dtensor/mlir/group_assignment.cc | 259 ------------------ tensorflow/dtensor/mlir/group_assignment.h | 188 ------------- .../dtensor/mlir/group_assignment_test.cc | 218 --------------- tensorflow/dtensor/mlir/utils/BUILD | 9 +- .../dtensor/mlir/utils/collective_lowering.cc | 13 - .../utils/dtensor_mlir_passes_internal.cc | 15 +- .../dtensor/mlir/utils/update_tpu_metadata.cc | 29 +- tensorflow/opensource_only.files | 1 - 12 files changed, 25 insertions(+), 780 deletions(-) delete mode 100644 tensorflow/dtensor/mlir/group_assignment.cc delete mode 100644 tensorflow/dtensor/mlir/group_assignment.h delete mode 100644 tensorflow/dtensor/mlir/group_assignment_test.cc diff --git a/tensorflow/dtensor/mlir/BUILD b/tensorflow/dtensor/mlir/BUILD index e543e419cde8a5..d627c163409b62 100644 --- a/tensorflow/dtensor/mlir/BUILD +++ b/tensorflow/dtensor/mlir/BUILD @@ -221,7 +221,6 @@ cc_library( ":dtensor_location", ":dtensor_passes_inc_gen", ":dtensor_send_recv", - ":group_assignment", ":layout_parsing", ":op_utils", ":shape_utils", @@ -296,34 +295,6 @@ cc_library( alwayslink = True, ) -cc_library( - name = "group_assignment", - srcs = ["group_assignment.cc"], - hdrs = ["group_assignment.h"], - deps = [ - "//tensorflow/core:lib", - "//tensorflow/dtensor/cc:dstatus", - "@com_google_absl//absl/container:flat_hash_map", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:IR", - ], -) - -tf_cc_test( - name = "group_assignment_test", - srcs = ["group_assignment_test.cc"], - deps = [ - ":group_assignment", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/dtensor/cc:dstatus", - "@com_google_absl//absl/container:flat_hash_map", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:IR", - ], -) - cc_library( name = "layout_parsing", srcs = [ diff --git a/tensorflow/dtensor/mlir/dtensor_allreduce_combine_optimization.cc b/tensorflow/dtensor/mlir/dtensor_allreduce_combine_optimization.cc index 11099801b45ae7..038df0a00eecd8 100644 --- a/tensorflow/dtensor/mlir/dtensor_allreduce_combine_optimization.cc +++ b/tensorflow/dtensor/mlir/dtensor_allreduce_combine_optimization.cc @@ -41,7 +41,6 @@ limitations under the License. #include "tensorflow/dtensor/cc/dtensor_utils.h" #include "tensorflow/dtensor/mlir/dtensor_location.h" #include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h" -#include "tensorflow/dtensor/mlir/group_assignment.h" #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h" #include "tensorflow/dtensor/mlir/layout_parsing.h" @@ -313,46 +312,13 @@ std::string DrawAllReduceDependencies( // clang-format on mlir::LogicalResult CombineAllReduceOps( mlir::tf_device::ClusterOp cluster, - const std::vector& all_reduces) { - // Drop within-slice all-reduces. - std::vector cross_slice_all_reduces; - for (mlir::TF::DTensorAllReduceOp all_reduce : all_reduces) { - mlir::DenseIntElementsAttr group_assignment_attr; - if (!matchPattern(all_reduce.getGroupAssignment(), - m_Constant(&group_assignment_attr))) { - return all_reduce.emitOpError("group_assignment should be a constant"); - } - // LINT.IfChange - // TODO(ishark): Confirm the right check for GPUs. - int num_slices = NumClients(); - int slice_size = kTpuDonutSize; - if (group_assignment_attr.getNumElements() < kTpuDonutSize) { - DCHECK_EQ(num_slices, 1) << "Num slices expected to be equal to 1."; - slice_size = group_assignment_attr.getNumElements(); - } - StatusOr group_assignment = GroupAssignment::FromMLIR( - group_assignment_attr, - GroupAssignment::ReplicaToDeviceMap::DefaultReplicaToDeviceMap( - num_slices, slice_size)); - // LINT.ThenChange(//tensorflow/dtensor/mlir/utils/collective_lowering_google.cc) - if (!group_assignment.ok()) { - return all_reduce.emitOpError( - llvm::formatv("Failed to create a GroupAssignment due to {0}", - group_assignment.status().message())); - } - // Unit tests have only one slice. Always combine all all-reduces in them. - if (group_assignment->num_slices() == 1 || - !group_assignment->IsWithinSlices()) { - cross_slice_all_reduces.push_back(all_reduce); - } - } - + std::vector& all_reduces) { // A single op has nothing to combine with. - int num_all_reduces = cross_slice_all_reduces.size(); + int num_all_reduces = all_reduces.size(); if (num_all_reduces <= 1) return mlir::success(); // Move all-reduces in the same group together and combine them. - auto& all_reduce_group = cross_slice_all_reduces; + auto& all_reduce_group = all_reduces; mlir::TF::DTensorAllReduceOp final_all_reduce = all_reduce_group[num_all_reduces - 1]; @@ -771,7 +737,7 @@ struct DTensorAllReduceCombineOptimization // Within the block, use the group's actual sorting. return lhs[0]->isBeforeInBlock(rhs[0]); }); - for (const auto& reduce_group : all_reduce_groups) { + for (auto& reduce_group : all_reduce_groups) { if (reduce_group.size() > 1) { VLOG(4) << "Combining following reduce ops into one: ------------"; for (auto reduce_op : reduce_group) { diff --git a/tensorflow/dtensor/mlir/dtensor_allreduce_scatter_optimization.cc b/tensorflow/dtensor/mlir/dtensor_allreduce_scatter_optimization.cc index dc9cb9347de1fc..2eb8d4da18890d 100644 --- a/tensorflow/dtensor/mlir/dtensor_allreduce_scatter_optimization.cc +++ b/tensorflow/dtensor/mlir/dtensor_allreduce_scatter_optimization.cc @@ -32,7 +32,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/dtensor/mlir/collectives_common.h" #include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h" -#include "tensorflow/dtensor/mlir/group_assignment.h" #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h" #include "tensorflow/dtensor/mlir/layout_parsing.h" #include "tensorflow/dtensor/mlir/spmd_expander_common.h" diff --git a/tensorflow/dtensor/mlir/dtensor_collective_type_lowering.cc b/tensorflow/dtensor/mlir/dtensor_collective_type_lowering.cc index b3aa8a70d9fe4d..d49c397237414d 100644 --- a/tensorflow/dtensor/mlir/dtensor_collective_type_lowering.cc +++ b/tensorflow/dtensor/mlir/dtensor_collective_type_lowering.cc @@ -46,7 +46,6 @@ limitations under the License. #include "tensorflow/dtensor/mlir/dtensor_dialect/ir/dialect.h" #include "tensorflow/dtensor/mlir/dtensor_dialect/ir/dtensor_attributes.h" #include "tensorflow/dtensor/mlir/dtensor_location.h" -#include "tensorflow/dtensor/mlir/group_assignment.h" #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h" #include "tensorflow/dtensor/mlir/layout_parsing.h" #include "tensorflow/dtensor/mlir/spmd_expander_common.h" diff --git a/tensorflow/dtensor/mlir/group_assignment.cc b/tensorflow/dtensor/mlir/group_assignment.cc deleted file mode 100644 index 8f90a8b4786b7c..00000000000000 --- a/tensorflow/dtensor/mlir/group_assignment.cc +++ /dev/null @@ -1,259 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/dtensor/mlir/group_assignment.h" - -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "llvm/ADT/STLExtras.h" -#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/IR/BuiltinTypes.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/status.h" -#include "tensorflow/core/platform/str_util.h" -#include "tensorflow/core/platform/strcat.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/dtensor/cc/dstatus.h" - -namespace tensorflow { -namespace dtensor { - -GroupAssignment::ReplicaToDeviceMap -GroupAssignment::ReplicaToDeviceMap::DefaultReplicaToDeviceMap(int num_slices, - int slice_size) { - absl::flat_hash_map map; - for (int i = 0; i < num_slices; ++i) { - for (int j = 0; j < slice_size; ++j) { - map[ReplicaId{i * slice_size + j}] = DeviceId{i, j}; - } - } - return ReplicaToDeviceMap(std::move(map)); -} - -GroupAssignment::ReplicaToDeviceMap::ReplicaToDeviceMap( - absl::flat_hash_map map) - : map_(std::move(map)) { - std::set slice_ids; - for (const auto& entry : map_) { - slice_ids.insert(entry.second.slice_id); - } - CHECK_GT(slice_ids.size(), 0); // Crash OK - CHECK_EQ(map_.size() % slice_ids.size(), 0); // Crash OK - num_slices_ = slice_ids.size(); -} - -GroupAssignment::ReplicaGroups::ReplicaGroups( - std::vector> replica_ids) - : replica_ids_(std::move(replica_ids)) { - int n = replica_ids_.size(); - CHECK_GT(n, 0); // Crash OK - int g = replica_ids_.front().size(); - CHECK_GT(g, 0); // Crash OK - std::set seen_replica_ids; - for (std::vector& group : replica_ids_) { - CHECK_EQ(group.size(), g); // Crash OK - for (int replica_id : group) { - CHECK_GE(replica_id, 0); // Crash OK - bool inserted = seen_replica_ids.insert(replica_id).second; - CHECK(inserted); // Crash OK - } - } -} - -mlir::DenseIntElementsAttr GroupAssignment::ReplicaGroups::ToMLIR( - mlir::MLIRContext& context) const { - auto shaped_type = mlir::RankedTensorType::get( - {num_groups(), group_size()}, mlir::IntegerType::get(&context, 32)); - - llvm::SmallVector flat_replica_ids; - flat_replica_ids.reserve(num_replica_ids()); - for (const std::vector& group : replica_ids()) { - flat_replica_ids.insert(flat_replica_ids.end(), group.begin(), group.end()); - } - - return mlir::DenseIntElementsAttr::get(shaped_type, flat_replica_ids); -} - -std::string GroupAssignment::ReplicaGroups::ToString() const { - return strings::StrCat( - "[", - str_util::Join(replica_ids(), ", ", - [](std::string* str, const std::vector& group) { - strings::StrAppend(str, "[", str_util::Join(group, ", "), - "]"); - }), - "]"); -} - -StatusOr GroupAssignment::FromMLIR( - const mlir::DenseIntElementsAttr& group_assignment_attr, - ReplicaToDeviceMap replica_to_device_map) { - mlir::ShapedType shaped_type = group_assignment_attr.getType(); - if (!shaped_type.hasRank()) { - return errors::InvalidArgument("group_assignment_attr must have a rank"); - } - if (shaped_type.getRank() != 2) { - return errors::InvalidArgument( - "group_assignment_attr must have a rank of 2, got ", - shaped_type.getRank()); - } - llvm::ArrayRef shape = shaped_type.getShape(); - int num_groups = shape[0]; - if (num_groups <= 0) { - return errors::InvalidArgument( - "group_assignment_attr must have at least 1 group, got ", num_groups); - } - int group_size = shape[1]; - if (group_size <= 0) { - return errors::InvalidArgument( - "group_assignment_attr must have non-empty groups, got ", group_size, - " replica IDs per group"); - } - int num_replica_ids = num_groups * group_size; - if (num_replica_ids != replica_to_device_map.num_cores()) { - return errors::InvalidArgument("group_assignment_attr must have ", - replica_to_device_map.num_cores(), - " replica IDs, got ", num_replica_ids); - } - - // Translate the flat group assignment to a 2D array. - std::vector> replica_ids; - replica_ids.resize(num_groups, std::vector(group_size)); - std::set seen_replica_ids; - if (group_assignment_attr.getNumElements() != num_replica_ids) { - return errors::InvalidArgument( - "group_assignments_attr num elements was not equal to the number of " - "replica ids."); - } - for (const auto& it : - llvm::enumerate(group_assignment_attr.getValues())) { - int index = it.index(); - int replica_id = it.value().getSExtValue(); - - // If all replica IDs are within this range and distinct, they must be a - // permutation of [0, ..., num_replica_ids). - if (replica_id < 0 || replica_id >= num_replica_ids) { - return errors::InvalidArgument("Out of range replica ID: ", replica_id); - } - if (!seen_replica_ids.insert(replica_id).second) { - return errors::InvalidArgument( - "All replica IDs in group_assigment must be distinct, seeing ", - replica_id, " more than once"); - } - - replica_ids[index / group_size][index % group_size] = replica_id; - } - - GroupAssignment group_assignment( - /*global=*/ReplicaGroups(std::move(replica_ids)), - std::move(replica_to_device_map)); - TF_RETURN_IF_ERROR(group_assignment.GlobalToSlices()); - return group_assignment; -} - -std::string GroupAssignment::ToString() const { - return strings::StrCat( - "GroupAssignment global: ", global_.ToString(), "; hosts: ", - hosts_.empty() - ? "" - : str_util::Join(hosts_, ", ", - [](std::string* str, const ReplicaGroups& groups) { - strings::StrAppend(str, groups.ToString()); - }), - "; slices: ", - slices_.empty() - ? "" - : str_util::Join(slices_, ", ", - [](std::string* str, const ReplicaGroups& groups) { - strings::StrAppend(str, groups.ToString()); - })); -} - -bool GroupAssignment::IsWithinSlices() const { - // This function returns true iff no group in the global view gets split in - // `GlobalToSlices`, i.e., the total group count remains the same. - int total_num_groups = 0; - for (int i = 0; i < num_slices(); i++) { - total_num_groups += num_groups(i).value(); - } - if (total_num_groups != num_groups()) return false; - return total_num_groups == num_groups(); -} - -Status GroupAssignment::GlobalToSlices() { - VLOG(2) << "Original group assignment: " << ToString(); - - int num_slices = replica_to_device_map_.num_slices(); - if (num_slices == 0) { - return errors::InvalidArgument("Unexpectedly empty replica_to_device_map."); - } - - // For each replica group in global replica groups, divide its replicas based - // on which slices they come from. Then, for each slice, collect subgroups - // from every such division and form a new ReplicaGroup for that slice. - std::vector>> replica_groups_per_host; - std::vector>> replica_groups_per_slice; - replica_groups_per_host.resize(num_slices, {}); - replica_groups_per_slice.resize(num_slices, {}); - - for (const std::vector& replica_group : replica_ids()) { - std::vector> replica_group_divided_by_host; - replica_group_divided_by_host.resize(num_slices, {}); - std::vector> replica_group_divided_by_slice; - replica_group_divided_by_slice.resize(num_slices, {}); - - for (int replica_id : replica_group) { - // TODO(b/183426911): Use DeviceId::core_id in ReplicaGroup directly for - // now. Integrate with device assignment with proper typing. - DeviceId device_id = replica_to_device_map_.device_id(replica_id); - replica_group_divided_by_host[device_id.slice_id].push_back(replica_id); - replica_group_divided_by_slice[device_id.slice_id].push_back( - device_id.core_id); - } - - for (int i = 0; i < num_slices; ++i) { - if (!replica_group_divided_by_host[i].empty()) { - // Host meshes have the same global device and replica IDs as TPU - // meshes. Let the first replica in every group do a host collective. - replica_groups_per_host[i].push_back( - std::vector(1, replica_group_divided_by_host[i].front())); - } - if (!replica_group_divided_by_slice[i].empty()) { - replica_groups_per_slice[i].push_back( - std::move(replica_group_divided_by_slice[i])); - } - } - } - - hosts_.reserve(num_slices); - slices_.reserve(num_slices); - for (int i = 0; i < num_slices; ++i) { - hosts_.push_back(ReplicaGroups(std::move(replica_groups_per_host[i]))); - slices_.push_back(ReplicaGroups(std::move(replica_groups_per_slice[i]))); - } - - VLOG(2) << "Divided group assignment: " << ToString(); - return OkStatus(); -} - -} // namespace dtensor -} // namespace tensorflow diff --git a/tensorflow/dtensor/mlir/group_assignment.h b/tensorflow/dtensor/mlir/group_assignment.h deleted file mode 100644 index f85eedcc98f466..00000000000000 --- a/tensorflow/dtensor/mlir/group_assignment.h +++ /dev/null @@ -1,188 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_DTENSOR_MLIR_GROUP_ASSIGNMENT_H_ -#define TENSORFLOW_DTENSOR_MLIR_GROUP_ASSIGNMENT_H_ - -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/status.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/dtensor/cc/dstatus.h" - -namespace tensorflow { -namespace dtensor { - -// Arranges all replica IDs in a DTensor mesh in groups, used as an attribute -// on collective operations. -// -// A group assignment has two views: -// -// - The global mesh view contains replica IDs from all participant TPU slices. -// These replica IDs are identical to global device IDs in a DTensor mesh. -// - The local slice view contains per-slice device IDs understood and used by -// the TPU runtime on each slice. These device IDs are used to set replica -// IDs on each slice. -// -// Some notable common cases: -// -// - In a single-slice case, `slice_size` is set to the actual slice size -// (e.g., 32 for 4x4 DF). The global and local views are identical. -// - In a special topology case, `slice_size` is set to 8. -// - In a multi-topology case, `slice_size` is set to the size of a single -// topology. -// All topologies must have the same size. -class GroupAssignment { - public: - using ReplicaId = int; - - struct DeviceId { - public: - int slice_id; - int core_id; // within `slice_id` - }; - - // Maps global replica IDs to local device IDs consisting of a slice ID and a - // core-on-slice ID. - class ReplicaToDeviceMap { - public: - // Creates a default map that orders devices according to TF task IDs - // followed by device ordinals. - static ReplicaToDeviceMap DefaultReplicaToDeviceMap(int num_slices, - int slice_size); - - // Constructs a map directly, checking it's valid. - explicit ReplicaToDeviceMap(absl::flat_hash_map map); - - int num_slices() { return num_slices_; } - int num_cores() { return map_.size(); } - DeviceId device_id(ReplicaId replica_id) { return map_[replica_id]; } - - private: - absl::flat_hash_map map_; - int num_slices_; - }; - - // Creates a group assignment by converting from an MLIR attribute. - static StatusOr FromMLIR( - const mlir::DenseIntElementsAttr& group_assignment_attr, - ReplicaToDeviceMap replica_to_device_map); - - // Creates an MLIR attribute using the global view. - mlir::DenseIntElementsAttr GlobalToMLIR(mlir::MLIRContext& context) const { - return global_.ToMLIR(context); - } - - // Creates an MLIR attribute for a particular slice. - // Callers should make sure `slice_id` is >= 0 and < num_slices(). - StatusOr SliceToMLIR(mlir::MLIRContext& context, - int slice_id) const { - if (slice_id < 0 || slice_id >= num_slices()) - return errors::InvalidArgument("slide_id was not within bounds."); - return slices_[slice_id].ToMLIR(context); - } - - // Returns a string representation for debugging. - std::string ToString() const; - - // Returns true if every group in the global view only has replica IDs from - // the same slice. - bool IsWithinSlices() const; - - // Returns the number of slices in the local view. - int num_slices() const { return slices_.size(); } - - // These methods return attributes of the global view. - int num_groups() const { return global_.num_groups(); } - int group_size() const { return global_.group_size(); } - int num_replica_ids() const { return global_.num_replica_ids(); } - const std::vector>& replica_ids() const { - return global_.replica_ids(); - } - - // These methods return attributes of a particular slice. - // Callers should make sure `slice_id` is >= 0 and < num_slices(). - StatusOr num_groups(int slice_id) const { - if (slice_id < 0 || slice_id >= num_slices()) - return errors::InvalidArgument("slide_id was not within bounds."); - return slices_[slice_id].num_groups(); - } - StatusOr group_size(int slice_id) const { - if (slice_id < 0 || slice_id >= num_slices()) - return errors::InvalidArgument("slide_id was not within bounds."); - return slices_[slice_id].group_size(); - } - const std::vector>& replica_ids(int slice_id) const { - return slices_[slice_id].replica_ids(); - } - - // Returns the replica groups for collectives running on a particular host. - // Callers should make sure `slice_id` is >= 0 and < num_slices(). - const std::vector>& host_replica_ids(int slice_id) const { - return hosts_[slice_id].replica_ids(); - } - - private: - // Groups of consecutive replica IDs starting at 0. - class ReplicaGroups { - public: - // Creates an object, enforcing the requirements on `replica_ids_`. - explicit ReplicaGroups(std::vector> replica_ids); - - mlir::DenseIntElementsAttr ToMLIR(mlir::MLIRContext& context) const; - - std::string ToString() const; - - int num_groups() const { return replica_ids_.size(); } - int group_size() const { return replica_ids_.front().size(); } - int num_replica_ids() const { return num_groups() * group_size(); } - const std::vector>& replica_ids() const { - return replica_ids_; - } - - private: - // N groups of replica IDs, N > 0. All groups have the same size G, G > 0. - // All replica IDs are distinct values >= 0; - std::vector> replica_ids_; // replica ID order matters - }; - - // Creates an object but leaves `slices_` empty. `GlobalToSlices` should be - // called next to fill in `slices_`. - explicit GroupAssignment(ReplicaGroups global, - ReplicaToDeviceMap replica_to_device_map) - : global_(std::move(global)), - replica_to_device_map_(std::move(replica_to_device_map)) {} - - // Divides the global view along slice boundaries and fill in the slice view. - Status GlobalToSlices(); - - ReplicaGroups global_; - std::vector hosts_; // sorted by increasing slice ID - std::vector slices_; // sorted by increasing slice ID - ReplicaToDeviceMap replica_to_device_map_; -}; - -} // namespace dtensor -} // namespace tensorflow - -#endif // TENSORFLOW_DTENSOR_MLIR_GROUP_ASSIGNMENT_H_ diff --git a/tensorflow/dtensor/mlir/group_assignment_test.cc b/tensorflow/dtensor/mlir/group_assignment_test.cc deleted file mode 100644 index 106ca7fcba9441..00000000000000 --- a/tensorflow/dtensor/mlir/group_assignment_test.cc +++ /dev/null @@ -1,218 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/dtensor/mlir/group_assignment.h" - -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "llvm/ADT/SmallVector.h" -#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/IR/BuiltinTypes.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/status.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/dtensor/cc/dstatus.h" - -namespace tensorflow { -namespace dtensor { -namespace { - -mlir::DenseIntElementsAttr CreateGroupAssignmentAttr( - mlir::MLIRContext& context, - const std::vector>& replica_ids) { - int num_groups = replica_ids.size(); - int group_size = replica_ids.front().size(); - llvm::SmallVector flat_replica_ids; - flat_replica_ids.reserve(num_groups * group_size); - for (const std::vector& group : replica_ids) { - CHECK_EQ(group.size(), group_size); - flat_replica_ids.insert(flat_replica_ids.end(), group.begin(), group.end()); - } - auto shaped_type = mlir::RankedTensorType::get( - {num_groups, group_size}, mlir::IntegerType::get(&context, 32)); - return mlir::DenseIntElementsAttr::get(shaped_type, flat_replica_ids); -} - -GroupAssignment CreateGroupAssignment( - mlir::MLIRContext& context, - const std::vector>& replica_ids, int num_slices, - int slice_size) { - mlir::DenseIntElementsAttr group_assignment_attr = - CreateGroupAssignmentAttr(context, replica_ids); - StatusOr group_assignment = GroupAssignment::FromMLIR( - group_assignment_attr, - GroupAssignment::ReplicaToDeviceMap::DefaultReplicaToDeviceMap( - num_slices, slice_size)); - TF_CHECK_OK(group_assignment.status()); - return *group_assignment; -} - -GroupAssignment CreateGroupAssignment( - mlir::MLIRContext& context, - const std::vector>& replica_ids, - absl::flat_hash_map - map) { - mlir::DenseIntElementsAttr group_assignment_attr = - CreateGroupAssignmentAttr(context, replica_ids); - StatusOr group_assignment = GroupAssignment::FromMLIR( - group_assignment_attr, - GroupAssignment::ReplicaToDeviceMap(std::move(map))); - TF_CHECK_OK(group_assignment.status()); - return *group_assignment; -} - -TEST(DTensorGroupAssignmentTest, InputOutput) { - mlir::MLIRContext context; - - mlir::DenseIntElementsAttr group_assignment_attr_in = - CreateGroupAssignmentAttr(context, - /*replica_ids=*/{{0, 1, 2, 3, 4, 5, 6, 7}}); - TF_ASSERT_OK_AND_ASSIGN( - auto group_assignment, - GroupAssignment::FromMLIR( - group_assignment_attr_in, - GroupAssignment::ReplicaToDeviceMap::DefaultReplicaToDeviceMap( - /*num_slices=*/1, /*slice_size=*/8))); - EXPECT_EQ(group_assignment.replica_ids(), - std::vector>({{0, 1, 2, 3, 4, 5, 6, 7}})); - - mlir::DenseIntElementsAttr group_assignment_attr_out = - group_assignment.GlobalToMLIR(context); - EXPECT_EQ(group_assignment_attr_out, group_assignment_attr_in); - - group_assignment_attr_out = - group_assignment.SliceToMLIR(context, /*slice_id=*/0).value(); - EXPECT_EQ(group_assignment_attr_out, group_assignment_attr_in); -} - -TEST(DTensorGroupAssignmentTest, BadInput) { - mlir::MLIRContext context; - - mlir::DenseIntElementsAttr indivisible_donut_size = - CreateGroupAssignmentAttr(context, - /*replica_ids=*/{{0, 1, 2, 3, 4, 5, 6, 7, 8}}); - EXPECT_FALSE( - GroupAssignment::FromMLIR( - indivisible_donut_size, - GroupAssignment::ReplicaToDeviceMap::DefaultReplicaToDeviceMap( - /*num_slices=*/1, /*slice_size=*/8)) - .ok()); - - mlir::DenseIntElementsAttr duplicate_replica_ids = - CreateGroupAssignmentAttr(context, - /*replica_ids=*/{{0, 1, 2, 3, 4, 5, 6, 6}}); - EXPECT_FALSE( - GroupAssignment::FromMLIR( - duplicate_replica_ids, - GroupAssignment::ReplicaToDeviceMap::DefaultReplicaToDeviceMap( - /*num_slices=*/1, /*slice_size=*/8)) - .ok()); -} - -TEST(DTensorGroupAssignmentTest, Properties) { - mlir::MLIRContext context; - GroupAssignment group_assignment = - CreateGroupAssignment(context, - /*replica_ids=*/{{0, 1, 2, 3}}, - /*num_slices=*/1, /*slice_size=*/4); - EXPECT_EQ(group_assignment.num_groups(), 1); - EXPECT_EQ(group_assignment.group_size(), 4); - EXPECT_EQ(group_assignment.num_replica_ids(), 4); - EXPECT_EQ(group_assignment.replica_ids(), - std::vector>({{0, 1, 2, 3}})); -} - -TEST(DTensorGroupAssignmentTest, GlobalAllReduceSingleDonut) { - mlir::MLIRContext context; - GroupAssignment group_assignment = - CreateGroupAssignment(context, - /*replica_ids=*/{{0, 1, 2, 3, 4, 5, 6, 7}}, - /*num_slices=*/1, /*slice_size=*/8); - EXPECT_TRUE(group_assignment.IsWithinSlices()); - EXPECT_EQ(group_assignment.replica_ids(), - std::vector>({{0, 1, 2, 3, 4, 5, 6, 7}})); - EXPECT_EQ(group_assignment.replica_ids(0), - std::vector>({{0, 1, 2, 3, 4, 5, 6, 7}})); -} - -TEST(DTensorGroupAssignmentTest, GlobalAllReduceTwoDonuts) { - mlir::MLIRContext context; - GroupAssignment group_assignment = - CreateGroupAssignment(context, - /*replica_ids=*/{{1, 2, 0, 3}}, - /*num_slices=*/2, /*slice_size=*/2); - EXPECT_FALSE(group_assignment.IsWithinSlices()); - EXPECT_EQ(group_assignment.replica_ids(), - std::vector>({{1, 2, 0, 3}})); - EXPECT_EQ(group_assignment.replica_ids(0), - std::vector>({{1, 0}})); - EXPECT_EQ(group_assignment.replica_ids(1), - std::vector>({{0, 1}})); -} - -TEST(DTensorGroupAssignmentTest, SubgroupAllReduceFourDonuts) { - mlir::MLIRContext context; - std::vector> global( - {{0, 4, 8, 12}, {1, 5, 9, 13}, {2, 6, 10, 14}, {3, 7, 11, 15}}); - GroupAssignment group_assignment = - CreateGroupAssignment(context, - /*replica_ids=*/global, - /*map=*/ - { - {0, {0, 0}}, - {1, {0, 1}}, - {2, {1, 0}}, - {3, {1, 1}}, - {4, {0, 2}}, - {5, {0, 3}}, - {6, {1, 2}}, - {7, {1, 3}}, - {8, {2, 0}}, - {9, {2, 1}}, - {10, {3, 0}}, - {11, {3, 1}}, - {12, {2, 2}}, - {13, {2, 3}}, - {14, {3, 2}}, - {15, {3, 3}}, - }); - EXPECT_FALSE(group_assignment.IsWithinSlices()); - EXPECT_EQ(group_assignment.replica_ids(), global); - EXPECT_EQ(group_assignment.host_replica_ids(0), - std::vector>({{0}, {1}})); - EXPECT_EQ(group_assignment.replica_ids(0), - std::vector>({{0, 2}, {1, 3}})); - EXPECT_EQ(group_assignment.host_replica_ids(1), - std::vector>({{2}, {3}})); - EXPECT_EQ(group_assignment.replica_ids(1), - std::vector>({{0, 2}, {1, 3}})); - EXPECT_EQ(group_assignment.host_replica_ids(2), - std::vector>({{8}, {9}})); - EXPECT_EQ(group_assignment.replica_ids(2), - std::vector>({{0, 2}, {1, 3}})); - EXPECT_EQ(group_assignment.host_replica_ids(3), - std::vector>({{10}, {11}})); - EXPECT_EQ(group_assignment.replica_ids(3), - std::vector>({{0, 2}, {1, 3}})); -} - -} // namespace -} // namespace dtensor -} // namespace tensorflow diff --git a/tensorflow/dtensor/mlir/utils/BUILD b/tensorflow/dtensor/mlir/utils/BUILD index 0d9a5a39429fe6..72c9678a8bc70a 100644 --- a/tensorflow/dtensor/mlir/utils/BUILD +++ b/tensorflow/dtensor/mlir/utils/BUILD @@ -1,6 +1,5 @@ # DTensor Internal experimental code. load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") -load("//tensorflow:tensorflow.bzl", "if_google") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -16,12 +15,7 @@ cc_library( "collective_lowering.cc", "dtensor_mlir_passes_internal.cc", "update_tpu_metadata.cc", - ] + if_google( - [ - "collective_lowering_google.cc", - "update_tpu_metadata_google.cc", - ], - ), + ], hdrs = ["dtensor_mlir_passes_internal.h"], deps = [ "//tensorflow/compiler/mlir/tensorflow", @@ -42,7 +36,6 @@ cc_library( "//tensorflow/dtensor/mlir:device_utils", "//tensorflow/dtensor/mlir:dtensor_location", "//tensorflow/dtensor/mlir:dtensor_passes_inc_gen", - "//tensorflow/dtensor/mlir:group_assignment", "//tensorflow/dtensor/mlir:layout_parsing", "//tensorflow/dtensor/mlir:op_utils", "//tensorflow/dtensor/mlir:shape_utils", diff --git a/tensorflow/dtensor/mlir/utils/collective_lowering.cc b/tensorflow/dtensor/mlir/utils/collective_lowering.cc index 753fe04e7898c5..61897371071e2f 100644 --- a/tensorflow/dtensor/mlir/utils/collective_lowering.cc +++ b/tensorflow/dtensor/mlir/utils/collective_lowering.cc @@ -55,7 +55,6 @@ limitations under the License. #include "tensorflow/dtensor/mlir/dtensor_dialect/ir/dialect.h" #include "tensorflow/dtensor/mlir/dtensor_dialect/ir/dtensor_attributes.h" #include "tensorflow/dtensor/mlir/dtensor_location.h" -#include "tensorflow/dtensor/mlir/group_assignment.h" #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h" #include "tensorflow/dtensor/mlir/layout_parsing.h" #include "tensorflow/dtensor/mlir/spmd_expander_common.h" @@ -75,13 +74,6 @@ namespace { } // namespace namespace internal { -#ifdef PLATFORM_GOOGLE -mlir::LogicalResult EmitAllReduceForXlaGoogle( - mlir::MLIRContext& context, mlir::OpBuilder& builder, - mlir::TF::DTensorAllReduceOp all_reduce, - mlir::DenseIntElementsAttr group_assignment_attr, int32 key_base, - mlir::Operation** final_op); -#endif namespace ops_util = ::mlir::TF::collection_ops_util; constexpr int32 kUninitializedGroupKey = 0; @@ -117,10 +109,6 @@ mlir::LogicalResult EmitAllReduceForXla( mlir::TF::DTensorAllReduceOp all_reduce, mlir::DenseIntElementsAttr group_assignment_attr, int32 key_base, mlir::Operation** final_op) { -#ifdef PLATFORM_GOOGLE - return EmitAllReduceForXlaGoogle(context, builder, all_reduce, - group_assignment_attr, key_base, final_op); -#else constexpr char kCrossReplica[] = "CrossReplica"; // For TPUs, lower to XlaAllReduce straightforwardly. @@ -129,7 +117,6 @@ mlir::LogicalResult EmitAllReduceForXla( all_reduce.getInput(), all_reduce.getGroupAssignment(), all_reduce.getReduceOpAttr(), builder.getStringAttr(kCrossReplica)); return mlir::success(); -#endif } llvm::SmallVector GetGroupKeyOffsets( diff --git a/tensorflow/dtensor/mlir/utils/dtensor_mlir_passes_internal.cc b/tensorflow/dtensor/mlir/utils/dtensor_mlir_passes_internal.cc index e8a8bd1c464a86..f2b1429bc05034 100644 --- a/tensorflow/dtensor/mlir/utils/dtensor_mlir_passes_internal.cc +++ b/tensorflow/dtensor/mlir/utils/dtensor_mlir_passes_internal.cc @@ -13,22 +13,26 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// LINT.IfChange #include "tensorflow/dtensor/mlir/utils/dtensor_mlir_passes_internal.h" #include -#include "mlir/IR/BuiltinOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "tensorflow/dtensor/mlir/create_dtensor_mlir_passes.h" + namespace tensorflow { namespace dtensor { -void AddDTensorAllReduceCombineOptimization(mlir::OpPassManager* pm){ +// Combine independent DTensorAllReduceOps from the same ClusterOp. +// Non-sea of donuts does not need this. It can rely on the XLA all-reduce +// combiner instead. +void AddDTensorAllReduceCombineOptimization(mlir::OpPassManager* pm) { // Experimental feature. If zero, the optimization for combining all reduces // with same group assignment and reduction, will not be done. - const char * env_str = ( - std::getenv("DTENSOR_ENABLE_COMBINE_ALL_REDUCES_OPTIMIZATION")); + const char* env_str = + (std::getenv("DTENSOR_ENABLE_COMBINE_ALL_REDUCES_OPTIMIZATION")); if (env_str && strcmp(env_str, "0") == 0) { return; } @@ -39,4 +43,3 @@ void AddDTensorAllReduceCombineOptimization(mlir::OpPassManager* pm){ } // namespace dtensor } // namespace tensorflow -// LINT.ThenChange(dtensor_mlir_passes_internal.cc) diff --git a/tensorflow/dtensor/mlir/utils/update_tpu_metadata.cc b/tensorflow/dtensor/mlir/utils/update_tpu_metadata.cc index 32e57572741c02..64820b9c2bdbeb 100644 --- a/tensorflow/dtensor/mlir/utils/update_tpu_metadata.cc +++ b/tensorflow/dtensor/mlir/utils/update_tpu_metadata.cc @@ -45,20 +45,6 @@ limitations under the License. namespace tensorflow { namespace dtensor { -namespace internal { -#ifdef PLATFORM_GOOGLE -extern void ComputeReplicaGroupSplitInfo(int requested_num_replicas, - int* num_replicas, - int* core_id_offset); -#else -// By default, all TPUs are connected, construct a single replica group. -void ComputeReplicaGroupSplitInfo(int requested_num_replicas, int* num_replicas, - int* core_id_local_offset) { - *num_replicas = requested_num_replicas; - *core_id_local_offset = 0; -} -#endif -} // namespace internal namespace { #define GEN_PASS_DEF_DTENSORUPDATETPUMETADATA #include "tensorflow/dtensor/mlir/dtensor_passes.h.inc" @@ -66,6 +52,13 @@ namespace { constexpr char kDeviceAttr[] = "device"; constexpr char kFuncDeviceAttr[] = "tf.device"; +// By default, all TPUs are connected, construct a single replica group. +void ComputeReplicaGroupSplitInfo(int requested_num_replicas, int* num_replicas, + int* core_id_local_offset) { + *num_replicas = requested_num_replicas; + *core_id_local_offset = 0; +} + // Removes explicit device assignment on TPUExecute and _TPUCompileMlir ops. // As TPU execution replication logic is delegated to DTensorDevice, // DTensorDevice should handle replication and Placer would assign devices. @@ -98,8 +91,8 @@ Status UpdateMetadataProtoXlaSpmd(const Mesh& mesh_config, int core_id_local_offset = 0; int num_replicas = mesh_config.num_devices(); - internal::ComputeReplicaGroupSplitInfo(num_replicas, &num_replicas, - &core_id_local_offset); + ComputeReplicaGroupSplitInfo(num_replicas, &num_replicas, + &core_id_local_offset); // DTensor will interact with Xla Spmd by setting 1 replica and // `num_devices` number of cores per that replica to ensure @@ -194,8 +187,8 @@ Status UpdateMetadataProtoDtensorSpmd(const Mesh& mesh_config, int core_id_local_offset = 0; int num_replicas = mesh_config.num_devices(); - internal::ComputeReplicaGroupSplitInfo(num_replicas, &num_replicas, - &core_id_local_offset); + ComputeReplicaGroupSplitInfo(num_replicas, &num_replicas, + &core_id_local_offset); proto.set_num_replicas(num_replicas); diff --git a/tensorflow/opensource_only.files b/tensorflow/opensource_only.files index 1d2b67f8cf0957..3aa07aaa63fb4a 100644 --- a/tensorflow/opensource_only.files +++ b/tensorflow/opensource_only.files @@ -19,7 +19,6 @@ tensorflow/core/tfrt/mla/mla_utils.h: tensorflow/core/tfrt/saved_model/saved_model_mira_impl.h: tensorflow/core/tfrt/utils/bridge_graph_analysis.h: tensorflow/dtensor/build_defs:.bzl -tensorflow/dtensor/mlir/utils/dtensor_mlir_passes_internal:.cc tensorflow/dtensor/python/tests/test_backend_name:.py tensorflow/dtensor/python/tests/test_backend_util:.py tensorflow/examples/custom_ops_doc/multiplex_1/BUILD: From b065b06e7ea2c58824fa68732de61790b517ad2e Mon Sep 17 00:00:00 2001 From: Yang Chen Date: Tue, 25 Jul 2023 18:04:26 -0700 Subject: [PATCH 162/410] #tf-data-service Fix a ClangTidy warning. use '= default' to define a trivial default constructor PiperOrigin-RevId: 551051722 --- tensorflow/core/data/service/dataset_store.cc | 2 -- tensorflow/core/data/service/dataset_store.h | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/tensorflow/core/data/service/dataset_store.cc b/tensorflow/core/data/service/dataset_store.cc index 18a1917da59859..f7db6301003f45 100644 --- a/tensorflow/core/data/service/dataset_store.cc +++ b/tensorflow/core/data/service/dataset_store.cc @@ -55,8 +55,6 @@ Status FileSystemDatasetStore::Get( return OkStatus(); } -MemoryDatasetStore::MemoryDatasetStore() {} - Status MemoryDatasetStore::Put(const std::string& key, const DatasetDef& dataset) { auto& stored_dataset = datasets_[key]; diff --git a/tensorflow/core/data/service/dataset_store.h b/tensorflow/core/data/service/dataset_store.h index 100cca885ec4ec..790d0d7acfaf09 100644 --- a/tensorflow/core/data/service/dataset_store.h +++ b/tensorflow/core/data/service/dataset_store.h @@ -61,7 +61,7 @@ class FileSystemDatasetStore : public DatasetStore { // dispatcher doesn't have a work directory configured. class MemoryDatasetStore : public DatasetStore { public: - MemoryDatasetStore(); + MemoryDatasetStore() = default; MemoryDatasetStore(const MemoryDatasetStore&) = delete; MemoryDatasetStore& operator=(const MemoryDatasetStore&) = delete; From 4d6dd388319f450b68e8ad465fe92357d9ebac09 Mon Sep 17 00:00:00 2001 From: Yu Feng Date: Tue, 25 Jul 2023 18:16:12 -0700 Subject: [PATCH 163/410] Stop producing CopyToMesh Op. To unblock adding automated mesh placement, which would need to insert CopyToMesh ops before layout propagation. Further work planned: Postpone CopyToMesh to DTensor/SendRecv lowering to after layout propagation (but before spmd expansion) Consolidate CopyToMesh followed by Relayout (user added relayout) and CopyToMesh followed by DTensorLayout (layout propagation added layout). This CL includes: - CopyToMesh is now limited as a transitory Op during lowering. - Old uses of CopyToMesh is replaced with Relayout, including from api.py. - CopyToMesh only accepts a Mesh attr. - layout propagation adds target_layout to DTensorSend / and source_layout to DTensorRecv, which are then used in SPMD expansion for lowering of the ops. - updated pattern recognization for shape and layouts in Save/Restore op support. - global shape now annotates the GlobalShape attr, not the _global_shape attr on DTensorLayout ops, since _global_shape is mostly ignored form DTensorLayout op. This means the deprecation of CopyToMeshGrad Op; will remove later. PiperOrigin-RevId: 551053876 --- .../compat/ops_history_v2/CopyToMesh.pbtxt | 4 +- .../ops_history_v2/CopyToMeshGrad.pbtxt | 9 +- tensorflow/dtensor/cc/constants.h | 8 ++ tensorflow/dtensor/cc/dtensor_ops.cc | 6 +- .../dtensor/mlir/annotate_global_shape.cc | 14 ++- .../dtensor/mlir/dtensor_mlir_passes.cc | 14 +-- tensorflow/dtensor/mlir/dtensor_send_recv.cc | 83 ++++++++-------- .../expansions/dtensor_op_spmd_expander.cc | 11 ++- .../expansions/save_restore_spmd_expander.cc | 1 + .../mlir/handle_cross_cluster_dependencies.cc | 95 +++++++++++++------ tensorflow/dtensor/mlir/ir/tf_dtensor.td | 17 ++-- .../dtensor/mlir/layout_propagation_v2.cc | 68 ++++++++++++- tensorflow/dtensor/mlir/merge_clusters.cc | 5 +- tensorflow/dtensor/mlir/mesh_propagation.cc | 52 ++++------ .../dtensor/mlir/op_to_device_cluster.cc | 13 ++- .../dtensor/mlir/restore_shape_inference.cc | 26 +++-- .../handle_cross_cluster_dependencies.mlir | 55 ++++++----- .../mlir/tests/layout_propagation_v2.mlir | 9 +- .../dtensor/mlir/tests/lower_send_recv.mlir | 14 +-- .../dtensor/mlir/tests/mesh_propagation.mlir | 29 +----- .../mlir/tests/op_to_device_cluster.mlir | 4 +- .../mlir/tests/restore_and_assign.mlir | 11 ++- .../mlir/tests/restore_shape_inference.mlir | 8 +- .../dtensor/mlir/tests/spmd_dtensor_ops.mlir | 38 ++++---- tensorflow/dtensor/python/api.py | 5 +- tensorflow/dtensor/python/dtensor_device.py | 8 -- .../dtensor/python/tests/device_test.py | 10 +- .../dtensor/python/tests/multi_mesh_test.py | 39 ++++++++ 28 files changed, 402 insertions(+), 254 deletions(-) diff --git a/tensorflow/core/ops/compat/ops_history_v2/CopyToMesh.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/CopyToMesh.pbtxt index 5bf27035608956..50e0a66e784a74 100644 --- a/tensorflow/core/ops/compat/ops_history_v2/CopyToMesh.pbtxt +++ b/tensorflow/core/ops/compat/ops_history_v2/CopyToMesh.pbtxt @@ -1,4 +1,4 @@ -op { +op { name: "CopyToMesh" input_arg { name: "input" @@ -9,7 +9,7 @@ op { type_attr: "T" } attr { - name: "layout" + name: "mesh" type: "string" } attr { diff --git a/tensorflow/core/ops/compat/ops_history_v2/CopyToMeshGrad.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/CopyToMeshGrad.pbtxt index a449a512850f10..e75ffe9bc3eb37 100644 --- a/tensorflow/core/ops/compat/ops_history_v2/CopyToMeshGrad.pbtxt +++ b/tensorflow/core/ops/compat/ops_history_v2/CopyToMeshGrad.pbtxt @@ -1,4 +1,4 @@ -op { +op { name: "CopyToMeshGrad" input_arg { name: "input" @@ -12,13 +12,6 @@ op { name: "output" type_attr: "T" } - attr { - name: "reference_layout" - type: "string" - default_value { - s: "" - } - } attr { name: "T" type: "type" diff --git a/tensorflow/dtensor/cc/constants.h b/tensorflow/dtensor/cc/constants.h index b14ea2438af7f2..8f11d62981d639 100644 --- a/tensorflow/dtensor/cc/constants.h +++ b/tensorflow/dtensor/cc/constants.h @@ -122,6 +122,14 @@ static constexpr char kCacheKey[] = "dtensor.cache_key"; // from. static constexpr char kFromArgIndex[] = "dtensor.from_arg_index"; +// To record the target layout of a DTensorSend, which is computed after +// layout propagation. +static constexpr char kTargetLayoutAttr[] = "target_layout"; + +// To record the source layout of a DTensorRecv, which is computed after +// layout propagation. +static constexpr char kSourceLayoutAttr[] = "source_layout"; + // An attribute that determines whether a tensor is a sparse tensor. If this // attribute exists in a tensor, then this tensor is a sparse tensor. static constexpr char kSparseValue[] = "tf._sparse"; diff --git a/tensorflow/dtensor/cc/dtensor_ops.cc b/tensorflow/dtensor/cc/dtensor_ops.cc index 83c82e5a0582a6..df1172c568bdd7 100644 --- a/tensorflow/dtensor/cc/dtensor_ops.cc +++ b/tensorflow/dtensor/cc/dtensor_ops.cc @@ -49,20 +49,22 @@ REGISTER_OP("RelayoutLike") .Attr("U: type") .SetShapeFn(UnchangedShape); +// FIXME(b/271292250): Add DTensor suffix to signal this is a meta Op +// Op. Or remove this altogether, if there is no use for it. // Copy `input` to the given mesh and layout. REGISTER_OP("CopyToMesh") .Input("input: T") .Output("output: T") - .Attr("layout: string") + .Attr("mesh: string") .Attr("T: type") .SetShapeFn(UnchangedShape); +// FIXME(b/271292250): Remove this Op It is no longer used. // Gradient of CopyToMesh. REGISTER_OP("CopyToMeshGrad") .Input("input: T") .Input("forward_input: T") // To infer the output mesh. .Output("output: T") - .Attr("reference_layout: string = ''") // To infer the sharding spec. .Attr("T: type") .SetShapeFn(UnchangedShape); diff --git a/tensorflow/dtensor/mlir/annotate_global_shape.cc b/tensorflow/dtensor/mlir/annotate_global_shape.cc index 32d1e0bba95094..e251254e1d38cf 100644 --- a/tensorflow/dtensor/mlir/annotate_global_shape.cc +++ b/tensorflow/dtensor/mlir/annotate_global_shape.cc @@ -32,6 +32,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" #include "tensorflow/dtensor/cc/constants.h" #include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h" +#include "tensorflow/dtensor/mlir/ir/tf_dtensor.h" +#include "tensorflow/dtensor/mlir/value_utils.h" namespace tensorflow { namespace dtensor { @@ -84,7 +86,17 @@ void AnnotateOperationGlobalShape(mlir::Operation* op, for (const auto& result_type : op->getResultTypes()) op_global_shape.emplace_back(ConvertTypeToTensorShapeAttr(result_type)); - op->setAttr(kGlobalShape, builder->getArrayAttr(op_global_shape)); + if (auto layout_op = mlir::dyn_cast(op)) { + // Shape of Resource type is incorrect when it is a variable. + // The global shape is undefined in this case; and usually we are supposed + // to propagate the value shape due to how resource variable layout is + // currently represented in DTensor. + if (!IsResourceType(op->getResult(0))) { + layout_op.setGlobalShapeAttr(op_global_shape[0]); + } + } else { + op->setAttr(kGlobalShape, builder->getArrayAttr(op_global_shape)); + } } // Pass that annotates function argument/return values and all operation with diff --git a/tensorflow/dtensor/mlir/dtensor_mlir_passes.cc b/tensorflow/dtensor/mlir/dtensor_mlir_passes.cc index d3db9ec3d626d6..8ddf5684d88e66 100644 --- a/tensorflow/dtensor/mlir/dtensor_mlir_passes.cc +++ b/tensorflow/dtensor/mlir/dtensor_mlir_passes.cc @@ -124,7 +124,7 @@ void CreateDTensorMLIRPass(const mlir::TF::StandardPipelineOptions &options, pm->addPass(CreateDTensorSparseTensorToDenseTensor()); // After shape inference, there may be unused constants ops added when - // propagating caller-callee constants. As DTensor mesh/layout propgation + // propagating caller-callee constants. As DTensor mesh/layout propagation // passes assumes that there are no unreachable ops, removes trivial unused // ops. Note that `Canonicalizer` pass in TF includes similar optimization. // However, canonicalizer pass also rewrites some ops and may remove `_layout` @@ -169,11 +169,6 @@ void CreateDTensorMLIRPass(const mlir::TF::StandardPipelineOptions &options, // Merge Clusters pm->addPass(CreateDTensorMergeClustersPass()); - // Mark all ops and functions with global shape attribute to preserve global - // shape information as it is needed during Layout Propagation and SPMD - // expansion. - pm->addPass(CreateDTensorAnnotateGlobalShape()); - //////// // Propagate layout to all ops in graph. @@ -183,6 +178,11 @@ void CreateDTensorMLIRPass(const mlir::TF::StandardPipelineOptions &options, // This pass fills in all missing shapes caused by tf.RestoreV2 ops. pm->addPass(CreateDTensorInferShapesForRestoreV2Op()); + // Mark all ops and functions with global shape attribute to preserve global + // shape information as it is needed during Layout Propagation and SPMD + // expansion. + pm->addPass(CreateDTensorAnnotateGlobalShape()); + pm->addPass(CreateDTensorLayoutPropagationPassV2()); // Expand graph to SPMD form given layouts are annotated to all ops. @@ -227,7 +227,7 @@ void CreateDTensorMLIRPass(const mlir::TF::StandardPipelineOptions &options, // DTensorReduceScatter lowering should come before DTensorAllReduce // and DTensorAllScatter lowerings since for some devices DTensorReduceScatter - // will be decomposed into an DTensorAllReduce+DTensorScatter. + // will be decomposed into a DTensorAllReduce+DTensorScatter. pm->addPass(CreateDTensorReduceScatterLoweringPass()); // For large enough reduction groups in reduction ops, upcast the input diff --git a/tensorflow/dtensor/mlir/dtensor_send_recv.cc b/tensorflow/dtensor/mlir/dtensor_send_recv.cc index 5b3500a6db2daa..4a17b4db82dca6 100644 --- a/tensorflow/dtensor/mlir/dtensor_send_recv.cc +++ b/tensorflow/dtensor/mlir/dtensor_send_recv.cc @@ -21,6 +21,7 @@ limitations under the License. #include "llvm/Support/FormatVariadic.h" #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/Support/DebugStringHelper.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h" @@ -116,11 +117,10 @@ StatusOr LowerDTensorSendToTFOp( builder.setInsertionPointAfter(send_input.getDefiningOp()); std::string tensor_name = dtensor_send.getKey().str(); - Layout target_layout = dtensor_send.getTargetLayout(); + Mesh target_mesh = dtensor_send.getTargetMesh(); absl::Span sending_devices = send_input_layout.mesh().local_devices(); - absl::Span receiving_devices = - target_layout.mesh().local_devices(); + absl::Span receiving_devices = target_mesh.local_devices(); mlir::Operation* lowered_send_op; lowered_send_op = builder.create( @@ -186,7 +186,7 @@ StatusOr LowerDTensorSendToXlaOp( // specific local tensor type needed, if different from the Recv op output type. StatusOr LowerDTensorRecvToXlaOp( mlir::TF::DTensorRecv dtensor_recv, mlir::Type output_type) { - const bool recv_at_cpu = dtensor_recv.getLayout().mesh().is_cpu_mesh(); + const bool recv_at_cpu = dtensor_recv.getMesh().is_cpu_mesh(); mlir::Operation* recv_xla_op = nullptr; mlir::OpBuilder builder(dtensor_recv); @@ -256,9 +256,8 @@ StatusOr LowerDTensorSendFromCPUToTFOp( absl::Span sending_devices = send_input_layout.mesh().local_devices(); - Layout target_layout = dtensor_send.getTargetLayout(); - absl::Span receiving_devices = - target_layout.mesh().local_devices(); + Mesh target_mesh = dtensor_send.getTargetMesh(); + absl::Span receiving_devices = target_mesh.local_devices(); std::string tensor_name = dtensor_send.getKey().str(); @@ -276,7 +275,7 @@ StatusOr LowerDTensorSendFromCPUToTFOp( // Lowers DTensorRecv op to TF Recv Op. StatusOr LowerDTensorRecvFromCPUToTFOp( const Mesh& send_mesh, mlir::TF::DTensorRecv dtensor_recv) { - const Layout& recv_layout = dtensor_recv.getLayout(); + const Mesh& recv_mesh = dtensor_recv.getMesh(); auto recv_cluster = dtensor_recv->getParentOfType(); @@ -286,8 +285,7 @@ StatusOr LowerDTensorRecvFromCPUToTFOp( builder.setInsertionPoint(dtensor_recv); std::string tensor_name = dtensor_recv.getKey().str(); absl::Span sending_devices = send_mesh.local_devices(); - absl::Span receiving_devices = - recv_layout.mesh().local_devices(); + absl::Span receiving_devices = recv_mesh.local_devices(); mlir::Operation* lowered_recv_op; mlir::Location loc = dtensor_recv.getLoc(); @@ -306,7 +304,7 @@ StatusOr LowerDTensorRecvFromCPUToTFOp( StatusOr LowerDTensorRecvToTFOp( const Mesh& send_mesh, mlir::TF::DTensorRecv dtensor_recv, mlir::Type output_type) { - const Layout& recv_layout = dtensor_recv.getLayout(); + const Mesh& recv_mesh = dtensor_recv.getMesh(); auto recv_cluster = dtensor_recv->getParentOfType(); @@ -314,8 +312,7 @@ StatusOr LowerDTensorRecvToTFOp( builder.setInsertionPoint(dtensor_recv); std::string tensor_name = dtensor_recv.getKey().str(); absl::Span sending_devices = send_mesh.local_devices(); - absl::Span receiving_devices = - recv_layout.mesh().local_devices(); + absl::Span receiving_devices = recv_mesh.local_devices(); mlir::Location loc = dtensor_recv.getLoc(); mlir::Operation* lowered_recv_op = builder.create( @@ -447,9 +444,8 @@ StatusOr LowerOneToOneDTensorRecvToTFHostRecv( mlir::TensorType recv_type = dtensor_recv.getType(); bool i32_copy = recv_type.getElementType().isInteger(32); - TF_ASSIGN_OR_RETURN( - mlir::TensorType local_recv_type, - LocalTypeFromGlobalType(dtensor_recv.getLayout(), recv_type)); + TF_ASSIGN_OR_RETURN(mlir::TensorType local_recv_type, + LocalTypeFromGlobalType(recv_layout, recv_type)); mlir::TensorType local_output_type = i32_copy ? mlir::RankedTensorType::get(local_recv_type.getShape(), builder.getIntegerType(64)) @@ -535,6 +531,7 @@ bool SendRecvOpUsesXla(const Mesh& send_mesh, const Mesh& recv_mesh) { } } // namespace +// FIXME(b/271292250): Remove the recv_op argument. StatusOr LowerDTensorSend(mlir::Operation* send_op, mlir::Operation* recv_op) { auto dtensor_send = llvm::cast(send_op); @@ -544,8 +541,17 @@ StatusOr LowerDTensorSend(mlir::Operation* send_op, ExtractRequiredLayoutFromOperand(dtensor_send.getInput())); const Mesh& input_mesh = input_layout.mesh(); - const Layout& recv_layout = dtensor_send.getTargetLayout(); - const Mesh& target_mesh = recv_layout.mesh(); + const Mesh& target_mesh = dtensor_send.getTargetMesh(); + + auto layout_attr = + dtensor_send->getAttrOfType(kTargetLayoutAttr); + + if (!layout_attr) { + return absl::InvalidArgumentError("target_layout is not found"); + } + + const Layout& recv_layout = layout_attr.getValue(); + bool one_to_one = IsOneToOneMeshTransfer(input_layout, recv_layout); // Force string type to not use the allreduce/broadcast optimization as there // is no string type allreduce. @@ -624,8 +630,7 @@ StatusOr LowerDTensorSend(mlir::Operation* send_op, dtensor_send->moveBefore(yield); // Lower DTensorSend op to actual TF op. - TF_ASSIGN_OR_RETURN(const Mesh recv_mesh, - ExtractDeviceMeshEnclosingCluster(recv_op)); + const Mesh& recv_mesh = recv_layout.mesh(); if (SendRecvOpUsesXla(input_layout.mesh(), recv_mesh)) { // Lower DTensorSend op to Xla Send ops. TF_ASSIGN_OR_RETURN( @@ -659,23 +664,26 @@ StatusOr LowerDTensorSend(mlir::Operation* send_op, return lowered_send; } +// FIXME(b/271292250): Remove the send_op argument. StatusOr LowerDTensorRecv(mlir::Operation* send_op, mlir::Operation* recv_op) { auto dtensor_recv = llvm::cast(recv_op); - auto dtensor_send = llvm::dyn_cast(send_op); - - TF_ASSIGN_OR_RETURN(const Layout send_layout, - ExtractRequiredLayoutFromOperand(send_op->getOperand(0))); - - TF_ASSIGN_OR_RETURN(const Mesh send_mesh, - ExtractDeviceMeshEnclosingCluster(send_op)); TF_ASSIGN_OR_RETURN(const Layout output_layout, ExtractRequiredSingleLayoutFromOp(recv_op)); mlir::Operation* lowered_recv; - const Layout recv_layout = dtensor_recv.getLayout(); - const Mesh& recv_mesh = recv_layout.mesh(); + auto layout_attr = + dtensor_recv->getAttrOfType(kSourceLayoutAttr); + if (!layout_attr) { + return absl::InvalidArgumentError("source_layout is not found"); + } + const Layout send_layout = layout_attr.getValue(); + + const Mesh send_mesh = send_layout.mesh(); + + const Mesh& recv_mesh = dtensor_recv.getMesh(); + const Layout& recv_layout = output_layout; mlir::OpBuilder builder(dtensor_recv); bool cpu_to_cpu = recv_mesh.is_cpu_mesh() && send_mesh.is_cpu_mesh(); @@ -686,15 +694,14 @@ StatusOr LowerDTensorRecv(mlir::Operation* send_op, bool is_string_type = IsStringType(dtensor_recv.getType().getElementType()); if (IsGpuToHostMeshTransfer(send_mesh, recv_mesh) && - (one_to_one && - (!dtensor_recv.getLayout().IsFullyReplicated() || is_string_type))) { + (one_to_one && (!recv_layout.IsFullyReplicated() || is_string_type))) { TF_ASSIGN_OR_RETURN(lowered_recv, LowerOneToOneDTensorRecvToTFHostRecv( send_mesh, recv_layout, dtensor_recv)); // erase the send op here iff not targeting a gpu if (recv_mesh.device_type() != "GPU") { - dtensor_send.erase(); + send_op->erase(); } return lowered_recv; @@ -706,10 +713,10 @@ StatusOr LowerDTensorRecv(mlir::Operation* send_op, (send_recv_xla && recv_mesh.is_cpu_mesh())) { // Recv can be lowered directly for a 1-to-1 transfer between host and // device (*for XLA/TPUs). - TF_ASSIGN_OR_RETURN(mlir::TensorType local_output_type, - LocalTypeFromGlobalType( - dtensor_recv.getLayout(), - dtensor_recv.getType().cast())); + TF_ASSIGN_OR_RETURN( + mlir::TensorType local_output_type, + LocalTypeFromGlobalType( + recv_layout, dtensor_recv.getType().cast())); TF_ASSIGN_OR_RETURN( lowered_recv, LowerDTensorRecvToXlaOp(dtensor_recv, local_output_type)); dtensor_recv->replaceAllUsesWith(lowered_recv); @@ -722,7 +729,7 @@ StatusOr LowerDTensorRecv(mlir::Operation* send_op, : LowerDTensorRecvToTFOp; // For other send/recv layouts, the tensor needs to be replicated. - if (!dtensor_recv.getLayout().IsFullyReplicated()) { + if (!recv_layout.IsFullyReplicated()) { return absl::InvalidArgumentError( "CopyToMesh where target mesh is GPU/TPU requires a replicated " "target layout."); @@ -844,7 +851,7 @@ StatusOr LowerDTensorSendAndRecv(mlir::Operation* send_op, auto dtensor_send = llvm::cast(send_op); auto dtensor_recv = llvm::dyn_cast(recv_op); - const Mesh recv_mesh = dtensor_recv.getLayout().mesh(); + const Mesh recv_mesh = dtensor_recv.getMesh(); TF_ASSIGN_OR_RETURN( std::optional send_mesh, ExtractDeviceMeshFromOp( diff --git a/tensorflow/dtensor/mlir/expansions/dtensor_op_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/dtensor_op_spmd_expander.cc index 07805ee4d23ded..6e8d896545a02c 100644 --- a/tensorflow/dtensor/mlir/expansions/dtensor_op_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/dtensor_op_spmd_expander.cc @@ -40,6 +40,8 @@ namespace tensorflow { namespace dtensor { namespace { +// FIXME(feyu): This function should take layouts as arguments. It doesn't need +// the ops. // Validates send/recv layout and mesh configurations. Among other things, this // checks for below constraints. // 1. Src/target layouts have non empty mesh. @@ -60,11 +62,14 @@ Status ValidateSendRecvLayoutConfiguration(mlir::TF::DTensorSend dtensor_send, return absl::InvalidArgumentError( "Input to DTensorSend must have specified layout."); + TF_ASSIGN_OR_RETURN(const Layout output_layout, + ExtractRequiredSingleLayoutFromOp(dtensor_recv)); + const Layout& send_layout = send_layout_or_null.value(); - const Layout recv_layout = dtensor_recv.getLayout(); + const Layout& recv_layout = output_layout; + const Mesh& recv_mesh = dtensor_recv.getMesh(); const Mesh& send_mesh = send_layout.mesh(); - const Mesh& recv_mesh = recv_layout.mesh(); // If any one of send/recv mesh are empty, return error. if (send_mesh.IsEmpty() || recv_mesh.IsEmpty()) @@ -322,7 +327,7 @@ DTensorRecvSPMDExpander::ComputeLayoutForward( return absl::InvalidArgumentError( llvm::formatv("Expecting DTensorRecvOp but got {0}", OpName(op)).str()); } - return llvm::DenseMap({{0, dtensor_recv.getLayout()}}); + return llvm::DenseMap(); } StatusOr> diff --git a/tensorflow/dtensor/mlir/expansions/save_restore_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/save_restore_spmd_expander.cc index 0b52945306e1be..f98892f1915ce2 100644 --- a/tensorflow/dtensor/mlir/expansions/save_restore_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/save_restore_spmd_expander.cc @@ -821,6 +821,7 @@ StatusOr> GetLayoutsFromAssignVariableOps( // an IdentityOp, CastOp, or a DTensorSend op on the path. So, skip past // these ops first. while (llvm::isa(consuming_op)) { if (auto send_op = mlir::dyn_cast_or_null(consuming_op)) { diff --git a/tensorflow/dtensor/mlir/handle_cross_cluster_dependencies.cc b/tensorflow/dtensor/mlir/handle_cross_cluster_dependencies.cc index 14420275f837c4..422a4d4a8cafd1 100644 --- a/tensorflow/dtensor/mlir/handle_cross_cluster_dependencies.cc +++ b/tensorflow/dtensor/mlir/handle_cross_cluster_dependencies.cc @@ -25,6 +25,7 @@ limitations under the License. #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/UseDefLists.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/DebugStringHelper.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project #include "mlir/Transforms/RegionUtils.h" // from @llvm-project @@ -49,7 +50,7 @@ constexpr char kMissingMeshErrorMsg[] = constexpr char kInvalidTensorTransferErrorMsg[] = "CopyToMeshOp must be used to send data across mesh."; -constexpr char kInvalidLayoutMsg[] = +constexpr char kInvalidMeshMsg[] = "found CopyToMesh with invalid layout. Found layout {0}. Error: {1}."; // Extracts mesh from `cluster`. @@ -87,23 +88,17 @@ mlir::LogicalResult CloneOpToCluster(mlir::Operation* const_op, auto copy_to_mesh = llvm::dyn_cast(operand->getOwner()); assert(copy_to_mesh); - const std::string layout_attr = copy_to_mesh.getLayout().str(); - StatusOr layout = Layout::FromString(layout_attr); - if (!layout.ok()) - return copy_to_mesh.emitOpError(llvm::formatv( - kInvalidLayoutMsg, layout_attr, layout.status().message())); + const std::string mesh_attr = copy_to_mesh.getMesh().str(); + StatusOr mesh = Mesh::FromString(mesh_attr); + if (!mesh.ok()) + return copy_to_mesh.emitOpError( + llvm::formatv(kInvalidMeshMsg, mesh_attr, mesh.status().message())); mlir::OpBuilder builder(&cluster.GetBody().front()); mlir::Operation* cloned_op = builder.clone(*const_op); - mlir::TensorType type = - cloned_op->getResult(0).getType().cast(); - auto layout_op = builder.create( - const_op->getLoc(), cloned_op->getResult(0), - mlir::dtensor::LayoutAttr::get(builder.getContext(), *layout), - mlir::TF::ShapeAttr::get(builder.getContext(), type)); copy_to_mesh.getOutput().replaceUsesWithIf( - layout_op.getOutput(), [&](mlir::OpOperand& operand) { + cloned_op->getResult(0), [&](mlir::OpOperand& operand) { return cluster.getOperation()->isProperAncestor(operand.getOwner()); }); @@ -176,9 +171,7 @@ mlir::LogicalResult CloneConstantsAcrossMesh( } // Handles CopyToMesh ops within the same cluster. These should not lower to -// send or recv as we can directly replace it with a Relayout. If the source and -// target layouts are the same, this is handled separately within Relayout -// lowering. +// send or recv; we can directly replace it with an Identity. mlir::LogicalResult HandleCopyToMeshWithinCluster( mlir::tf_device::ClusterOp cluster) { Mesh current_mesh; @@ -205,9 +198,9 @@ mlir::LogicalResult HandleCopyToMeshWithinCluster( } } mlir::OpBuilder builder(op); - auto relayout_op = builder.create( - op.getLoc(), input.getType(), input, op.getLayout()); - op->getResult(0).replaceAllUsesWith(relayout_op.getOutput()); + auto identity_op = builder.create( + op.getLoc(), input.getType(), input); + op->getResult(0).replaceAllUsesWith(identity_op.getOutput()); op->erase(); return mlir::WalkResult::advance(); }); @@ -234,20 +227,20 @@ mlir::LogicalResult LowerToSendRecv(mlir::TF::CopyToMeshOp copy_to_mesh, mlir::OpBuilder builder(value_to_send.getParentBlock()->getTerminator()); const std::string op_key = - llvm::formatv("communication_key_{0}_{1}", copy_to_mesh.getLayout(), + llvm::formatv("communication_key_{0}_{1}", copy_to_mesh.getMesh(), *send_recv_counter) .str(); - const std::string layout_attr = copy_to_mesh.getLayout().str(); - auto layout_or_status = Layout::FromString(layout_attr); - if (!layout_or_status.ok()) + const std::string mesh_attr = copy_to_mesh.getMesh().str(); + auto mesh_or_status = Mesh::FromString(mesh_attr); + if (!mesh_or_status.ok()) return copy_to_mesh.emitOpError(llvm::formatv( - kInvalidLayoutMsg, layout_attr, layout_or_status.status().message())); + kInvalidMeshMsg, mesh_attr, mesh_or_status.status().message())); // Create send op that sends data from input cluster to target cluster. - const Layout& target_layout = layout_or_status.value(); + const Mesh& target_mesh = mesh_or_status.value(); builder.create( copy_to_mesh.getLoc(), value_to_send, builder.getStringAttr(op_key), - mlir::dtensor::LayoutAttr::get(context, target_layout)); + mlir::dtensor::MeshAttr::get(context, target_mesh)); // Create recv op that recvs data from send op. auto tensor_type = value_to_send.getType().dyn_cast(); @@ -261,7 +254,7 @@ mlir::LogicalResult LowerToSendRecv(mlir::TF::CopyToMeshOp copy_to_mesh, copy_to_mesh.getLoc(), value_to_send.getType(), builder.getStringAttr(op_key), mlir::TF::ShapeAttr::get(context, tensor_type), - mlir::dtensor::LayoutAttr::get(context, target_layout)); + mlir::dtensor::MeshAttr::get(context, target_mesh)); // Replace value for recv ops for all usages of `copy_to_mesh` op. copy_to_mesh.replaceAllUsesWith(recv_op.getOutput()); @@ -358,6 +351,50 @@ mlir::LogicalResult ReplaceCopyToMeshWithVirtualSendRecv( return result; } +// Inserts a CopyToMesh Op to represent the mesh part of the Relayout. +// Only insert to Cross mesh Relayout ops. +mlir::LogicalResult InsertCopyToMesh(mlir::tf_device::ClusterOp cluster) { + Mesh mesh; + if (mlir::failed(ExtractMeshFromCluster(cluster, &mesh))) { + return mlir::failure(); + } + llvm::SmallVector relayout_ops; + + cluster.walk([&](mlir::Operation* op) { + if (!mlir::isa(op) && + !mlir::isa(op)) { + return; + } + relayout_ops.push_back(op); + }); + + for (mlir::Operation* op : relayout_ops) { + mlir::Value input = op->getOperand(0); + + auto input_cluster = + mlir::dyn_cast(input.getDefiningOp()); + if (!input_cluster) { + input_cluster = + input.getDefiningOp()->getParentOfType(); + } + if (!input_cluster) { + op->emitOpError() << "Input cluster not found."; + return mlir::failure(); + } + Mesh input_mesh; + if (mlir::failed(ExtractMeshFromCluster(input_cluster, &input_mesh))) { + return mlir::failure(); + } + if (input_mesh == mesh) continue; + mlir::OpBuilder builder(op); + + auto new_op = builder.create( + op->getLoc(), op->getResult(0).getType(), input, mesh.ToString()); + op->replaceUsesOfWith(input, new_op.getResult()); + } + return mlir::success(); +} + struct DTensorHandleCrossClusterDependencies : public impl::DTensorHandleCrossClusterDependenciesBase< DTensorHandleCrossClusterDependencies> { @@ -374,6 +411,10 @@ struct DTensorHandleCrossClusterDependencies }); int send_recv_counter = 0; + for (auto cluster : clusters) { + if (mlir::failed(InsertCopyToMesh(cluster))) return signalPassFailure(); + } + for (auto cluster : clusters) { if (mlir::failed(CloneConstantsAcrossMesh(cluster))) return signalPassFailure(); diff --git a/tensorflow/dtensor/mlir/ir/tf_dtensor.td b/tensorflow/dtensor/mlir/ir/tf_dtensor.td index b8d56515af8053..999d8df041e74d 100644 --- a/tensorflow/dtensor/mlir/ir/tf_dtensor.td +++ b/tensorflow/dtensor/mlir/ir/tf_dtensor.td @@ -39,6 +39,11 @@ def DTensor_LayoutAttr : DTensor_DTensorAttr<"Layout", "layout"> { let convertFromStorage = "$_self.cast().getValue()"; } +def DTensor_MeshAttr : DTensor_DTensorAttr<"Mesh", "mesh"> { + let returnType = "mlir::dtensor::MeshAttr::Mesh"; + let convertFromStorage = "$_self.cast().getValue()"; +} + //===----------------------------------------------------------------------===// // DTensor op definitions //===----------------------------------------------------------------------===// @@ -49,7 +54,7 @@ def Tf_DTensorSend : TF_Op<"DTensorSend", []> { let arguments = (ins TF_Tensor:$input, StrAttr:$key, - DTensor_LayoutAttr:$target_layout + DTensor_MeshAttr:$target_mesh ); let results = (outs); @@ -63,7 +68,7 @@ def Tf_DTensorRecv : TF_Op<"DTensorRecv", []> { let arguments = (ins StrAttr:$key, TF_ShapeAttr:$shape, - DTensor_LayoutAttr:$layout + DTensor_MeshAttr:$mesh ); let results = (outs TF_Tensor:$output @@ -127,8 +132,7 @@ def TF_CopyToMeshOp : TF_Op<"CopyToMesh", [Pure]> { let arguments = (ins TF_Tensor:$input, - - StrAttr:$layout + StrAttr:$mesh ); let results = (outs @@ -143,10 +147,7 @@ def TF_CopyToMeshGradOp : TF_Op<"CopyToMeshGrad", [Pure]> { let arguments = (ins TF_Tensor:$input, - - TF_Tensor:$forward_input, - - StrAttr:$reference_layout + TF_Tensor:$forward_input ); let results = (outs diff --git a/tensorflow/dtensor/mlir/layout_propagation_v2.cc b/tensorflow/dtensor/mlir/layout_propagation_v2.cc index 6cc87e7e3c564c..d2ccee7dc2bbd1 100644 --- a/tensorflow/dtensor/mlir/layout_propagation_v2.cc +++ b/tensorflow/dtensor/mlir/layout_propagation_v2.cc @@ -57,6 +57,7 @@ limitations under the License. #include "tensorflow/dtensor/mlir/dtensor_dialect/ir/dialect.h" #include "tensorflow/dtensor/mlir/dtensor_dialect/ir/dtensor_attributes.h" #include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h" +#include "tensorflow/dtensor/mlir/dtensor_send_recv.h" #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h" #include "tensorflow/dtensor/mlir/layout_parsing.h" #include "tensorflow/dtensor/mlir/op_utils.h" @@ -711,6 +712,8 @@ mlir::LogicalResult InsertDTensorLayoutOps( int num_users = std::distance(users.begin(), users.end()); if (num_users == 1 && mlir::isa(*users.begin())) continue; + auto layout_attr = mlir::dtensor::LayoutAttr::get(builder.getContext(), + merged_layout.second); builder.setInsertionPointAfterValue(merged_layout.first); // Handles resource and variant as the real shape is embedded in the // resource type elements. @@ -718,9 +721,7 @@ mlir::LogicalResult InsertDTensorLayoutOps( if (auto type = value_type.dyn_cast()) { auto layout_op = builder.create( - merged_layout.first.getLoc(), merged_layout.first, - mlir::dtensor::LayoutAttr::get(builder.getContext(), - merged_layout.second), + merged_layout.first.getLoc(), merged_layout.first, layout_attr, mlir::TF::ShapeAttr::get(builder.getContext(), type)); llvm::SmallPtrSet exception{layout_op}; merged_layout.first.replaceAllUsesExcept(layout_op.getOutput(), @@ -728,9 +729,67 @@ mlir::LogicalResult InsertDTensorLayoutOps( } else { mlir::emitError(merged_layout.first.getLoc()) << "value type is not TensorType as expected."; + return mlir::failure(); + } + } + return mlir::success(); +} + +mlir::LogicalResult UpdateDTensorSendRecvOps(mlir::ModuleOp module, + mlir::OpBuilder& builder) { + llvm::SmallVector send_ops; + llvm::SmallVector recv_ops; + module.walk([&](mlir::Operation* op) { + if (auto send_op = llvm::dyn_cast(op)) + send_ops.emplace_back(send_op); + if (auto recv_op = llvm::dyn_cast(op)) + recv_ops.emplace_back(recv_op); + }); + + for (auto send_op : send_ops) { + absl::StatusOr layout = + ExtractRequiredLayoutFromOperand(send_op->getOperand(0)); + if (!layout.ok()) { + send_op->emitOpError() + << "Cannot add source layout to DTensorRecv for DTensorSend: " + << layout.status().message(); + return mlir::failure(); } + absl::StatusOr recv_op = + GetCorrespondingDTensorSendRecvOp(module, + send_op); + if (!recv_op.ok()) { + send_op->emitOpError() + << "Cannot add source layout to DTensorRecv for DTensorSend: " + << recv_op.status().message(); + return mlir::failure(); + } + auto layout_attr = + mlir::dtensor::LayoutAttr::get(builder.getContext(), *layout); + recv_op.value()->setAttr(kSourceLayoutAttr, layout_attr); } + for (auto recv_op : recv_ops) { + absl::StatusOr layout = ExtractRequiredSingleLayoutFromOp(recv_op); + if (!layout.ok()) { + recv_op->emitOpError() + << "Cannot add source layout to DTensorRecv for DTensorSend: " + << layout.status().message(); + return mlir::failure(); + } + absl::StatusOr send_op = + GetCorrespondingDTensorSendRecvOp(module, + recv_op); + if (!send_op.ok()) { + recv_op->emitOpError() + << "Cannot add target layout to DTensorSend for DTensorRecv: " + << send_op.status().message(); + return mlir::failure(); + } + auto layout_attr = + mlir::dtensor::LayoutAttr::get(builder.getContext(), *layout); + send_op.value()->setAttr(kTargetLayoutAttr, layout_attr); + } return mlir::success(); } @@ -1565,6 +1624,9 @@ struct DLayoutPropagationPassV2 if (mlir::failed( InsertDTensorLayoutForIfRegionOp(if_ops, builder.getContext()))) return signalPassFailure(); + + if (mlir::failed(UpdateDTensorSendRecvOps(module, builder))) + return signalPassFailure(); }; }; diff --git a/tensorflow/dtensor/mlir/merge_clusters.cc b/tensorflow/dtensor/mlir/merge_clusters.cc index 37ab79d1325bdf..3aa323ec1d4cc1 100644 --- a/tensorflow/dtensor/mlir/merge_clusters.cc +++ b/tensorflow/dtensor/mlir/merge_clusters.cc @@ -283,11 +283,10 @@ void CloneEmptyIfWithPredicate(mlir::TF::IfRegionOp if_region, const Mesh& mesh, absl::StrCat(kSendRecvKeyPrefix, *num_send_recvs); *num_send_recvs += 1; - const Layout target_layout = Layout::ReplicatedOnMesh(mesh, 0); builder.create( if_region.getLoc(), if_region.getCond(), builder.getStringAttr(send_recv_key), - mlir::dtensor::LayoutAttr::get(context, target_layout)); + mlir::dtensor::MeshAttr::get(context, mesh)); // Create new cluster op that contains cloned if operation. auto new_cluster = builder.create( @@ -303,7 +302,7 @@ void CloneEmptyIfWithPredicate(mlir::TF::IfRegionOp if_region, const Mesh& mesh, if_region.getLoc(), predicate_tensor_type, builder.getStringAttr(send_recv_key), mlir::TF::ShapeAttr::get(context, predicate_tensor_type), - mlir::dtensor::LayoutAttr::get(context, target_layout)); + mlir::dtensor::MeshAttr::get(context, mesh)); // Clone tf.IfRegion op inside newly created cluster and make sure // that the predicate tensor is from DTensorRecv op created above. diff --git a/tensorflow/dtensor/mlir/mesh_propagation.cc b/tensorflow/dtensor/mlir/mesh_propagation.cc index 54d547cc82683e..2bfdd39a543914 100644 --- a/tensorflow/dtensor/mlir/mesh_propagation.cc +++ b/tensorflow/dtensor/mlir/mesh_propagation.cc @@ -195,10 +195,9 @@ mlir::LogicalResult InferMeshFromInputs( // inputs. `tf.CopyToMesh` specifies that all operations following the // operation is executed on target device mesh cluster specified by // `tf.CopyToMesh`. - if (llvm::isa(&cluster.GetBody().front())) { - return result; - } - if (llvm::isa(&cluster.GetBody().front())) { + if (llvm::isa( + &cluster.GetBody().front())) { return result; } @@ -288,8 +287,11 @@ mlir::LogicalResult InferMeshFromConsumers( // `tf.CopyToMesh`. Therefore, if `consumer` operation is `tf.CopyToMesh` // do not propagate mesh backwards to `cluster`. if (llvm::isa(consumer)) continue; + if (llvm::isa(consumer)) continue; if (llvm::isa(&cluster.GetBody().front())) continue; + if (llvm::isa(&cluster.GetBody().front())) + continue; Mesh extracted_mesh; @@ -672,23 +674,20 @@ mlir::LogicalResult DTensorMeshPropagation::PropagateMeshFromConsumers( return mlir::success(); } -mlir::LogicalResult RewriteCopyToMeshGradOp( +mlir::LogicalResult PropagateLikeMesh( const llvm::DenseMap>& producers, mlir::tf_device::ClusterOp cluster, mlir::OpBuilder* builder, bool* mesh_changed) { - auto backward_op = llvm::dyn_cast_or_null( - &cluster.GetBody().front()); - if (!backward_op) { + mlir::Operation* backward_op = &cluster.GetBody().front(); + + if (!mlir::isa(backward_op) && + !mlir::isa(backward_op)) { // No CopyToMeshGradOp is found. Either the cluster did not have one, // or it has been rewritten from previous iterations. return mlir::success(); } - if (cluster->getAttrOfType(kMeshAttr)) { - return backward_op.emitOpError( - "A cluster with CopyToMeshGrad is already assigned a mesh. " - "This indicates an internal error."); - } + auto old_mesh = cluster->getAttrOfType(kMeshAttr); std::optional mesh; mlir::OpOperand& operand = backward_op->getOpOperand(1); // forward_input(); @@ -697,27 +696,14 @@ mlir::LogicalResult RewriteCopyToMeshGradOp( if (mlir::failed(ExtractMeshFromOperand(producers, &operand, &mesh))) { return mlir::success(); } - cluster->setAttr(kMeshAttr, builder->getStringAttr(mesh->ToString())); - - // Rewrites to CopyToMesh, by combining the sharding spec of the reference - // layout with the mesh. - // This assumes the CopyToMesh maintains the layout of the input and only - // changes the mesh. - builder->setInsertionPoint(backward_op); - StatusOr layout = - Layout::FromString(backward_op.getReferenceLayout().str()); - if (!layout.ok()) { - return backward_op.emitOpError("Failure passing layout: ") - << backward_op.getReferenceLayout().str(); + if (old_mesh != nullptr) { + if (old_mesh.getValue().str() == mesh->ToString()) { + return mlir::success(); + } } - layout->set_mesh(mesh.value()); - auto op = builder->create( - backward_op->getLoc(), backward_op->getResult(0).getType(), - backward_op.getInput(), layout->ToString()); + cluster->setAttr(kMeshAttr, builder->getStringAttr(mesh->ToString())); - backward_op->replaceAllUsesWith(op); - backward_op->erase(); *mesh_changed = true; return mlir::success(); } @@ -760,8 +746,8 @@ mlir::LogicalResult DTensorMeshPropagation::PropagateMesh( return mlir::failure(); } for (auto cluster : llvm::reverse(cluster_ops)) { - if (mlir::failed(RewriteCopyToMeshGradOp(producers, cluster, builder, - mesh_changed))) { + if (mlir::failed( + PropagateLikeMesh(producers, cluster, builder, mesh_changed))) { return mlir::failure(); } } diff --git a/tensorflow/dtensor/mlir/op_to_device_cluster.cc b/tensorflow/dtensor/mlir/op_to_device_cluster.cc index 02c7fc73dc83bc..87530f1faed7fa 100644 --- a/tensorflow/dtensor/mlir/op_to_device_cluster.cc +++ b/tensorflow/dtensor/mlir/op_to_device_cluster.cc @@ -58,16 +58,15 @@ mlir::LogicalResult WrapDeviceCluster(mlir::OpBuilder *builder, if (auto layout_op = llvm::dyn_cast(op)) { cluster->setAttr(kMeshAttr, builder->getStringAttr( layout_op.getLayout().mesh().ToString())); - } else if (auto copy_to_mesh = llvm::dyn_cast(op)) { + } else if (auto copy_to_mesh = llvm::dyn_cast(op)) { const std::string layout_string = copy_to_mesh.getLayout().str(); - auto layout_or = Layout::FromString(layout_string); - if (!layout_or.ok()) - return op->emitOpError( - llvm::formatv("Found tf.CopyToMesh Op with unparsable layout : {0}", - layout_string)); + auto layout = Layout::FromString(layout_string); + if (!layout.ok()) + return op->emitOpError(llvm::formatv( + "Found tf.Relayout Op with unparsable layout: {0}", layout_string)); cluster->setAttr(kMeshAttr, - builder->getStringAttr(layout_or->mesh().ToString())); + builder->getStringAttr(layout->mesh().ToString())); } else { // If mesh configuration can be inferred from the op directly, use the mesh // information from op attribute directly. If op is not annotated with mesh diff --git a/tensorflow/dtensor/mlir/restore_shape_inference.cc b/tensorflow/dtensor/mlir/restore_shape_inference.cc index 1bce632d6ebea5..abbda8cbdd80ae 100644 --- a/tensorflow/dtensor/mlir/restore_shape_inference.cc +++ b/tensorflow/dtensor/mlir/restore_shape_inference.cc @@ -43,10 +43,12 @@ mlir::LogicalResult BackwardShapeInferenceToRestoreOp(mlir::ModuleOp module, mlir::Type type) { mlir::Operation* op = value.getDefiningOp(); if (op == nullptr) return mlir::success(); - if (!llvm::isa(op)) { return op->emitOpError( - llvm::formatv("Expected an Identity, Cast, DTensorRecv, or RestoreV2 " + llvm::formatv("Expected an Identity, Relayout, Cast, DTensorLayout, " + "DTensorRecv, or RestoreV2 " "op, but got: {0}. Please file a bug to the DTensor team." "(component id: 833864)", op->getName().getStringRef())); @@ -98,6 +100,18 @@ mlir::LogicalResult BackwardShapeInferenceToRestoreOp(mlir::ModuleOp module, // Recursively shape inference to the input of the identity op. return BackwardShapeInferenceToRestoreOp(module, builder, new_identity_op.getInput(), type); + } else if (auto relayout_op = + llvm::dyn_cast_or_null(op)) { + relayout_op->getResult(0).setType(type); + // Recursively shape inference to the input of the identity op. + return BackwardShapeInferenceToRestoreOp(module, builder, + relayout_op->getOperand(0), type); + } else if (auto layout_op = + llvm::dyn_cast_or_null(op)) { + layout_op->getResult(0).setType(type); + // Recursively shape inference to the input of the identity op. + return BackwardShapeInferenceToRestoreOp(module, builder, + layout_op->getOperand(0), type); } else if (auto recv_op = llvm::dyn_cast_or_null(op)) { // If we have a DTensorRecv, then there is cross mesh action and the // RestoreV2Op we want to fix is on the mesh of the corresponding @@ -107,8 +121,7 @@ mlir::LogicalResult BackwardShapeInferenceToRestoreOp(mlir::ModuleOp module, recv_op.getLoc(), type, builder->getStringAttr(recv_op.getKey()), mlir::TF::ShapeAttr::get(builder->getContext(), type.dyn_cast()), - mlir::dtensor::LayoutAttr::get(builder->getContext(), - recv_op.getLayout())); + mlir::dtensor::MeshAttr::get(builder->getContext(), recv_op.getMesh())); recv_op.replaceAllUsesWith(new_recv_op.getOutput()); recv_op.erase(); @@ -137,7 +150,7 @@ mlir::LogicalResult BackwardShapeInferenceToRestoreOp(mlir::ModuleOp module, // leading up to the tf.RestoreV2 op. mlir::LogicalResult PropagateShapeInformationFromAssignVariableOp( mlir::ModuleOp module) { - module.walk([&](mlir::TF::AssignVariableOp assign_op) { + auto result = module.walk([&](mlir::TF::AssignVariableOp assign_op) { // Check that the `value` has an unknown shape. if (ValueRank(assign_op.getValue()) == -1) { StatusOr> shape = @@ -148,7 +161,7 @@ mlir::LogicalResult PropagateShapeInformationFromAssignVariableOp( "missing it during CheckpointShapeInference."); return mlir::WalkResult::interrupt(); } - // Propagete shape backwards to all the ops that use or produce + // Propagate shape backwards to all the ops that use or produce // the value with missing shape. mlir::OpBuilder builder(assign_op); mlir::Type known_type = GetSubtypeOrSelf(assign_op.getResource()); @@ -163,6 +176,7 @@ mlir::LogicalResult PropagateShapeInformationFromAssignVariableOp( return mlir::WalkResult::advance(); }); + if (result.wasInterrupted()) return mlir::failure(); return mlir::success(); } diff --git a/tensorflow/dtensor/mlir/tests/handle_cross_cluster_dependencies.mlir b/tensorflow/dtensor/mlir/tests/handle_cross_cluster_dependencies.mlir index 2bb3b0cb3931ff..83b4c6aa041949 100644 --- a/tensorflow/dtensor/mlir/tests/handle_cross_cluster_dependencies.mlir +++ b/tensorflow/dtensor/mlir/tests/handle_cross_cluster_dependencies.mlir @@ -29,7 +29,7 @@ func.func @main() -> tensor { %2 = "tf_device.cluster"() ({ - %3 = "tf.CopyToMesh"(%0#0) { layout ="sharding_specs:unsharded, mesh:TPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3"} : (tensor) -> (tensor) + %3 = "tf.Relayout"(%0#0) { layout="sharding_specs:unsharded, mesh:TPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3"} : (tensor) -> (tensor) %4 = "tf.Neg"(%3) : (tensor) -> tensor tf_device.return %4 : tensor }) {_mesh="TPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3"} : () -> (tensor) @@ -54,13 +54,13 @@ func.func @main() -> tensor { // CHECK: "tf_device.cluster" // CHECK-NEXT: %[[CONST_OUT:.*]] = "tf.Const"() - // CHECK-NEXT: %[[LAYOUT_OUT:.*]] = "tf.DTensorLayout"(%[[CONST_OUT]]) - // CHECK-SAME: layout = #dtensor.layout + // CHECK-NEXT: %[[LAYOUT_OUT:.*]] = "tf.Relayout"(%[[CONST_OUT]]) + // CHECK-SAME: layout = "sharding_specs:scalar, mesh:TPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3" // CHECK-NEXT: %[[NEG_OUT:.*]] = "tf.Neg"(%[[LAYOUT_OUT]] // CHECK-NEXT: tf_device.return // CHECK-NEXT: () -> () %2 = "tf_device.cluster"() ({ - %3 = "tf.CopyToMesh"(%0#0) { layout ="sharding_specs:scalar, mesh:TPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3"} : (tensor) -> (tensor) + %3 = "tf.Relayout"(%0#0) { layout ="sharding_specs:scalar, mesh:TPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3"} : (tensor) -> (tensor) %4 = "tf.Neg"(%3) : (tensor) -> tensor tf_device.return %4 : tensor }) {_mesh="TPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3"} : () -> (tensor) @@ -76,8 +76,8 @@ func.func @main() -> tensor { // CHECK-NEXT: %[[A_OUT:.*]] = "tf.A"() // CHECK-NEXT: %[[NEG_OUT:.*]] = "tf.Neg"(%[[A_OUT]] // CHECK-NEXT: "tf.DTensorSend"(%[[A_OUT]] - // CHECK-SAME: key = "communication_key_sharding_specs:unsharded, mesh:TPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3_0" - // CHECK-SAME: target_layout = #dtensor.layout + // CHECK-SAME: key = "communication_key_TPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3_0" + // CHECK-SAME: target_mesh = #dtensor.mesh // CHECK-NEXT: tf_device.return // CHECK-NEXT: () -> tensor %0:2 = "tf_device.cluster"() ({ @@ -88,13 +88,14 @@ func.func @main() -> tensor { // CHECK: "tf_device.cluster" // CHECK-NEXT: %[[RECV_OUT:.*]] = "tf.DTensorRecv"() - // CHECK-SAME: key = "communication_key_sharding_specs:unsharded, mesh:TPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3_0" - // CHECK-SAME: layout = #dtensor.layout - // CHECK-NEXT: %[[NEG_OUT:.*]] = "tf.Neg"(%[[RECV_OUT]] + // CHECK-SAME: key = "communication_key_TPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3_0" + // CHECK-SAME: mesh = #dtensor.mesh + // CHECK-NEXT: %[[RECVRELAYOUT_OUT:.*]] = "tf.Relayout"(%[[RECV_OUT]] + // CHECK-NEXT: %[[NEG_OUT:.*]] = "tf.Neg"(%[[RECVRELAYOUT_OUT]] // CHECK-NEXT: tf_device.return // CHECK-NEXT: () -> () %2 = "tf_device.cluster"() ({ - %3 = "tf.CopyToMesh"(%0#0) { layout ="sharding_specs:unsharded, mesh:TPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3"} : (tensor) -> (tensor) + %3 = "tf.Relayout"(%0#0) { layout ="sharding_specs:unsharded, mesh:TPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3"} : (tensor) -> (tensor) %4 = "tf.Neg"(%3) : (tensor) -> tensor tf_device.return %4 : tensor }) {_mesh="TPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3"} : () -> (tensor) @@ -147,11 +148,11 @@ func.func @main() { // CHECK-NEXT: %[[A_OUT:.*]] = "tf.A"() // CHECK-NEXT: %[[NEG_OUT:.*]] = "tf.Neg"(%[[A_OUT]] // CHECK-NEXT: "tf.DTensorSend"(%[[NEG_OUT]] - // CHECK-SAME: key = "communication_key_sharding_specs:unsharded, mesh:TPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3_0" - // CHECK-SAME: target_layout = #dtensor.layout + // CHECK-SAME: key = "communication_key_TPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3_0" + // CHECK-SAME: target_mesh = #dtensor.mesh // CHECK-NEXT: "tf.DTensorSend"(%[[NEG_OUT]] - // CHECK-SAME: key = "communication_key_sharding_specs:unsharded, mesh:GPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:GPU:0,/job:localhost/task:0/device:GPU:1,/job:localhost/task:0/device:GPU:2,/job:localhost/task:0/device:GPU:3_1" - // CHECK-SAME: target_layout = #dtensor.layout + // CHECK-SAME: key = "communication_key_GPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:GPU:0,/job:localhost/task:0/device:GPU:1,/job:localhost/task:0/device:GPU:2,/job:localhost/task:0/device:GPU:3_1" + // CHECK-SAME: target_mesh = #dtensor.mesh // CHECK-NEXT: tf_device.return // CHECK-NEXT: () -> tensor %0 = "tf_device.cluster"() ({ @@ -162,26 +163,28 @@ func.func @main() { // CHECK: "tf_device.cluster" // CHECK-NEXT: %[[RECV_OUT_1:.*]] = "tf.DTensorRecv"() - // CHECK-SAME: key = "communication_key_sharding_specs:unsharded, mesh:TPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3_0" - // CHECK-SAME: layout = #dtensor.layout - // CHECK-NEXT: %[[NEG_OUT_1:.*]] = "tf.Neg"(%[[RECV_OUT_1]] + // CHECK-SAME: key = "communication_key_TPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3_0" + // CHECK-SAME: mesh = #dtensor.mesh + // CHECK-NEXT: %[[RELAYOUT_OUT_1:.*]] = "tf.Relayout"(%[[RECV_OUT_1]] + // CHECK-NEXT: %[[NEG_OUT_1:.*]] = "tf.Neg"(%[[RELAYOUT_OUT_1]] // CHECK-NEXT: tf_device.return // CHECK-NEXT: () -> () %2 = "tf_device.cluster"() ({ - %3 = "tf.CopyToMesh"(%0) { layout ="sharding_specs:unsharded, mesh:TPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3"} : (tensor) -> (tensor) + %3 = "tf.Relayout"(%0) { layout ="sharding_specs:unsharded, mesh:TPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3"} : (tensor) -> (tensor) %4 = "tf.Neg"(%3) : (tensor) -> tensor tf_device.return %4 : tensor }) {_mesh="TPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3"} : () -> (tensor) // CHECK: "tf_device.cluster" // CHECK-NEXT: %[[RECV_OUT_2:.*]] = "tf.DTensorRecv"() - // CHECK-SAME: key = "communication_key_sharding_specs:unsharded, mesh:GPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:GPU:0,/job:localhost/task:0/device:GPU:1,/job:localhost/task:0/device:GPU:2,/job:localhost/task:0/device:GPU:3_1" - // CHECK-SAME: layout = #dtensor.layout - // CHECK-NEXT: %[[NEG_OUT_2:.*]] = "tf.Neg"(%[[RECV_OUT_2]] + // CHECK-SAME: key = "communication_key_GPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:GPU:0,/job:localhost/task:0/device:GPU:1,/job:localhost/task:0/device:GPU:2,/job:localhost/task:0/device:GPU:3_1" + // CHECK-SAME: mesh = #dtensor.mesh + // CHECK-NEXT: %[[RELAYOUT_OUT_2:.*]] = "tf.Relayout"(%[[RECV_OUT_2]] + // CHECK-NEXT: %[[NEG_OUT_2:.*]] = "tf.Neg"(%[[RELAYOUT_OUT_2]] // CHECK-NEXT: tf_device.return // CHECK-NEXT: () -> () %3 = "tf_device.cluster"() ({ - %4 = "tf.CopyToMesh"(%0) { layout ="sharding_specs:unsharded, mesh:GPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:GPU:0,/job:localhost/task:0/device:GPU:1,/job:localhost/task:0/device:GPU:2,/job:localhost/task:0/device:GPU:3"} : (tensor) -> (tensor) + %4 = "tf.Relayout"(%0) { layout ="sharding_specs:unsharded, mesh:GPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:GPU:0,/job:localhost/task:0/device:GPU:1,/job:localhost/task:0/device:GPU:2,/job:localhost/task:0/device:GPU:3"} : (tensor) -> (tensor) %5 = "tf.Neg"(%4) : (tensor) -> tensor tf_device.return %4 : tensor }) {_mesh="GPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:GPU:0,/job:localhost/task:0/device:GPU:1,/job:localhost/task:0/device:GPU:2,/job:localhost/task:0/device:GPU:3"} : () -> (tensor) @@ -216,16 +219,16 @@ func.func @main() -> tensor { // CHECK: "tf_device.cluster" // CHECK-NEXT: %[[CONST_OUT:.*]] = "tf.Const"() - // CHECK-NEXT: %[[LAYOUT_OUT:.*]] = "tf.DTensorLayout"(%[[CONST_OUT]]) - // CHECK-SAME: layout = #dtensor.layout + // CHECK-NEXT: %[[LAYOUT_OUT:.*]] = "tf.Relayout"(%[[CONST_OUT]]) + // CHECK-SAME: layout = "sharding_specs:scalar, mesh:TPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3" // CHECK-NEXT: %[[RELAYOUT_OUT:.*]] = "tf.Relayout"(%[[LAYOUT_OUT]]) // CHECK-SAME: layout = "sharding_specs:scalar, mesh:TPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3" // CHECK-NEXT: %[[NEG_OUT:.*]] = "tf.Neg"(%[[RELAYOUT_OUT]] // CHECK-NEXT: tf_device.return // CHECK-NEXT: () -> () %2 = "tf_device.cluster"() ({ - %3 = "tf.CopyToMesh"(%0#0) { layout ="sharding_specs:scalar, mesh:TPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3" } : (tensor) -> (tensor) - %4 = "tf.CopyToMesh"(%3) { layout ="sharding_specs:scalar, mesh:TPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3" } : (tensor) -> (tensor) + %3 = "tf.Relayout"(%0#0) { layout ="sharding_specs:scalar, mesh:TPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3" } : (tensor) -> (tensor) + %4 = "tf.Relayout"(%3) { layout ="sharding_specs:scalar, mesh:TPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3" } : (tensor) -> (tensor) %5 = "tf.Neg"(%4) : (tensor) -> tensor tf_device.return %5 : tensor }) {_mesh="TPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3"} : () -> (tensor) diff --git a/tensorflow/dtensor/mlir/tests/layout_propagation_v2.mlir b/tensorflow/dtensor/mlir/tests/layout_propagation_v2.mlir index 20c8c9a7b4c79b..2e9b079319b0e0 100644 --- a/tensorflow/dtensor/mlir/tests/layout_propagation_v2.mlir +++ b/tensorflow/dtensor/mlir/tests/layout_propagation_v2.mlir @@ -807,9 +807,10 @@ func.func @main(%arg0: tensor, tf._mesh = "|x=2|0,1|0,1|/job:localhost/replica:0/task:0/device:TPU:0,/job:localhost/replica:0/task:0/device:TPU:1"}) { // CHECK: "tf_device.cluster" "tf_device.cluster"() ({ - %1 = "tf.DTensorRecv"() {key = "communication_key_sharding_specs:unsharded,unsharded, mesh:|x=2|0,1|0,1|/job:localhost/replica:0/task:0/device:TPU:0,/job:localhost/replica:0/task:0/device:TPU:1", layout = #dtensor.layout, shape = #tf_type.shape<4x8>} : () -> tensor<4x8xf32> - %2 = "tf.Identity"(%1) : (tensor<4x8xf32>) -> tensor<4x8xf32> - "tf.AssignVariableOp"(%arg4, %2) {validate_shape = true} : (tensor<*x!tf_type.resource>>, tensor<4x8xf32>) -> () + %1 = "tf.DTensorRecv"() {key = "communication_key_|x=2|0,1|0,1|/job:localhost/replica:0/task:0/device:TPU:0,/job:localhost/replica:0/task:0/device:TPU:1", mesh = #dtensor.mesh<|x=2|0,1|0,1|/job:localhost/replica:0/task:0/device:TPU:0,/job:localhost/replica:0/task:0/device:TPU:1>, shape = #tf_type.shape<4x8>} : () -> tensor<4x8xf32> + %2 = "tf.Relayout"(%1) {global_shape = #tf_type.shape<8x8>, layout = "sharding_specs:unsharded,unsharded, mesh:|x=2|0,1|0,1|/job:localhost/replica:0/task:0/device:TPU:0,/job:localhost/replica:0/task:0/device:TPU:1"} : (tensor<4x8xf32>) -> tensor<4x8xf32> + %3 = "tf.Identity"(%2) : (tensor<4x8xf32>) -> tensor<4x8xf32> + "tf.AssignVariableOp"(%arg4, %3) {validate_shape = true} : (tensor<*x!tf_type.resource>>, tensor<4x8xf32>) -> () tf_device.return }) {_mesh="TPU|x=2|0,1|0,1|/job:localhost/replica:0/task:0/device:TPU:0,/job:localhost/replica:0/task:0/device:TPU:1"} : () -> (tensor, tensor) @@ -819,7 +820,7 @@ func.func @main(%arg0: tensor, // CHECK-SAME: layout = #dtensor.layout "tf_device.cluster"() ({ %6 = "tf.RestoreV2"(%arg1, %arg2, %arg3) {} : (tensor, tensor<1x!tf_type.string>, tensor<1x!tf_type.string>) -> (tensor<4x8xf32>) - "tf.DTensorSend"(%6) {key = "communication_key_sharding_specs:unsharded,unsharded, mesh:|x=2|0,1|0,1|/job:localhost/replica:0/task:0/device:TPU:0,/job:localhost/replica:0/task:0/device:TPU:1", target_layout = #dtensor.layout} : (tensor<4x8xf32>) -> () + "tf.DTensorSend"(%6) {key = "communication_key_|x=2|0,1|0,1|/job:localhost/replica:0/task:0/device:TPU:0,/job:localhost/replica:0/task:0/device:TPU:1", target_mesh = #dtensor.mesh<|x=2|0,1|0,1|/job:localhost/replica:0/task:0/device:TPU:0,/job:localhost/replica:0/task:0/device:TPU:1>} : (tensor<4x8xf32>) -> () tf_device.return }) {_mesh="CPU|x=2|0,1|0,1|/job:localhost/task:0/device:CPU:0,/job:localhost/task:0/device:CPU:1"} : () -> (tensor) func.return diff --git a/tensorflow/dtensor/mlir/tests/lower_send_recv.mlir b/tensorflow/dtensor/mlir/tests/lower_send_recv.mlir index 4cd10793321957..da12e597ac8c70 100644 --- a/tensorflow/dtensor/mlir/tests/lower_send_recv.mlir +++ b/tensorflow/dtensor/mlir/tests/lower_send_recv.mlir @@ -33,24 +33,24 @@ func.func @main(%arg0: tensor) { // CHECK: "tf_device.cluster" // CHECK-NEXT: "tf.Identity" // CHECK-NEXT: %[[TPU_RECV_OUT:.*]] = "tf.XlaRecvFromHost"() - // CHECK-SAME: key = "communication_key_sharding_specs:, mesh:TPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3_0" - // CHECK-NEXT: %[[TPU_LAYOUT_OUT:.*]] = "tf.DTensorLayout"(%[[TPU_RECV_OUT]]) + // CHECK-SAME: key = "communication_key_TPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3_0" + // CHECK-NEXT: %[[TPU_LAYOUT_OUT:.*]] = "tf.Relayout"(%[[TPU_RECV_OUT]]) // CHECK-NEXT: %[[A_OUT:.*]] = "tf.A" // CHECK-NEXT: "tf.XlaSendToHost"(%[[A_OUT]]) "tf_device.cluster"() ({ %0 = "tf.Const"() {value = dense<10> : tensor<1xi32>} : () -> tensor<1xi32> %1 = "tf.DTensorLayout"(%0) {global_shape = #tf_type.shape<1>, layout = #dtensor.layout} : (tensor<1xi32>) -> tensor<1xi32> - "tf.DTensorSend"(%1) {key = "communication_key_sharding_specs:, mesh:TPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3_0", target_layout = #dtensor.layout} : (tensor<1xi32>) -> () + "tf.DTensorSend"(%1) {key = "communication_key_TPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3_0", target_mesh = #dtensor.mesh} : (tensor<1xi32>) -> () - %2 = "tf.DTensorRecv"() {key = "CPU|x=1|0|0|/job:localhost/task:0/device:CPU:0_2", layout = #dtensor.layout, shape = #tf_type.shape<>} : () -> (tensor<1xi32>) + %2 = "tf.DTensorRecv"() {key = "CPU|x=1|0|0|/job:localhost/task:0/device:CPU:0_2", mesh = #dtensor.mesh, shape = #tf_type.shape<>} : () -> (tensor<1xi32>) "tf.B"(%2) : (tensor<1xi32>) -> () tf_device.return }) {_mesh = "CPU|x=1|0|0|/job:localhost/task:0/device:CPU:0"} : () -> () "tf_device.cluster"() ({ - %0 = "tf.DTensorRecv"() {key = "communication_key_sharding_specs:, mesh:TPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3_0", layout = #dtensor.layout, shape = #tf_type.shape<>} : () -> tensor<1xi32> - %1 = "tf.DTensorLayout"(%0) {global_shape = #tf_type.shape<1>, layout = #dtensor.layout} : (tensor<1xi32>) -> tensor<1xi32> + %0 = "tf.DTensorRecv"() {key = "communication_key_TPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3_0", mesh = #dtensor.mesh, shape = #tf_type.shape<>} : () -> tensor<1xi32> + %1 = "tf.Relayout"(%0) {global_shape = #tf_type.shape<1>, layout = "sharding_specs:unsharded, mesh:TPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3"} : (tensor<1xi32>) -> tensor<1xi32> %2 = "tf.A"(%1) : (tensor<1xi32>) -> tensor<1xi32> - "tf.DTensorSend"(%2) {key = "CPU|x=1|0|0|/job:localhost/task:0/device:CPU:0_2", target_layout = #dtensor.layout} : (tensor<1xi32>) -> () + "tf.DTensorSend"(%2) {key = "CPU|x=1|0|0|/job:localhost/task:0/device:CPU:0_2", target_mesh = #dtensor.mesh} : (tensor<1xi32>) -> () tf_device.return }) {_mesh = "TPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3"} : () -> () func.return diff --git a/tensorflow/dtensor/mlir/tests/mesh_propagation.mlir b/tensorflow/dtensor/mlir/tests/mesh_propagation.mlir index f75b5ffef77f5b..ec9084550a1c35 100644 --- a/tensorflow/dtensor/mlir/tests/mesh_propagation.mlir +++ b/tensorflow/dtensor/mlir/tests/mesh_propagation.mlir @@ -341,20 +341,20 @@ module @test_multi_mesh { }) : () -> tensor<4xi32> // CHECK: "tf_device.cluster" - // CHECK-NEXT: "tf.CopyToMesh" + // CHECK-NEXT: "tf.Relayout" // CHECK-NEXT: tf_device.return // CHECK-NEXT: _mesh = "TPU|x=2|0,1|0,1|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1" %6 = "tf_device.cluster"() ({ - %7 = "tf.CopyToMesh"(%2) { layout = "sharding_specs:not_sharded mesh:TPU|x=2|0,1|0,1|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1"} : (tensor<4xi32>) -> tensor<4xi32> + %7 = "tf.Relayout"(%2) { layout = "sharding_specs:not_sharded mesh:TPU|x=2|0,1|0,1|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1"} : (tensor<4xi32>) -> tensor<4xi32> tf_device.return %7 : tensor<4xi32> }) { _mesh = "TPU|x=2|0,1|0,1|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1" } : () -> tensor<4xi32> // CHECK: "tf_device.cluster" - // CHECK-NEXT: "tf.CopyToMesh" + // CHECK-NEXT: "tf.Relayout" // CHECK-NEXT: tf_device.return // CHECK-NEXT: _mesh = "TPU|x=2|0,1|0,1|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1" %8 = "tf_device.cluster"() ({ - %9 = "tf.CopyToMesh"(%4) { layout = "sharding_specs:not_sharded mesh:TPU|x=2|0,1|0,1|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1"} : (tensor<4xi32>) -> tensor<4xi32> + %9 = "tf.Relayout"(%4) { layout = "sharding_specs:not_sharded mesh:TPU|x=2|0,1|0,1|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1"} : (tensor<4xi32>) -> tensor<4xi32> tf_device.return %9 : tensor<4xi32> }) { _mesh = "TPU|x=2|0,1|0,1|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1" } : () -> tensor<4xi32> @@ -373,27 +373,6 @@ module @test_multi_mesh { // ----- -// Checks CopyToMeshGrad is written to CopyToMesh. -// CHECK-LABEL: module @test_copy_to_mesh_grad -module @test_copy_to_mesh_grad { - func.func @main(%arg0: tensor<4xi32> {tf._layout = "sharding_specs:not_sharded mesh:CPU|x=2|0,1|0,1|/job:localhost/task:0/device:CPU:0,/job:localhost/task:0/device:CPU:1"}, - %arg1: tensor<4xi32> {tf._layout = "sharding_specs:not_sharded mesh:TPU|x=2|0,1|0,1|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1"}) -> (tensor<4xi32>) { - - // CHECK: "tf_device.cluster" - // CHECK-NEXT: "tf.CopyToMesh" - // CHECK-NEXT: tf_device.return - // CHECK-NEXT: _mesh = "TPU|x=2|0,1|0,1|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1" - %0 = "tf_device.cluster"() ({ - %1 = "tf.CopyToMeshGrad"(%arg0, %arg1) { reference_layout = "sharding_specs:not_sharded mesh:TPU|x=2|0,1|0,1|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> - tf_device.return %1 : tensor<4xi32> - }) : () -> tensor<4xi32> - - func.return %0 :tensor<4xi32> - } -} - -// ----- - // Check mesh propagation of ops inside tf.WhileRegion op. // CHECK-LABEL: module @test_while module @test_while { diff --git a/tensorflow/dtensor/mlir/tests/op_to_device_cluster.mlir b/tensorflow/dtensor/mlir/tests/op_to_device_cluster.mlir index fe73bfbd745e83..c736fdb55f6732 100644 --- a/tensorflow/dtensor/mlir/tests/op_to_device_cluster.mlir +++ b/tensorflow/dtensor/mlir/tests/op_to_device_cluster.mlir @@ -39,10 +39,10 @@ func.func @check_device_cluster_from_dtensor_layout_op(%arg0: tensor) -> te // CHECK-LABEL: func @check_device_cluster_from_copy_to_mesh_op func.func @check_device_cluster_from_copy_to_mesh_op(%arg0: tensor) -> tensor { // CHECK: "tf_device.cluster" - // CHECK-NEXT: %[[A_OUT:.*]] = "tf.CopyToMesh" + // CHECK-NEXT: %[[A_OUT:.*]] = "tf.Relayout" // CHECK-NEXT: tf_device.return %[[A_OUT]] // CHECK-NEXT: _mesh = "|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:CPU:0,/job:localhost/task:0/device:CPU:1,/job:localhost/task:0/device:CPU:2,/job:localhost/task:0/device:CPU:3" - %0 = "tf.CopyToMesh"(%arg0) { layout = "sharding_specs:x, mesh:|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:CPU:0,/job:localhost/task:0/device:CPU:1,/job:localhost/task:0/device:CPU:2,/job:localhost/task:0/device:CPU:3"} : (tensor) -> tensor + %0 = "tf.Relayout"(%arg0) { layout = "sharding_specs:x, mesh:|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:CPU:0,/job:localhost/task:0/device:CPU:1,/job:localhost/task:0/device:CPU:2,/job:localhost/task:0/device:CPU:3"} : (tensor) -> tensor func.return %0 : tensor } diff --git a/tensorflow/dtensor/mlir/tests/restore_and_assign.mlir b/tensorflow/dtensor/mlir/tests/restore_and_assign.mlir index 69a0d3f875d371..ee6ca5be3bf669 100644 --- a/tensorflow/dtensor/mlir/tests/restore_and_assign.mlir +++ b/tensorflow/dtensor/mlir/tests/restore_and_assign.mlir @@ -75,7 +75,12 @@ func.func @main( %arg4: tensor<*x!tf_type.resource>>) { // CHECK: "tf_device.cluster" // CHECK-NEXT: %[[RESOURCE:.*]] = "tf.DTensorLayout"(%arg4) - // CHECK-NEXT: %[[RECV:.*]] = "tf.DTensorRecv"() {key = "communication_key_sharding_specs:unsharded,unsharded, mesh:|x=2|0,1|0,1|/job:localhost/replica:0/task:0/device:TPU:0,/job:localhost/replica:0/task:0/device:TPU:1", layout = #dtensor.layout, shape = #tf_type.shape<4x8>} : () -> tensor<4x8xf32> + // CHECK-NEXT: %[[RECV:.*]] = "tf.DTensorRecv"() { + // CHECK-SAME: key = "communication_key_|x=2|0,1|0,1|/job:localhost/replica:0/task:0/device:TPU:0,/job:localhost/replica:0/task:0/device:TPU:1" + // CHECK-SAME: mesh = #dtensor.mesh<|x=2|0,1|0,1|/job:localhost/replica:0/task:0/device:TPU:0,/job:localhost/replica:0/task:0/device:TPU:1> + // CHECK-SAME: shape = #tf_type.shape<4x8> + // CHECK-SAME: source_layout = #dtensor.layout + // CHECK-SAME: () -> tensor<4x8xf32> // CHECK-NEXT: %[[RECV_DL:.*]] = "tf.DTensorLayout"(%[[RECV]]) // CHECK-NEXT: %[[IDENTITY:.*]] = "tf.Identity"(%[[RECV_DL]]) : (tensor<4x8xf32>) -> tensor<4x8xf32> // CHECK-NEXT: %[[IDENTITY_DL:.*]] = "tf.DTensorLayout"(%[[IDENTITY]]) @@ -83,7 +88,7 @@ func.func @main( // CHECK-NEXT: tf_device.return "tf_device.cluster"() ({ %4 = "tf.DTensorLayout"(%arg4) {global_shape = #tf_type.shape<4x8>, layout = #dtensor.layout} : (tensor<*x!tf_type.resource>>) -> tensor<*x!tf_type.resource>> - %5 = "tf.DTensorRecv"() {key = "communication_key_sharding_specs:unsharded,unsharded, mesh:|x=2|0,1|0,1|/job:localhost/replica:0/task:0/device:TPU:0,/job:localhost/replica:0/task:0/device:TPU:1", layout = #dtensor.layout, shape = #tf_type.shape<*>} : () -> tensor<*xf32> + %5 = "tf.DTensorRecv"() {key = "communication_key_|x=2|0,1|0,1|/job:localhost/replica:0/task:0/device:TPU:0,/job:localhost/replica:0/task:0/device:TPU:1", mesh = #dtensor.mesh<|x=2|0,1|0,1|/job:localhost/replica:0/task:0/device:TPU:0,/job:localhost/replica:0/task:0/device:TPU:1>, shape = #tf_type.shape<*>} : () -> tensor<*xf32> %6 = "tf.Identity"(%5) : (tensor<*xf32>) -> tensor<*xf32> "tf.AssignVariableOp"(%4, %6) {validate_shape = true} : (tensor<*x!tf_type.resource>>, tensor<*xf32>) -> () tf_device.return @@ -102,7 +107,7 @@ func.func @main( %1 = "tf.DTensorLayout"(%arg2) {global_shape = #tf_type.shape<>, layout = #dtensor.layout} : (tensor) -> tensor %2 = "tf.DTensorLayout"(%arg3) {global_shape = #tf_type.shape<>, layout = #dtensor.layout} : (tensor) -> tensor %3 = "tf.RestoreV2"(%0, %1, %2) {} : (tensor, tensor, tensor) -> (tensor<*xf32>) - "tf.DTensorSend"(%3) {key = "communication_key_sharding_specs:unsharded,unsharded, mesh:|x=2|0,1|0,1|/job:localhost/replica:0/task:0/device:TPU:0,/job:localhost/replica:0/task:0/device:TPU:1", target_layout = #dtensor.layout} : (tensor<*xf32>) -> () + "tf.DTensorSend"(%3) {key = "communication_key_|x=2|0,1|0,1|/job:localhost/replica:0/task:0/device:TPU:0,/job:localhost/replica:0/task:0/device:TPU:1", target_mesh = #dtensor.mesh<|x=2|0,1|0,1|/job:localhost/replica:0/task:0/device:TPU:0,/job:localhost/replica:0/task:0/device:TPU:1>} : (tensor<*xf32>) -> () tf_device.return }) {_mesh="CPU|x=2|0,1|0,1|/job:localhost/task:0/device:CPU:0,/job:localhost/task:0/device:CPU:1"} : () -> (tensor) func.return diff --git a/tensorflow/dtensor/mlir/tests/restore_shape_inference.mlir b/tensorflow/dtensor/mlir/tests/restore_shape_inference.mlir index ba2a3a3f9071ba..6f78b552e47f40 100644 --- a/tensorflow/dtensor/mlir/tests/restore_shape_inference.mlir +++ b/tensorflow/dtensor/mlir/tests/restore_shape_inference.mlir @@ -19,12 +19,12 @@ func.func @main(%arg0: tensor, %arg1: tensor, %arg2: tenso // Check the tf.RestoreV2Op's and all connected ops' resulting types are inferred from the AssignVariableOps in cross mesh cluster. All unknown shapes should be known after this pass. func.func @main(%arg0: tensor, %arg1: tensor, %arg2: tensor<2x!tf_type.string>, %arg3: tensor<2x!tf_type.string>, %arg4: tensor<*x!tf_type.resource>>, %arg5: tensor<*x!tf_type.resource>>) { // CHECK: "tf_device.cluster" - // CHECK-NEXT: %2 = "tf.DTensorRecv"() {key = "communication_key_sharding_specs:unsharded,unsharded, mesh:|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/replica:0/task:0/device:TPU:0,/job:localhost/replica:0/task:0/device:TPU:1,/job:localhost/replica:0/task:0/device:TPU:2,/job:localhost/replica:0/task:0/device:TPU:3", layout = #dtensor.layout, shape = #tf_type.shape<4x8>} : () -> tensor<4x8xf32> + // CHECK-NEXT: %2 = "tf.DTensorRecv"() {key = "communication_key_|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/replica:0/task:0/device:TPU:0,/job:localhost/replica:0/task:0/device:TPU:1,/job:localhost/replica:0/task:0/device:TPU:2,/job:localhost/replica:0/task:0/device:TPU:3", mesh = #dtensor.mesh<|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/replica:0/task:0/device:TPU:0,/job:localhost/replica:0/task:0/device:TPU:1,/job:localhost/replica:0/task:0/device:TPU:2,/job:localhost/replica:0/task:0/device:TPU:3>, shape = #tf_type.shape<4x8>} : () -> tensor<4x8xf32> // CHECK-NEXT: %3 = "tf.Identity"(%2) : (tensor<4x8xf32>) -> tensor<4x8xf32> // CHECK-NEXT: "tf.AssignVariableOp"(%arg4, %3) {validate_shape = true} : (tensor<*x!tf_type.resource>>, tensor<4x8xf32>) -> () // CHECK-NEXT: tf_device.return "tf_device.cluster"() ({ - %1 = "tf.DTensorRecv"() {key = "communication_key_sharding_specs:unsharded,unsharded, mesh:|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/replica:0/task:0/device:TPU:0,/job:localhost/replica:0/task:0/device:TPU:1,/job:localhost/replica:0/task:0/device:TPU:2,/job:localhost/replica:0/task:0/device:TPU:3", layout = #dtensor.layout, shape = #tf_type.shape<*>} : () -> tensor<*xf32> + %1 = "tf.DTensorRecv"() {key = "communication_key_|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/replica:0/task:0/device:TPU:0,/job:localhost/replica:0/task:0/device:TPU:1,/job:localhost/replica:0/task:0/device:TPU:2,/job:localhost/replica:0/task:0/device:TPU:3", mesh = #dtensor.mesh<|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/replica:0/task:0/device:TPU:0,/job:localhost/replica:0/task:0/device:TPU:1,/job:localhost/replica:0/task:0/device:TPU:2,/job:localhost/replica:0/task:0/device:TPU:3>, shape = #tf_type.shape<*>} : () -> tensor<*xf32> %2 = "tf.Identity"(%1) : (tensor<*xf32>) -> tensor<*xf32> "tf.AssignVariableOp"(%arg4, %2) {validate_shape = true} : (tensor<*x!tf_type.resource>>, tensor<*xf32>) -> () tf_device.return @@ -34,13 +34,13 @@ func.func @main(%arg0: tensor, %arg1: tensor, %arg2: tenso // CHECK-NEXT: %2:2 = "tf.RestoreV2"(%arg1, %arg2, %arg3) : (tensor, tensor<2x!tf_type.string>, tensor<2x!tf_type.string>) -> (tensor<4x8xf32>, tensor) // CHECK-NEXT: %3 = "tf.Identity"(%2#1) : (tensor) -> tensor // CHECK-NEXT: "tf.AssignVariableOp"(%arg5, %3) {validate_shape = false} : (tensor<*x!tf_type.resource>>, tensor) -> () - // CHECK-NEXT: "tf.DTensorSend"(%2#0) {key = "communication_key_sharding_specs:unsharded,unsharded, mesh:|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/replica:0/task:0/device:TPU:0,/job:localhost/replica:0/task:0/device:TPU:1,/job:localhost/replica:0/task:0/device:TPU:2,/job:localhost/replica:0/task:0/device:TPU:3", target_layout = #dtensor.layout} : (tensor<4x8xf32>) -> () + // CHECK-NEXT: "tf.DTensorSend"(%2#0) {key = "communication_key_|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/replica:0/task:0/device:TPU:0,/job:localhost/replica:0/task:0/device:TPU:1,/job:localhost/replica:0/task:0/device:TPU:2,/job:localhost/replica:0/task:0/device:TPU:3", target_mesh = #dtensor.mesh<|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/replica:0/task:0/device:TPU:0,/job:localhost/replica:0/task:0/device:TPU:1,/job:localhost/replica:0/task:0/device:TPU:2,/job:localhost/replica:0/task:0/device:TPU:3>} : (tensor<4x8xf32>) -> () // CHECK-NEXT: tf_device.return "tf_device.cluster"() ({ %6:2 = "tf.RestoreV2"(%arg1, %arg2, %arg3) {} : (tensor, tensor<2x!tf_type.string>, tensor<2x!tf_type.string>) -> (tensor<*xf32>, tensor<*xi64>) %7 = "tf.Identity"(%6#1) : (tensor<*xi64>) -> tensor<*xi64> "tf.AssignVariableOp"(%arg5, %7) {validate_shape = false} : (tensor<*x!tf_type.resource>>, tensor<*xi64>) -> () - "tf.DTensorSend"(%6#0) {key = "communication_key_sharding_specs:unsharded,unsharded, mesh:|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/replica:0/task:0/device:TPU:0,/job:localhost/replica:0/task:0/device:TPU:1,/job:localhost/replica:0/task:0/device:TPU:2,/job:localhost/replica:0/task:0/device:TPU:3", target_layout = #dtensor.layout} : (tensor<*xf32>) -> () + "tf.DTensorSend"(%6#0) {key = "communication_key_|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/replica:0/task:0/device:TPU:0,/job:localhost/replica:0/task:0/device:TPU:1,/job:localhost/replica:0/task:0/device:TPU:2,/job:localhost/replica:0/task:0/device:TPU:3", target_mesh = #dtensor.mesh<|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/replica:0/task:0/device:TPU:0,/job:localhost/replica:0/task:0/device:TPU:1,/job:localhost/replica:0/task:0/device:TPU:2,/job:localhost/replica:0/task:0/device:TPU:3>} : (tensor<*xf32>) -> () tf_device.return }) {_mesh="CPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:CPU:0,/job:localhost/task:0/device:CPU:1,/job:localhost/task:0/device:CPU:2,/job:localhost/task:0/device:CPU:3"} : () -> (tensor) func.return diff --git a/tensorflow/dtensor/mlir/tests/spmd_dtensor_ops.mlir b/tensorflow/dtensor/mlir/tests/spmd_dtensor_ops.mlir index 5d088f74facf03..0aa9500915fad7 100644 --- a/tensorflow/dtensor/mlir/tests/spmd_dtensor_ops.mlir +++ b/tensorflow/dtensor/mlir/tests/spmd_dtensor_ops.mlir @@ -26,7 +26,7 @@ func.func @main(%arg0: tensor) { // CHECK-NEXT: %[[ZERO_2:.*]] = "tf.Const" // CHECK-SAME: value = dense<0> // CHECK-NEXT: "tf._XlaSendFromHostV2"(%[[CONST_OUT]], %[[PROGRAM_KEY]], %[[ZERO_2]]) - // CHECK-SAME: key = "communication_key_sharding_specs:, mesh:TPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3_0" + // CHECK-SAME: key = "communication_key_TPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3_0" // CHECK-NEXT: "tf.Yield" // CHECK: "tf.Yield" // CHECK: "tf_device.cluster" @@ -47,7 +47,7 @@ func.func @main(%arg0: tensor) { // CHECK-NEXT: %[[PREDICATE_2:.*]] = "tf.Equal"(%[[DEVICE_ORDINAL_SCALAR_64_2]], %[[ZERO_2]]) // CHECK-NEXT: %[[IF_OUT:.*]] = "tf.IfRegion"(%[[PREDICATE_2]]) // CHECK-NEXT: %[[RECV_OUT:.*]] = "tf.XlaRecvFromHost"() - // CHECK-SAME: key = "communication_key_sharding_specs:, mesh:TPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3_0" + // CHECK-SAME: key = "communication_key_TPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3_0" // CHECK-NEXT: "tf.Yield"(%[[RECV_OUT]]) // CHECK: %[[ZEROS_3:.*]] = "tf.Const" // CHECK-NEXT: "tf.Yield"(%[[ZEROS_3]]) @@ -56,11 +56,11 @@ func.func @main(%arg0: tensor) { "tf_device.cluster"() ({ %0 = "tf.Const"() {value = dense<10> : tensor<1xi32>} : () -> tensor<1xi32> %1 = "tf.DTensorLayout"(%0) {global_shape = #tf_type.shape<1>, layout = #dtensor.layout} : (tensor<1xi32>) -> tensor<1xi32> - "tf.DTensorSend"(%1) {key = "communication_key_sharding_specs:, mesh:TPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3_0", target_layout = #dtensor.layout} : (tensor<1xi32>) -> () + "tf.DTensorSend"(%1) {key = "communication_key_TPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3_0", target_mesh = #dtensor.mesh, target_layout = #dtensor.layout} : (tensor<1xi32>) -> () tf_device.return }) {_mesh = "CPU|x=1|0|0|/job:localhost/task:0/device:CPU:0"} : () -> () "tf_device.cluster"() ({ - %0 = "tf.DTensorRecv"() {key = "communication_key_sharding_specs:, mesh:TPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3_0", layout = #dtensor.layout, shape = #tf_type.shape<>} : () -> tensor<1xi32> + %0 = "tf.DTensorRecv"() {key = "communication_key_TPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3_0", mesh = #dtensor.mesh, shape = #tf_type.shape<>, source_layout = #dtensor.layout} : () -> tensor<1xi32> %1 = "tf.DTensorLayout"(%0) {global_shape = #tf_type.shape<1>, layout = #dtensor.layout} : (tensor<1xi32>) -> tensor<1xi32> tf_device.return }) {_mesh = "TPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3"} : () -> () @@ -93,7 +93,7 @@ func.func @main(%arg0: tensor) { // CHECK-NEXT: %[[PREDICATE:.*]] = "tf.Equal"(%[[DEVICE_ORDINAL_SCALAR_64]], %[[ZERO]]) // CHECK-NEXT: "tf.IfRegion"(%[[PREDICATE]]) // CHECK-NEXT: "tf.XlaSendToHost"(%[[CONST_OUT]]) - // CHECK-SAME: key = "communication_key_sharding_specs:, mesh:CPU|x=1|0|0|/job:localhost/task:0/device:CPU:0_0" + // CHECK-SAME: key = "communication_key_CPU|x=1|0|0|/job:localhost/task:0/device:CPU:0_0" // CHECK-NEXT: "tf.Yield" // CHECK: "tf.Yield" // CHECK: "tf_device.cluster" @@ -114,12 +114,12 @@ func.func @main(%arg0: tensor) { "tf_device.cluster"() ({ %0 = "tf.Const"() {value = dense<10> : tensor<1xi32>} : () -> tensor<1xi32> %1 = "tf.DTensorLayout"(%0) {global_shape = #tf_type.shape<1>, layout = #dtensor.layout} : (tensor<1xi32>) -> tensor<1xi32> - "tf.DTensorSend"(%1) {key = "communication_key_sharding_specs:, mesh:CPU|x=1|0|0|/job:localhost/task:0/device:CPU:0_0", target_layout = #dtensor.layout} : (tensor<1xi32>) -> () + "tf.DTensorSend"(%1) {key = "communication_key_CPU|x=1|0|0|/job:localhost/task:0/device:CPU:0_0", target_mesh = #dtensor.mesh, target_layout = #dtensor.layout} : (tensor<1xi32>) -> () tf_device.return }) {_mesh = "TPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3"} : () -> () "tf_device.cluster"() ({ - %0 = "tf.DTensorRecv"() {key = "communication_key_sharding_specs:, mesh:CPU|x=1|0|0|/job:localhost/task:0/device:CPU:0_0", layout = #dtensor.layout, shape = #tf_type.shape<>} : () -> tensor<1xi32> + %0 = "tf.DTensorRecv"() {key = "communication_key_CPU|x=1|0|0|/job:localhost/task:0/device:CPU:0_0", mesh = #dtensor.mesh, shape = #tf_type.shape<>, source_layout = #dtensor.layout} : () -> tensor<1xi32> %1 = "tf.DTensorLayout"(%0) {global_shape = #tf_type.shape<1>, layout = #dtensor.layout} : (tensor<1xi32>) -> tensor<1xi32> tf_device.return @@ -154,7 +154,7 @@ func.func @main(%arg0: tensor) { // CHECK-NEXT: %[[PREDICATE:.*]] = "tf.Equal"(%[[DEVICE_ORDINAL_SCALAR_64]], %[[ZERO]]) // CHECK-NEXT: "tf.IfRegion"(%[[PREDICATE]]) // CHECK-NEXT: "tf.XlaSendToHost"(%[[ALL_GATHER_OUT]]) - // CHECK-SAME: key = "communication_key_sharding_specs:, mesh:CPU|x=1|0|0|/job:localhost/task:0/device:CPU:0_0" + // CHECK-SAME: key = "communication_key_CPU|x=1|0|0|/job:localhost/task:0/device:CPU:0_0" // CHECK-NEXT: "tf.Yield" // CHECK: "tf.Yield" // CHECK: "tf_device.cluster" @@ -175,12 +175,12 @@ func.func @main(%arg0: tensor) { "tf_device.cluster"() ({ %0 = "tf.Const"() {value = dense<10> : tensor<2xi32>} : () -> tensor<2xi32> %1 = "tf.DTensorLayout"(%0) {global_shape = #tf_type.shape<2>, layout = #dtensor.layout} : (tensor<2xi32>) -> tensor<2xi32> - "tf.DTensorSend"(%1) {key = "communication_key_sharding_specs:, mesh:CPU|x=1|0|0|/job:localhost/task:0/device:CPU:0_0", target_layout = #dtensor.layout} : (tensor<2xi32>) -> () + "tf.DTensorSend"(%1) {key = "communication_key_CPU|x=1|0|0|/job:localhost/task:0/device:CPU:0_0", target_mesh = #dtensor.mesh, target_layout = #dtensor.layout} : (tensor<2xi32>) -> () tf_device.return }) {_mesh = "TPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3"} : () -> () "tf_device.cluster"() ({ - %0 = "tf.DTensorRecv"() {key = "communication_key_sharding_specs:, mesh:CPU|x=1|0|0|/job:localhost/task:0/device:CPU:0_0", layout = #dtensor.layout, shape = #tf_type.shape<>} : () -> tensor<2xi32> + %0 = "tf.DTensorRecv"() {key = "communication_key_CPU|x=1|0|0|/job:localhost/task:0/device:CPU:0_0", mesh = #dtensor.mesh, shape = #tf_type.shape<>, source_layout = #dtensor.layout} : () -> tensor<2xi32> %1 = "tf.DTensorLayout"(%0) {global_shape = #tf_type.shape<2>, layout = #dtensor.layout} : (tensor<2xi32>) -> tensor<2xi32> tf_device.return @@ -197,12 +197,12 @@ func.func @main(%arg0: tensor) { %0 = "tf.Const"() {value = dense<10> : tensor<1xi32>} : () -> tensor<1xi32> %1 = "tf.DTensorLayout"(%0) {global_shape = #tf_type.shape<1>, layout = #dtensor.layout} : (tensor<1xi32>) -> tensor<1xi32> // expected-error @+1 {{Only use CopyToMesh to transfer data across different mesh cluster}} - "tf.DTensorSend"(%1) {key = "communication_key_sharding_specs:, mesh:CPU|x=2|0,1|0,1|/job:localhost/task:0/device:CPU:0,/job:localhost/task:0/device:CPU:1_0", target_layout = #dtensor.layout} : (tensor<1xi32>) -> () + "tf.DTensorSend"(%1) {key = "communication_key_CPU|x=2|0,1|0,1|/job:localhost/task:0/device:CPU:0,/job:localhost/task:0/device:CPU:1_0", target_mesh = #dtensor.mesh, target_layout = #dtensor.layout} : (tensor<1xi32>) -> () tf_device.return }) {_mesh = "CPU|x=2|*CPU"} : () -> () "tf_device.cluster"() ({ - %0 = "tf.DTensorRecv"() {key = "communication_key_sharding_specs:, mesh:CPU|x=2|0,1|0,1|/job:localhost/task:0/device:CPU:0,/job:localhost/task:0/device:CPU:1_0", layout = #dtensor.layout, shape = #tf_type.shape<>} : () -> tensor<1xi32> + %0 = "tf.DTensorRecv"() {key = "communication_key_CPU|x=2|0,1|0,1|/job:localhost/task:0/device:CPU:0,/job:localhost/task:0/device:CPU:1_0", mesh = #dtensor.mesh, shape = #tf_type.shape<>, source_layout = #dtensor.layout} : () -> tensor<1xi32> %1 = "tf.DTensorLayout"(%0) {global_shape = #tf_type.shape<1>, layout = #dtensor.layout} : (tensor<1xi32>) -> tensor<1xi32> tf_device.return @@ -219,12 +219,12 @@ func.func @main(%arg0: tensor) { %0 = "tf.Const"() {value = dense<10> : tensor<1xi32>} : () -> tensor<1xi32> %1 = "tf.DTensorLayout"(%0) {global_shape = #tf_type.shape<1>, layout = #dtensor.layout} : (tensor<1xi32>) -> tensor<1xi32> // expected-error @+1 {{f.CopyToMesh op must be used to send data from/to host mesh}} - "tf.DTensorSend"(%1) {key = "communication_key_sharding_specs:, mesh:GPU|x=2|0,1|0,1|/job:localhost/task:0/device:GPU:0,/job:localhost/task:0/device:GPU:1_0", target_layout = #dtensor.layout} : (tensor<1xi32>) -> () + "tf.DTensorSend"(%1) {key = "communication_key_GPU|x=2|0,1|0,1|/job:localhost/task:0/device:GPU:0,/job:localhost/task:0/device:GPU:1_0", target_mesh = #dtensor.mesh, target_layout = #dtensor.layout} : (tensor<1xi32>) -> () tf_device.return }) {_mesh = "TPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:TPU:0,/job:localhost/task:0/device:TPU:1,/job:localhost/task:0/device:TPU:2,/job:localhost/task:0/device:TPU:3"} : () -> () "tf_device.cluster"() ({ - %0 = "tf.DTensorRecv"() {key = "communication_key_sharding_specs:, mesh:GPU|x=2|0,1|0,1|/job:localhost/task:0/device:GPU:0,/job:localhost/task:0/device:GPU:1_0", layout = #dtensor.layout, shape = #tf_type.shape<>} : () -> tensor<1xi32> + %0 = "tf.DTensorRecv"() {key = "communication_key_GPU|x=2|0,1|0,1|/job:localhost/task:0/device:GPU:0,/job:localhost/task:0/device:GPU:1_0", mesh = #dtensor.mesh, shape = #tf_type.shape<>, source_layout = #dtensor.layout} : () -> tensor<1xi32> %1 = "tf.DTensorLayout"(%0) {global_shape = #tf_type.shape<1>, layout = #dtensor.layout} : (tensor<1xi32>) -> tensor<1xi32> tf_device.return @@ -242,16 +242,16 @@ func.func @main(%arg0: tensor) { // CHECK-NEXT: "tf._HostSend"(%[[CONST_OUT]]) // CHECK: "tf_device.cluster" // CHECK-NEXT: "tf._HostRecv" - // CHECK-SAME: tensor_name = "communication_key_sharding_specs:unsharded, mesh:CPU|x=1|0|0|/job:localhost/task:0/device:CPU:1_0" + // CHECK-SAME: tensor_name = "communication_key_CPU|x=1|0|0|/job:localhost/task:0/device:CPU:1_0" "tf_device.cluster"() ({ %0 = "tf.Const"() {value = dense<10> : tensor<1xi32>} : () -> tensor<1xi32> %1 = "tf.DTensorLayout"(%0) {global_shape = #tf_type.shape<1>, layout = #dtensor.layout} : (tensor<1xi32>) -> tensor<1xi32> - "tf.DTensorSend"(%1) {key = "communication_key_sharding_specs:unsharded, mesh:CPU|x=1|0|0|/job:localhost/task:0/device:CPU:1_0", target_layout = #dtensor.layout} : (tensor<1xi32>) -> () + "tf.DTensorSend"(%1) {key = "communication_key_CPU|x=1|0|0|/job:localhost/task:0/device:CPU:1_0", target_mesh = #dtensor.mesh, target_layout = #dtensor.layout} : (tensor<1xi32>) -> () tf_device.return }) {_mesh = "CPU|x=1|0|0|/job:localhost/task:0/device:CPU:0"} : () -> () "tf_device.cluster"() ({ - %0 = "tf.DTensorRecv"() {key = "communication_key_sharding_specs:unsharded, mesh:CPU|x=1|0|0|/job:localhost/task:0/device:CPU:1_0", layout = #dtensor.layout, shape = #tf_type.shape<1>} : () -> tensor<1xi32> + %0 = "tf.DTensorRecv"() {key = "communication_key_CPU|x=1|0|0|/job:localhost/task:0/device:CPU:1_0", mesh = #dtensor.mesh, shape = #tf_type.shape<1>, source_layout = #dtensor.layout} : () -> tensor<1xi32> %1 = "tf.DTensorLayout"(%0) {global_shape = #tf_type.shape<1>, layout = #dtensor.layout} : (tensor<1xi32>) -> tensor<1xi32> tf_device.return }) {_mesh = "CPU|x=1|0|0|/job:localhost/task:0/device:CPU:1"} : () -> () @@ -311,11 +311,11 @@ func.func @main(%arg0: tensor) { "tf_device.cluster"() ({ %0 = "tf.Const"() {value = dense<1.> : tensor<8x8xf32>} : () -> tensor<8x8xf32> %1 = "tf.DTensorLayout"(%0) {_global_shape = [#tf_type.shape<8x8>], global_shape = #tf_type.shape<8x8>, layout = #dtensor.layout} : (tensor<8x8xf32>) -> tensor<8x8xf32> - "tf.DTensorSend"(%1) {key = "communication_key_sharding_specs:, mesh:GPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:GPU:0,/job:localhost/task:0/device:GPU:1,/job:localhost/task:0/device:GPU:2,/job:localhost/task:0/device:GPU:3_0", target_layout = #dtensor.layout} : (tensor<8x8xf32>) -> () + "tf.DTensorSend"(%1) {key = "communication_key_GPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:GPU:0,/job:localhost/task:0/device:GPU:1,/job:localhost/task:0/device:GPU:2,/job:localhost/task:0/device:GPU:3_0", target_mesh = #dtensor.mesh, target_layout = #dtensor.layout} : (tensor<8x8xf32>) -> () tf_device.return }) {_mesh = "CPU|x=1|0|0|/job:localhost/task:0/device:CPU:0"} : () -> () "tf_device.cluster"() ({ - %0 = "tf.DTensorRecv"() {key = "communication_key_sharding_specs:, mesh:GPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:GPU:0,/job:localhost/task:0/device:GPU:1,/job:localhost/task:0/device:GPU:2,/job:localhost/task:0/device:GPU:3_0", layout = #dtensor.layout, shape = #tf_type.shape<8x8>} : () -> tensor<8x8xf32> + %0 = "tf.DTensorRecv"() {key = "communication_key_GPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:GPU:0,/job:localhost/task:0/device:GPU:1,/job:localhost/task:0/device:GPU:2,/job:localhost/task:0/device:GPU:3_0", mesh = #dtensor.mesh, shape = #tf_type.shape<8x8>, source_layout = #dtensor.layout} : () -> tensor<8x8xf32> %1 = "tf.DTensorLayout"(%0) {_global_shape = [#tf_type.shape<8x8>], global_shape = #tf_type.shape<8x8>, layout = #dtensor.layout} : (tensor<8x8xf32>) -> tensor<8x8xf32> tf_device.return }) {_mesh = "GPU|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:GPU:0,/job:localhost/task:0/device:GPU:1,/job:localhost/task:0/device:GPU:2,/job:localhost/task:0/device:GPU:3"} : () -> () diff --git a/tensorflow/dtensor/python/api.py b/tensorflow/dtensor/python/api.py index 2f303e9aa3a218..cbfab30e768a1b 100644 --- a/tensorflow/dtensor/python/api.py +++ b/tensorflow/dtensor/python/api.py @@ -185,8 +185,7 @@ def copy_to_mesh( A DTensor on the DTensor device with the given layout. """ del source_layout - with default_mesh(layout.mesh): - return gen_dtensor_ops.copy_to_mesh(tensor, layout.to_string()) + return relayout(tensor, layout) @tf_export("experimental.dtensor.pack", v1=[]) @@ -548,7 +547,6 @@ def _copy_to_mesh_gradient(op, grad): grad = gen_dtensor_ops.copy_to_mesh_grad( grad, forward_input=op.inputs[0], - reference_layout=op.get_attr("layout"), ) return grad @@ -558,6 +556,5 @@ def _copy_to_mesh_grad_gradient(op, grad): grad = gen_dtensor_ops.copy_to_mesh_grad( grad, forward_input=op.inputs[0], - reference_layout=op.get_attr("reference_layout"), ) return grad, None diff --git a/tensorflow/dtensor/python/dtensor_device.py b/tensorflow/dtensor/python/dtensor_device.py index 8a252e64f1e402..5c6dff23052b66 100644 --- a/tensorflow/dtensor/python/dtensor_device.py +++ b/tensorflow/dtensor/python/dtensor_device.py @@ -23,7 +23,6 @@ from tensorflow.core.framework import attr_value_pb2 from tensorflow.dtensor.python import config -from tensorflow.dtensor.python import gen_dtensor_ops from tensorflow.dtensor.python import layout as layout_lib from tensorflow.python import _pywrap_dtensor_device from tensorflow.python.eager import context @@ -32,7 +31,6 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor -from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_util from tensorflow.python.util import _pywrap_utils @@ -120,12 +118,6 @@ def _register_mesh(self, mesh: layout_lib.Mesh): def meshes(self) -> Set[layout_lib.Mesh]: return self._meshes - def copy_to_mesh(self, tensor, new_layout) -> tensor_lib.Tensor: - """Copy `tensor` to `device` with the given layout.""" - self._register_mesh(new_layout.mesh) - with ops.device(self.name): - return gen_dtensor_ops.copy_to_mesh(tensor, layout=new_layout.to_string()) - def pack(self, tensors: Sequence[Any], layout: layout_lib.Layout) -> Any: """Packs tensors into a DTensor handle on this DTensor device. diff --git a/tensorflow/dtensor/python/tests/device_test.py b/tensorflow/dtensor/python/tests/device_test.py index 08bfb3a06b2444..3e8c073e4724f0 100644 --- a/tensorflow/dtensor/python/tests/device_test.py +++ b/tensorflow/dtensor/python/tests/device_test.py @@ -79,10 +79,12 @@ def testAsyncOption(self, is_async): # There isn't a great way to test whether something actually executed # synchronously; this test just exercises the option. device = dtensor_device.DTensorDevice([], is_async=is_async) - a = device.copy_to_mesh( - constant_op.constant([1.0]), Layout.replicated(self.mesh, rank=1) - ) - b = array_ops.identity(a) + with device._experimental_default_mesh(self.mesh): + with ops.device_v2(device.name): + a = api.copy_to_mesh( + constant_op.constant([1.0]), Layout.replicated(self.mesh, rank=1) + ) + b = array_ops.identity(a) self.assertEqual([1.], b.numpy()) def testBasicTypeBasedDispatch(self): diff --git a/tensorflow/dtensor/python/tests/multi_mesh_test.py b/tensorflow/dtensor/python/tests/multi_mesh_test.py index 735d833bb44e88..e938e1bc8c01b9 100644 --- a/tensorflow/dtensor/python/tests/multi_mesh_test.py +++ b/tensorflow/dtensor/python/tests/multi_mesh_test.py @@ -756,6 +756,45 @@ def second(x): self.assertDTensorEqual(-0.5 * 0.5 * 0.5 * (1 / numpy_a)**1.5, host_layout, a_grad_grad) + def testMultiMeshMultipleCopyToMesh(self): + self.skipForDeviceType( + ['CPU'], + 'Skipping test as only CPU mesh is available for multi-meshtest.', + ) + + sharded_layout_on_tpu = Layout([_MESH_DIM_X], self.second_mesh) + host_layout = Layout( + sharded_layout_on_tpu.sharding_specs, + sharded_layout_on_tpu.mesh.host_mesh(), + ) + + source_layout = host_layout + target_layout = sharded_layout_on_tpu + + numpy_a = constant_op.constant([1, 2, 3, 4], dtype=dtypes.int32) + numpy_b = constant_op.constant([2, 2, 3, 4], dtype=dtypes.int32) + + # TODO(b/193443769): switch to a single copy_to_mesh when this is supported. + replicated_layout = Layout.replicated( + source_layout.mesh, source_layout.rank + ) + a = api.copy_to_mesh(numpy_a, replicated_layout) + b = api.copy_to_mesh(numpy_b, replicated_layout) + a = api.relayout(a, source_layout) + b = api.relayout(b, source_layout) + + @polymorphic_function.function + def func(a, b): + a = api.copy_to_mesh(a, target_layout) + b = api.copy_to_mesh(b, target_layout) + return array_ops.identity(a), array_ops.identity(b) + + with ops.device_v2(api.device_name()): + dtensor_a, dtensor_b = func(a, b) + + self.assertDTensorEqual(numpy_a, target_layout, dtensor_a) + self.assertDTensorEqual(numpy_b, target_layout, dtensor_b) + def testDVariableDefaultMesh(self): other_layout = Layout.replicated(_OTHER_CPU_MESH, rank=0) first_layout = Layout.replicated(_ONE_D_CPU_MESH, rank=0) From eb11c5b9b07e2d0b16ea6ab2d84751de27b60d6f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 25 Jul 2023 18:35:43 -0700 Subject: [PATCH 164/410] Add a Python wrapper function for gen_math_ops.floor_mod in math_ops.py. This wrapper is needed for the dunder method __mod__ and __rmod__ to call the WeakTensor patched op. PiperOrigin-RevId: 551057080 --- tensorflow/python/ops/math_ops.py | 26 ++++++++++++++++--- .../python/ops/weak_tensor_math_ops_test.py | 9 +++++++ tensorflow/python/ops/weak_tensor_ops.py | 8 ------ 3 files changed, 32 insertions(+), 11 deletions(-) diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index 10bdf8dac1ec2a..27718ec479a4c7 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -1776,8 +1776,28 @@ def multiply_no_nan(x, y, name=None): return gen_math_ops.mul_no_nan(x, y, name=name) -# TODO(aselle): This should be removed -mod = gen_math_ops.floor_mod +def mod(x, y, name=None): + r"""Returns element-wise remainder of division. + + This follows Python semantics in that the + result here is consistent with a flooring divide. E.g. + `floor(x / y) * y + floormod(x, y) = x`, regardless of the signs of x and y. + + *NOTE*: `math.floormod` supports broadcasting. More about broadcasting + [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + + Args: + x: A `Tensor`. Must be one of the following types: `int8`, `int16`, `int32`, + `int64`, `uint8`, `uint16`, `uint32`, `uint64`, `bfloat16`, `half`, + `float32`, `float64`. + y: A `Tensor`. Must have the same type as `x`. + name: A name for the operation (optional). + + Returns: + A `Tensor`. Has the same type as `x`. + """ + with ops.name_scope(name, "mod", [x, y]) as name: + return gen_math_ops.floor_mod(x, y, name=name) @tf_export("math.floordiv", v1=["math.floordiv", "floordiv"]) @@ -1874,7 +1894,7 @@ def _mul_dispatch(x, y, name=None): _OverrideBinaryOperatorHelper(div, "div") _OverrideBinaryOperatorHelper(truediv, "truediv") _OverrideBinaryOperatorHelper(floordiv, "floordiv") -_OverrideBinaryOperatorHelper(gen_math_ops.floor_mod, "mod") +_OverrideBinaryOperatorHelper(mod, "mod") _OverrideBinaryOperatorHelper(pow, "pow") diff --git a/tensorflow/python/ops/weak_tensor_math_ops_test.py b/tensorflow/python/ops/weak_tensor_math_ops_test.py index 3cbcd133ea1132..7929c9ec1a7637 100644 --- a/tensorflow/python/ops/weak_tensor_math_ops_test.py +++ b/tensorflow/python/ops/weak_tensor_math_ops_test.py @@ -475,12 +475,21 @@ class RHSReturnsTrue: def __radd__(self, other): return True + def __rmod__(self, other): + return False + a = array_ops.ones([1], dtype=dtypes.int32) + RHSReturnsTrue() self.assertEqual(a, True) a = _get_weak_tensor(5, dtype=dtypes.int32) + RHSReturnsTrue() self.assertEqual(a, True) + a = array_ops.ones([1], dtype=dtypes.float32) % RHSReturnsTrue() + self.assertEqual(a, False) + + a = _get_weak_tensor(5, dtype=dtypes.float32) % RHSReturnsTrue() + self.assertEqual(a, False) + class RHSRaisesError: def __radd__(self, other): diff --git a/tensorflow/python/ops/weak_tensor_ops.py b/tensorflow/python/ops/weak_tensor_ops.py index ec92340712c71a..75535ab505f8ce 100644 --- a/tensorflow/python/ops/weak_tensor_ops.py +++ b/tensorflow/python/ops/weak_tensor_ops.py @@ -479,20 +479,12 @@ def _update_weak_tensor_patched_ops_in_dispatch_dict(patched_op): # ============================================================================== # Update old op references. # ============================================================================== -math_ops.mod = gen_math_ops.floor_mod math_ops.realdiv = gen_math_ops.real_div math_ops.truncatediv = gen_math_ops.truncate_div math_ops.floor_div = gen_math_ops.floor_div math_ops.truncatemod = gen_math_ops.truncate_mod math_ops.floormod = gen_math_ops.floor_mod -# Update Tensor dunder methods. -# Rest of the dunder methods call the updated op because those ops have -# Python wrapper functions that call the patched op. (e.g. __add__ = -# _add_dispatch and _add_dispatch calls the updated math_ops.add). -tensor.Tensor.__mod__ = gen_math_ops.floor_mod -tensor.Tensor.__rmod__ = weak_tensor_binary_op_wrapper(tensor.Tensor.__rmod__) - # Set WeakTensor dunder methods. # Tensor unary ops do not need WeakTensor support. weak_tensor.WeakTensor.__invert__ = math_ops.invert_ From 9c260f899566f26a790b5b1051a7ff3d7e47f3cc Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 25 Jul 2023 18:45:44 -0700 Subject: [PATCH 165/410] Potential improvement opportunity to eliminate extra transpose op. This CL will add patterns to fold Transpose and FC to covert into a BMM, like below- FC(lhs, Transpose(rhs)) -> BMM(lha, rhs, false, false) The right thing to do in this pattern will be to apply the pattern only if keep_num_dims==True. Because, if the output rank is less-than the input rank, it means `keep_num_dims` has reduced the output. But checking for rank will improve the coverage. This pattern will now work for all the transpose->fc patterns where the fc has not done implicit reshape of the output PiperOrigin-RevId: 551058717 --- .../tests/end2end/unroll_batch_matmul.pbtxt | 7 ++--- .../compiler/mlir/lite/tests/optimize.mlir | 28 +++++++++++++++++++ .../mlir/lite/transforms/optimize_patterns.td | 26 +++++++++++++++++ tensorflow/lite/python/analyzer_test.py | 18 +++++------- 4 files changed, 63 insertions(+), 16 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/tests/end2end/unroll_batch_matmul.pbtxt b/tensorflow/compiler/mlir/lite/tests/end2end/unroll_batch_matmul.pbtxt index 0657bcbf8126d5..1d6c0ef73bbab9 100644 --- a/tensorflow/compiler/mlir/lite/tests/end2end/unroll_batch_matmul.pbtxt +++ b/tensorflow/compiler/mlir/lite/tests/end2end/unroll_batch_matmul.pbtxt @@ -78,15 +78,12 @@ versions { } # CHECK: func @main(%[[VAL_0:.*]]: tensor<2x5x3xf32>, %[[VAL_1:.*]]: tensor<3x7xf32>) -> tensor<2x5x7xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "Placeholder,Placeholder_1", outputs = "MatMul"}} { -# CHECK-DAG: %[[VAL_2:.*]] = arith.constant dense<[1, 0]> : tensor<2xi32> -# CHECK-DAG: %[[VAL_3:.*]] = "tfl.no_value"() {value} : () -> none # CHECK-DAG: %[[VAL_6:.*]] = arith.constant dense<0> : tensor # CHECK: %[[VAL_7:.*]]:2 = "tfl.split"(%[[VAL_6]], %[[VAL_0]]) {num_splits = 2 : i32} : (tensor, tensor<2x5x3xf32>) -> (tensor<1x5x3xf32>, tensor<1x5x3xf32>) # CHECK: %[[VAL_8:.*]] = "tfl.reshape"(%[[VAL_7]]#0, %cst) : (tensor<1x5x3xf32>, tensor<2xi32>) -> tensor<5x3xf32> # CHECK: %[[VAL_9:.*]] = "tfl.reshape"(%[[VAL_7]]#1, %cst) : (tensor<1x5x3xf32>, tensor<2xi32>) -> tensor<5x3xf32> -# CHECK: %[[VAL_10:.*]] = "tfl.transpose"(%[[VAL_1]], %[[VAL_2]]) : (tensor<3x7xf32>, tensor<2xi32>) -> tensor<7x3xf32> -# CHECK: %[[VAL_11:.*]] = "tfl.fully_connected"(%[[VAL_8]], %[[VAL_10]], %[[VAL_3]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<5x3xf32>, tensor<7x3xf32>, none) -> tensor<5x7xf32> -# CHECK: %[[VAL_12:.*]] = "tfl.fully_connected"(%[[VAL_9]], %[[VAL_10]], %[[VAL_3]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<5x3xf32>, tensor<7x3xf32>, none) -> tensor<5x7xf32> +# CHECK: %[[VAL_11:.*]] = "tfl.batch_matmul"(%[[VAL_8]], %[[VAL_1]]) {adj_x = false, adj_y = false} : (tensor<5x3xf32>, tensor<3x7xf32>) -> tensor<5x7xf32> +# CHECK: %[[VAL_12:.*]] = "tfl.batch_matmul"(%[[VAL_9]], %[[VAL_1]]) {adj_x = false, adj_y = false} : (tensor<5x3xf32>, tensor<3x7xf32>) -> tensor<5x7xf32> # CHECK: %[[VAL_13:.*]] = "tfl.pack"(%[[VAL_11]], %[[VAL_12]]) {axis = 0 : i32, values_count = 2 : i32} : (tensor<5x7xf32>, tensor<5x7xf32>) -> tensor<2x5x7xf32> # CHECK: return %[[VAL_13]] : tensor<2x5x7xf32> # CHECK: } diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir index da0cc252529c61..8d8c8d6eb50f90 100644 --- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir @@ -686,6 +686,34 @@ func.func @FuseTransposeIntoBMM_RHS2(%arg0: tensor, %arg1: tensor } +// CHECK-LABEL: @FuseTransposeIntoFC_RHS +func.func @FuseTransposeIntoFC_RHS(%arg0: tensor<1x4x1440x256xf32>, %arg1: tensor<256x1440xf32>) -> tensor<1x4x1440x1440xf32> { + %cst = "tfl.no_value"() {value} : () -> none + %cst_1 = arith.constant dense<[1, 0]> : tensor<2xi32> + %32 = "tfl.transpose"(%arg1, %cst_1) : (tensor<256x1440xf32>, tensor<2xi32>) -> tensor<1440x256xf32> + %33 = "tfl.fully_connected"(%arg0, %32, %cst) {fused_activation_function = "NONE", keep_num_dims = true, weights_format = "DEFAULT"} : (tensor<1x4x1440x256xf32>, tensor<1440x256xf32>, none) -> tensor<1x4x1440x1440xf32> + return %33 : tensor<1x4x1440x1440xf32> + // CHECK: %0 = "tfl.batch_matmul"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<1x4x1440x256xf32>, tensor<256x1440xf32>) -> tensor<1x4x1440x1440xf32> + // CHECK: return %0 : tensor<1x4x1440x1440xf32> +} + +// CHECK-LABEL: @FuseTransposeIntoFC_RHS1 +func.func @FuseTransposeIntoFC_RHS1(%arg0: tensor<1x384x64x!quant.uniform>, %arg1: tensor<1x64x384x!quant.uniform>) -> tensor<1x384x384x!quant.uniform> { + %cst = "tfl.no_value"() {value} : () -> none + %cst_6 = arith.constant dense<[1, 0]> : tensor<2xi32> + %cst_7 = arith.constant dense<[64, 384]> : tensor<2xi32> + %cst_8 = arith.constant dense<[1, 384, 384]> : tensor<3xi32> + %46 = "tfl.reshape"(%arg1, %cst_7) : (tensor<1x64x384x!quant.uniform>, tensor<2xi32>) -> tensor<64x384x!quant.uniform> + %53 = "tfl.transpose"(%46, %cst_6) : (tensor<64x384x!quant.uniform>, tensor<2xi32>) -> tensor<384x64x!quant.uniform> + %33 = "tfl.fully_connected"(%arg0, %53, %cst) {asymmetric_quantize_inputs = false, fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x384x64x!quant.uniform>, tensor<384x64x!quant.uniform>, none) -> tensor<384x384x!quant.uniform> + %34 = "tfl.reshape"(%33, %cst_8) : (tensor<384x384x!quant.uniform>, tensor<3xi32>) -> tensor<1x384x384x!quant.uniform> + return %34 : tensor<1x384x384x!quant.uniform> + // CHECK: %cst = arith.constant dense<[64, 384]> : tensor<2xi32> + // CHECK: %0 = "tfl.reshape"(%arg1, %cst) : (tensor<1x64x384x!quant.uniform>, tensor<2xi32>) -> tensor<64x384x!quant.uniform> + // CHECK: %1 = "tfl.batch_matmul"(%arg0, %0) {adj_x = false, adj_y = false, asymmetric_quantize_inputs = false} : (tensor<1x384x64x!quant.uniform>, tensor<64x384x!quant.uniform>) -> tensor<1x384x384x!quant.uniform> + // CHECK: return %1 : tensor<1x384x384x!quant.uniform> +} + // CHECK-LABEL: @FuseTransposeIntoBMM_LHS func.func @FuseTransposeIntoBMM_LHS(%arg0: tensor<1x4x1440x256xf32>, %arg1: tensor<1x1440x256xf32>) -> tensor<1x4x256x256xf32> { %cst_1 = arith.constant dense<[0, 2, 1]> : tensor<3xi32> diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td index 34a0cd741f6998..69a9f19619ed77 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td @@ -441,6 +441,12 @@ def IsRankLessThanEqualTo : Constraint().getRank() <= " "$1.getType().cast().getRank()">>; +// Constraint check to test if the rank of an element in a ValueRange($0) +// is equal to the value($1) +class IsValueArrayElementRankEqualTo : Constraint().getRank() == " + "$1.getType().cast().getRank()">>; + def Flatten : NativeCodeCall< "$0.cast()" ".reshape(RankedTensorType::get({$0.getType().cast().getNumElements()}, " @@ -1506,3 +1512,23 @@ def FuseTransposeIntoBatchMatMulLHS: Pat< [(AreLastTwoDimsTransposed $perm_value), (IsBoolAttrEqual<"false"> $adj_x)]>; +// Fuse redundant TFL_TransposeOp into TFL_FullyConnecedOp to form TFL_BatchMatMulOp +def FuseTransposeIntoFullyConnectedRHS: Pat< + (TFL_FullyConnectedOp:$output $lhs, + (TFL_TransposeOp $input, (Arith_ConstantOp:$perm_value $p0)), + $bias, + TFL_AF_None, + TFL_FCWO_Default, + $keep_num_dims, + $asymmetric_quantize_inputs), + (TFL_BatchMatMulOp $lhs, $input, + ConstBoolAttrFalse, ConstBoolAttrFalse, $asymmetric_quantize_inputs), + [(IsNoneType $bias), + //if the output rank is less-than the input rank, it means keep_num_dims has + //reduced the output. The right thing to do here will be to apply the pattern + //only if keep_num_dims==True. But checking for rank will improve the + //coverage. This pattern will now work for all the transpose->fc patterns + //where the fc has not done imlicit reshape of the output + (IsValueArrayElementRankEqualTo<0> $output, $lhs), + (AreLastTwoDimsTransposed $perm_value)]>; + diff --git a/tensorflow/lite/python/analyzer_test.py b/tensorflow/lite/python/analyzer_test.py index 22a6a939ba900c..a262823a89affd 100644 --- a/tensorflow/lite/python/analyzer_test.py +++ b/tensorflow/lite/python/analyzer_test.py @@ -221,23 +221,19 @@ def func(lhs, rhs): with test.mock.patch.object(sys, 'stdout', mock_stdout): analyzer.ModelAnalyzer.analyze(model_content=fb_model) txt = mock_stdout.getvalue() - self.assertIn('Op#0 RESHAPE(T#1, T#5[512, 512]) -> [T#6]', txt) - self.assertIn('Op#1 RESHAPE(T#0, T#3[100, 512]) -> [T#7]', txt) - self.assertIn('Op#2 TRANSPOSE(T#6, T#4[1, 0]) -> [T#8]', txt) - self.assertIn('Op#3 FULLY_CONNECTED(T#7, T#8, T#-1) -> [T#9]', txt) - self.assertIn('Op#4 RESHAPE(T#9, T#2[1, 100, 8, 64]) -> [T#10]', txt) + self.assertIn('Op#0 RESHAPE(T#1, T#4[512, 512]) -> [T#5]', txt) + self.assertIn('Op#1 RESHAPE(T#0, T#3[100, 512]) -> [T#6]', txt) + self.assertIn('Op#2 BATCH_MATMUL(T#6, T#5) -> [T#7]', txt) + self.assertIn('Op#3 RESHAPE(T#7, T#2[1, 100, 8, 64]) -> [T#8]', txt) self.assertIn( 'T#2(einsum/Einsum) shape:[4], type:INT32 RO 16 bytes, ' 'buffer: 3, data:[1, 100, 8, 64]', txt) self.assertIn( - 'T#3(einsum/Einsum2) shape:[2], type:INT32 RO 8 bytes, ' + 'T#3(einsum/Einsum1) shape:[2], type:INT32 RO 8 bytes, ' 'buffer: 4, data:[100, 512]', txt) self.assertIn( - 'T#4(einsum/Einsum3) shape:[2], type:INT32 RO 8 bytes, ' - 'buffer: 5, data:[1, 0]', txt) - self.assertIn( - 'T#5(einsum/Einsum4) shape:[2], type:INT32 RO 8 bytes, ' - 'buffer: 6, data:[512, 512]', txt) + 'T#4(einsum/Einsum2) shape:[2], type:INT32 RO 8 bytes, ' + 'buffer: 5, data:[512, 512]', txt) if __name__ == '__main__': test.main() From 865c951e5f08dd72bc15c66ebd64bf39f7e31588 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 25 Jul 2023 19:18:27 -0700 Subject: [PATCH 166/410] Update ops-related pbtxt files. PiperOrigin-RevId: 551063559 --- tensorflow/core/ops/ops.pbtxt | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index 7d2b4d103ed09f..73d793d535eaa2 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -10681,7 +10681,7 @@ op { type_attr: "T" } attr { - name: "layout" + name: "mesh" type: "string" } attr { @@ -10703,13 +10703,6 @@ op { name: "output" type_attr: "T" } - attr { - name: "reference_layout" - type: "string" - default_value { - s: "" - } - } attr { name: "T" type: "type" From 53c7d31a6a3e6185f0b4eb11d3560c7b3dd27e93 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 25 Jul 2023 19:47:11 -0700 Subject: [PATCH 167/410] Integrate LLVM at llvm/llvm-project@1c154bd75515 Updates LLVM usage to match [1c154bd75515](https://github.com/llvm/llvm-project/commit/1c154bd75515) PiperOrigin-RevId: 551067352 --- third_party/llvm/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index afed500e54f70d..c877a9c5b35a34 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "049d6a3f428efeb1a22f62e55b808f60b0bf27cc" - LLVM_SHA256 = "559ecaa58d53ac29c6da5df8b8b358ba8ec99f3cd8e7d08babe165f8506c2c19" + LLVM_COMMIT = "1c154bd755153b5c6ada4bbed58facf23f6abffc" + LLVM_SHA256 = "212983fc56c42105e34690097801981974acec667e9753a6465d7a5c07075e0e" tf_http_archive( name = name, From 09f6e9e9216fbb1ddaa0d4b39f09a087fd1bb5db Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 25 Jul 2023 20:31:00 -0700 Subject: [PATCH 168/410] Update TFRT dependency to use revision http://github.com/tensorflow/runtime/commit/dcffa4094e13d56a30b2cdcdb709ce5d71b38953. PiperOrigin-RevId: 551075708 --- third_party/tf_runtime/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/tf_runtime/workspace.bzl b/third_party/tf_runtime/workspace.bzl index f94f03e669937a..fb499cb5f9dae7 100644 --- a/third_party/tf_runtime/workspace.bzl +++ b/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "e143099db75fe9b89b3959965840dd3c4237eab3" - TFRT_SHA256 = "5591f65576dabeadbba8ebe7d2d95d47576389de9b0f433cbc4e33fb9bd7eb08" + TFRT_COMMIT = "dcffa4094e13d56a30b2cdcdb709ce5d71b38953" + TFRT_SHA256 = "ce4a95e4fe258353e751af57e6970a186aab77f03f12944ce389667b33d6c5b1" tf_http_archive( name = "tf_runtime", From 35642696f4fc620fc88250bed83bb1b73c95a695 Mon Sep 17 00:00:00 2001 From: Bixia Zheng Date: Tue, 25 Jul 2023 21:15:42 -0700 Subject: [PATCH 169/410] [xla] Add latency-hiding-scheduler-preparation pass. This new pass currently amends control dependency to ensure that a Send-Recv sequence generated by the collective-permute-decomposer is scheduled before an instruction that uses the Recv result and also calls nested computations with Send-Recv operation. Add tests. PiperOrigin-RevId: 551083114 --- tensorflow/compiler/xla/service/BUILD | 33 ++++ .../latency_hiding_scheduler_preparation.cc | 182 ++++++++++++++++++ .../latency_hiding_scheduler_preparation.h | 93 +++++++++ ...tency_hiding_scheduler_preparation_test.cc | 161 ++++++++++++++++ 4 files changed, 469 insertions(+) create mode 100644 tensorflow/compiler/xla/service/latency_hiding_scheduler_preparation.cc create mode 100644 tensorflow/compiler/xla/service/latency_hiding_scheduler_preparation.h create mode 100644 tensorflow/compiler/xla/service/latency_hiding_scheduler_preparation_test.cc diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 573f6c4e8a9873..5f5284becfbcf3 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -1189,6 +1189,39 @@ xla_cc_test( ], ) +cc_library( + name = "latency_hiding_scheduler_preparation", + srcs = ["latency_hiding_scheduler_preparation.cc"], + hdrs = ["latency_hiding_scheduler_preparation.h"], + deps = [ + ":collective_ops_utils", + ":hlo_pass", + "//tensorflow/compiler/xla/hlo/ir:hlo", + "//tensorflow/compiler/xla/hlo/ir:hlo_reachability", + "//tensorflow/compiler/xla/service/gpu:backend_configs_cc", + "//tensorflow/compiler/xla/service/graphcycles", + "@com_google_absl//absl/log", + "@com_google_absl//absl/strings", + ], +) + +xla_cc_test( + name = "latency_hiding_scheduler_preparation_test", + srcs = ["latency_hiding_scheduler_preparation_test.cc"], + deps = [ + ":hlo_parser", + ":latency_hiding_scheduler_preparation", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/hlo/ir:hlo", + "//tensorflow/compiler/xla/hlo/utils:hlo_matchers", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/tsl/lib/core:status_test_util", + "//tensorflow/tsl/platform:test_main", + "@com_google_googletest//:gtest", + ], +) + cc_library( name = "profile_guided_latency_estimator", srcs = ["profile_guided_latency_estimator.cc"], diff --git a/tensorflow/compiler/xla/service/latency_hiding_scheduler_preparation.cc b/tensorflow/compiler/xla/service/latency_hiding_scheduler_preparation.cc new file mode 100644 index 00000000000000..78e4adfee3e227 --- /dev/null +++ b/tensorflow/compiler/xla/service/latency_hiding_scheduler_preparation.cc @@ -0,0 +1,182 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/latency_hiding_scheduler_preparation.h" + +#include +#include + +#include "absl/log/log.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_computation.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_instructions.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_opcode.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_reachability.h" + +namespace xla { + +namespace { + +// Returns a boolean to indicate whether the operation is a non-host P2P +// operation. We exclude non-host P2P operations for two reasons: (1) this pass +// currently can only amend control dependence for non-host P2P operations. (2) +// we need to exclude host P2P operations when looking for a nested chain +// of non-host P2P operations. +bool IsP2POp(const HloInstruction* op) { + auto p2p = DynCastOrNull(op); + return p2p != nullptr && !p2p->is_host_transfer(); +} + +// Returns the predecessor of op with instruction type T1 or nullptr if such a +// control predecessor doesn't exist. The routine gives an error if there are +// more than one such control predecessor. +template +const T1* GetChainedOp(const HloInstruction* op) { + const T1* chained_op = nullptr; + for (const HloInstruction* predecessor : op->control_predecessors()) { + auto tmp = DynCastOrNull(predecessor); + if (!tmp || !IsP2POp(tmp)) { + continue; + } + CHECK_EQ(chained_op, nullptr); + chained_op = tmp; + } + return chained_op; +} + +// Given a send_done, returns the recv_done if it is in a chain of +// Recv => Send => RecvDone => SendDone +// Returns nullptr if such a chain doesn't exist. +const HloRecvDoneInstruction* GetChainedRecvDone( + const HloSendDoneInstruction* send_done) { + const HloRecvDoneInstruction* recv_done = + GetChainedOp(send_done); + if (!recv_done) { + return nullptr; + } + + auto send = DynCast(send_done->operand(0)); + CHECK_NE(send, nullptr); + CHECK_EQ(send->is_host_transfer(), false); + + const HloRecvInstruction* recv = GetChainedOp(send); + if (!recv) { + return nullptr; + } + if (recv_done->operand(0) != recv) { + return nullptr; + } + + return recv_done; +} + +// Inspects the instructions in the computation to find out whether the +// computation directly or indirectly invoke P2P operations, and records the +// finding in p2p_in_computation_cache. Also returns the boolean result. +bool FindP2PInComputation( + absl::flat_hash_map& p2p_in_computation_cache, + const HloComputation* computation) { + auto it = p2p_in_computation_cache.find(computation); + if (it != p2p_in_computation_cache.end()) { + return it->second; + } + bool result = false; + for (HloInstruction* instr : computation->instructions()) { + if (IsP2POp(instr)) { + result = true; + break; + } + for (const HloComputation* called_computation : + instr->called_computations()) { + if (FindP2PInComputation(p2p_in_computation_cache, called_computation)) { + result = true; + break; + } + } + } + p2p_in_computation_cache[computation] = result; + return result; +} + +// Returns a boolean to indicate whether there are any operation in the range +// [start, end] that contains non-host P2P transfer that are reachable from the +// given instruction. +bool OperationChainHasP2P( + absl::flat_hash_map& p2p_in_computation_cache, + const std::vector::const_iterator& start, + const std::vector::const_iterator& end, + const HloReachabilityMap* reachability, const HloInstruction* instr) { + for (auto it_op = start; it_op != end; ++it_op) { + const HloInstruction* op = *it_op; + if (!reachability->IsReachable(instr, op)) continue; + + if (IsP2POp(op)) { + return true; + } + + for (const HloComputation* called_comp : op->called_computations()) { + if (FindP2PInComputation(p2p_in_computation_cache, called_comp)) { + return true; + } + } + } + return false; +} + +} // namespace + +StatusOr LatencyHidingSchedulerPreparation::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + bool changed = false; + absl::flat_hash_map p2p_in_computation_cache; + for (HloComputation* computation : module->computations(execution_threads)) { + std::unique_ptr reachability; + std::vector all_instructions = + computation->MakeInstructionPostOrder(); + for (auto it = all_instructions.begin(); it != all_instructions.end(); + ++it) { + HloInstruction* hlo = *it; + if (hlo->opcode() != HloOpcode::kSendDone) { + continue; + } + auto send_done = Cast(hlo); + if (send_done->is_host_transfer()) { + continue; + } + const HloRecvDoneInstruction* recv_done = GetChainedRecvDone(send_done); + if (!recv_done) { + continue; + } + if (reachability == nullptr) { + reachability = HloReachabilityMap::Build(computation); + } + for (HloInstruction* recv_data : recv_done->users()) { + if (OperationChainHasP2P(p2p_in_computation_cache, it, + all_instructions.end(), reachability.get(), + recv_data)) { + // We need to schedule send_done before recv_data to avoid deadlock. + TF_RETURN_IF_ERROR(send_done->AddControlDependencyTo(recv_data)); + VLOG(10) << "Add control predecessor to " << recv_data->ToString(); + changed = true; + } + } + } + } + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/latency_hiding_scheduler_preparation.h b/tensorflow/compiler/xla/service/latency_hiding_scheduler_preparation.h new file mode 100644 index 00000000000000..b74e2271c3a9c1 --- /dev/null +++ b/tensorflow/compiler/xla/service/latency_hiding_scheduler_preparation.h @@ -0,0 +1,93 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LATENCY_HIDING_SCHEDULER_PREPARATION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_LATENCY_HIDING_SCHEDULER_PREPARATION_H_ + +#include "tensorflow/compiler/xla/hlo/ir/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// LatencyHidingSchedulerPreparation is a pass to linearize certain operations +// to prepare for the latency hiding scheduler (LHS). In particular, this pass +// currently does the following: +// +// Adds control prececessors/successors to ensure that a P2P Send-Recv sequence +// on a non-host device will be scheduled before other operations that use the +// Recv result and may also invoke P2P operations indirectly. Here is an example +// to illustrate the problem we address: +// +// Assume a computation with the following HLO instructions, where while-body +// invokes collective-permute operations: +// collective-permute-start = (u32[2], u32[2]) +// collective-permute-start(data) ... +// collective-permute-done = u32[2] +// collective-permute-done(collective-permute-start) +// while-init = (u32[], u32[2]) tuple(c0, collective-permute-done) +// while-result = (u32[], u32[2]) while(while-init), +// body=while-body, condition=while-cond +// +// Without collective-permute-decomposer transformation, LHS will Schedule +// while-result after collective-permute-start without any problem. +// +// Now assume we transform the collective-permute operations in the computation +// as well as inside the while-body into a sequence of P2P Send-Recv sequence, +// the computation will become something like this: +// after-all = token[] after-all() +// recv = (u32[2], token[]) recv(after-all) ... +// send = (u32[2], token[]) send(data, after-all), +// control-predecessors={recv} ... +// recv-done = (u32[2], token[]) recv-done(recv), +// control-predecessors={send} ... +// send-done = token[] send-done(send), +// control-predecessors={recv-done} ... +// recv-data = u32[2] get-tuple-element(recv-done), index=0 +// while-init = (u32[], u32[2]) tuple(c0, recv-data) +// while-result = (u32[], u32[2]) while(while_init), +// body=while_body, condition=while_cond +// +// When scheduling this computation in a bottom up fashion, the LHS will reach a +// point where both while-result and send-done are in the ready queue. If LHS +// picks send-done over while-result, the scheduler is stuck because +// while-result can't be scheduled when the Send-Recv chain is holding the +// resources for P2P operations and recv-done cannot be scheduled as well +// because while-result depends on while-init which depends on recv-done. To +// avoid this deadlock, we make send-done a control predecessor of recv-data +// in this pass. +// +// Note that instead of making send-done a control predecessor of recv-data, we +// may make send-done a control predecessor of the instruction that contains +// the nested P2P operations, which is while-result in this example. This allows +// recv-data and while-init to be scheduled before send-done. However, doing so +// would complicate the implementation. We leave this to future improvement if +// we will find out it can actually help performance in real practice. +class LatencyHidingSchedulerPreparation : public HloModulePass { + public: + absl::string_view name() const override { + return "latency-hiding-scheduler-preparation"; + } + + using HloPassInterface::Run; + // Runs LatencyHidingSchedulerPreparation pass on computations in 'module'. + // Returns whether the 'module' was changed. + StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LATENCY_HIDING_SCHEDULER_PREPARATION_H_ diff --git a/tensorflow/compiler/xla/service/latency_hiding_scheduler_preparation_test.cc b/tensorflow/compiler/xla/service/latency_hiding_scheduler_preparation_test.cc new file mode 100644 index 00000000000000..ac2b30cb1ac8a5 --- /dev/null +++ b/tensorflow/compiler/xla/service/latency_hiding_scheduler_preparation_test.cc @@ -0,0 +1,161 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/latency_hiding_scheduler_preparation.h" + +#include +#include + +#include +#include "tensorflow/compiler/xla/hlo/ir/hlo_computation.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" + +namespace xla { +namespace { + +using LatencyHidingSchedulerPreparationTest = HloTestBase; + +constexpr char kEmpty[] = ""; +constexpr char kHostTransfer[] = ", is_host_transfer=true"; +constexpr char kChainRecvDoneToSendDone[] = + ", control-predecessors={recv-done.1}"; + +// Returns an HLO module string for testing, which is generated from a +// templated string with placeholders for specifying the following values: +// Whether the Send/Recv operations in the while-body are host transfer +// Whether the Send/Recv operations in the main computation are host transfer +// Wether the SendDone/RecvDone operations in the main computation are chain +// +std::string GetHloModuleString(bool whileP2PIsHost = false, + bool mainP2PIsHost = false, + bool chainRecvDoneToSendDone = true) { + // A template string for the input HLO module. + constexpr char kModuleTemplate[] = R"( + HloModule test + while-cond { + param = (u32[], f32[1, 1024, 1024]) parameter(0) + count = get-tuple-element(param), index=0 + ub = u32[] constant(25) + ROOT cond-result = pred[] compare(count, ub), direction=LT + } + + while-body { + param = (u32[], f32[1, 1024, 1024]) parameter(0) + count = get-tuple-element(param), index=0 + send-data = get-tuple-element(param), index=1 + + after-all = token[] after-all() + recv = (f32[1, 1024, 1024], u32[], token[]) recv(after-all), channel_id=1, + frontend_attributes={ + _xla_send_recv_source_target_pairs="{{0, 1}}" + } %s + send = (f32[1, 1024, 1024], u32[], token[]) send(send-data, after-all), + channel_id=1, control-predecessors={recv}, frontend_attributes={ + _xla_send_recv_source_target_pairs="{{0, 1}}" + } %s + recv-done = (f32[1, 1024, 1024], token[]) recv-done(recv), channel_id=1, control-predecessors={send} %s + recv-data = f32[1, 1024, 1024] get-tuple-element(recv-done), index=0 + send-done = token[] send-done(send), control-predecessors={recv-done}, channel_id=1 %s %s + c1 = u32[] constant(1) + new-count = u32[] add(count, c1) + ROOT result = (u32[], f32[1, 1024, 1024]) tuple(new-count, recv-data) + } + + ENTRY main { + c0 = u32[] constant(0) + f0 = f32[] constant(0.0) + init = f32[1, 1024, 1024] broadcast(f0), dimensions={} + + after-all.1 = token[] after-all() + recv.1 = (f32[1, 1024, 1024], u32[], token[]) recv(after-all.1), channel_id=2, + frontend_attributes={ + _xla_send_recv_source_target_pairs="{{0, 1}}" + } %s + send.1 = (f32[1, 1024, 1024], u32[], token[]) send(init, after-all.1), + channel_id=2, control-predecessors={recv.1}, frontend_attributes={ + _xla_send_recv_source_target_pairs="{{0, 1}}" + } %s + recv-done.1 = (f32[1, 1024, 1024], token[]) recv-done(recv.1), channel_id=2, control-predecessors={send.1} %s + send-done.1 = token[] send-done(send.1), channel_id=2 %s + recv-data.1 = f32[1, 1024, 1024] get-tuple-element(recv-done.1), index=0 + + while-init = (u32[], f32[1, 1024, 1024]) tuple(c0, recv-data.1) + while-result = (u32[], f32[1, 1024, 1024]) while(while-init), + body=while-body, condition=while-cond + + while-result-data = f32[1, 1024, 1024] get-tuple-element(while-result), index=1 + ROOT entry-result = f32[1, 1024, 1024] add(while-result-data, recv-data.1) + } + )"; + const char* while_p2p = whileP2PIsHost ? kHostTransfer : kEmpty; + const char* main_p2p = mainP2PIsHost ? kHostTransfer : kEmpty; + const char* chain = + chainRecvDoneToSendDone ? kChainRecvDoneToSendDone : kEmpty; + return absl::StrFormat(kModuleTemplate, while_p2p, while_p2p, while_p2p, + while_p2p, main_p2p, main_p2p, main_p2p, main_p2p, + chain); +} + +TEST_F(LatencyHidingSchedulerPreparationTest, WhileP2PIsHostNotTransformed) { + std::string kModuleStr = GetHloModuleString(/*whileP2PIsHost=*/true); + VLOG(0) << kModuleStr; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnUnverifiedModule((kModuleStr))); + LatencyHidingSchedulerPreparation preparation; + TF_ASSERT_OK_AND_ASSIGN(bool changed, preparation.Run(module.get())); + EXPECT_FALSE(changed); +} + +TEST_F(LatencyHidingSchedulerPreparationTest, MainP2PIsHostNotTransformed) { + std::string kModuleStr = GetHloModuleString(/*whileP2PIsHost=*/false, + /*mainP2PIsHost=*/true); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnUnverifiedModule((kModuleStr))); + LatencyHidingSchedulerPreparation preparation; + TF_ASSERT_OK_AND_ASSIGN(bool changed, preparation.Run(module.get())); + EXPECT_FALSE(changed); +} + +TEST_F(LatencyHidingSchedulerPreparationTest, MainP2PNotChainedNotTransformed) { + std::string kModuleStr = + GetHloModuleString(/*whileP2PIsHost=*/false, + /*mainP2PIsHost=*/false, + /*chainRecvDoneToSendDone=*/false); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnUnverifiedModule((kModuleStr))); + LatencyHidingSchedulerPreparation preparation; + TF_ASSERT_OK_AND_ASSIGN(bool changed, preparation.Run(module.get())); + EXPECT_FALSE(changed); +} + +TEST_F(LatencyHidingSchedulerPreparationTest, ChainedWithNestedP2PTransformed) { + std::string kModuleStr = GetHloModuleString(); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnUnverifiedModule((kModuleStr))); + LatencyHidingSchedulerPreparation preparation; + TF_ASSERT_OK_AND_ASSIGN(bool changed, preparation.Run(module.get())); + EXPECT_TRUE(changed); + + HloInstruction* send_done = FindInstruction(module.get(), "send-done.1"); + HloInstruction* recv_data = FindInstruction(module.get(), "recv-data.1"); + EXPECT_EQ(recv_data->control_predecessors()[0], send_done); +} + +} // namespace +} // namespace xla From ed1cc75be5e404fb1726e3802c13a53edf1ed177 Mon Sep 17 00:00:00 2001 From: Doyeon Kim Date: Tue, 25 Jul 2023 23:10:57 -0700 Subject: [PATCH 170/410] Convert stablehlo.constant to arith.constant PiperOrigin-RevId: 551101755 --- .../mlir/quantization/stablehlo/BUILD | 30 ++++++++++++++++++- .../quantization/stablehlo/passes/passes.td | 7 ++++- .../stablehlo/passes/prepare_srq_quantize.cc | 4 +++ .../stablehlo/passes/prepare_srq_quantize.td | 28 +++++++++++++++++ .../stablehlo/tests/prepare_srq_quantize.mlir | 2 +- 5 files changed, 68 insertions(+), 3 deletions(-) create mode 100644 tensorflow/compiler/mlir/quantization/stablehlo/passes/prepare_srq_quantize.td diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD index c3b045f19a86b0..aa796e88734e4b 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD @@ -2,7 +2,7 @@ load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test") load("//tensorflow/core/platform:build_config.bzl", "tf_proto_library") load("//tensorflow/compiler/mlir/quantization/stablehlo:internal_visibility_allowlist.bzl", "internal_visibility_allowlist") -load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") # TODO(b/264218457): Create stablehlo-quantization-opt and register passes to actually test. @@ -31,6 +31,7 @@ cc_library( name = "passes", srcs = [ "passes/prepare_srq_quantize.cc", + "passes/prepare_srq_quantize.inc", "passes/quantize_weight.cc", ], hdrs = [ @@ -61,6 +62,33 @@ cc_library( alwayslink = True, ) +td_library( + name = "quant_td_files", + srcs = [ + "passes/prepare_srq_quantize.td", + ], + compatible_with = get_compatible_with_portable(), + deps = [ + "@llvm-project//mlir:ArithOpsTdFiles", + "@llvm-project//mlir:FuncTdFiles", + "@stablehlo//:stablehlo_ops_td_files", + ], +) + +gentbl_cc_library( + name = "prepare_srq_quantize_inc_gen", + compatible_with = get_compatible_with_portable(), + tbl_outs = [ + ( + ["-gen-rewriters"], + "passes/prepare_srq_quantize.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "passes/prepare_srq_quantize.td", + deps = [":quant_td_files"], +) + gentbl_cc_library( name = "bridge_passes_inc_gen", compatible_with = get_compatible_with_portable(), diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td b/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td index 896aa1ba833395..aafb641963b219 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td @@ -24,6 +24,11 @@ def QuantizeWeightPass : Pass<"stablehlo-quantize-weight", "mlir::func::FuncOp"> def PrepareSrqQuantizePass : Pass<"stablehlo-prepare-srq-quantize", "mlir::func::FuncOp"> { let summary = "Prepare StableHLO dialect for static range quantization."; let constructor = "CreatePrepareSrqQuantizePass()"; - let dependentDialects = ["stablehlo::StablehloDialect", "quant::QuantizationDialect", "quantfork::QuantizationForkDialect"]; + let dependentDialects = [ + "stablehlo::StablehloDialect", + "quant::QuantizationDialect", + "quantfork::QuantizationForkDialect", + "arith::ArithDialect" + ]; } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/prepare_srq_quantize.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/prepare_srq_quantize.cc index 12ccddcce58ba7..101c13c6a48ba3 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/prepare_srq_quantize.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/prepare_srq_quantize.cc @@ -74,11 +74,15 @@ using ReplaceStatsWithQDQs = quant::ConvertStatsToQDQs; +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/prepare_srq_quantize.inc" + void PrepareSrqQuantizePass::runOnOperation() { func::FuncOp func = getOperation(); MLIRContext* ctx = func.getContext(); RewritePatternSet patterns(ctx); + populateWithGenerated(patterns); + // TODO: b/288046643 - Implement different activation bit width per op/op // instance. int bit_width; diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/prepare_srq_quantize.td b/tensorflow/compiler/mlir/quantization/stablehlo/passes/prepare_srq_quantize.td new file mode 100644 index 00000000000000..d36a14f2ba9ba1 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/prepare_srq_quantize.td @@ -0,0 +1,28 @@ +/* Copyright 2023 The StableHLO Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +include "mlir/IR/OpBase.td" +include "mlir/IR/PatternBase.td" +include "mlir/Dialect/Func/IR/FuncOps.td" +include "mlir/Dialect/Arith/IR/ArithOps.td" +include "stablehlo/dialect/StablehloOps.td" + +// Converts stablehlo.constant to arith.constant for statically shaped +// constants. Needed for QuantizationDriver to recognize constants. +def ConvertStableHloConstToArithConst : Pat< + (StableHLO_ConstantOp:$res ElementsAttr:$value), + (Arith_ConstantOp $value), + [(AnyStaticShapeTensor $res)], (addBenefit 10)>; + \ No newline at end of file diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/prepare_srq_quantize.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/prepare_srq_quantize.mlir index 4c7d909094fe27..90c52f6b34750f 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/prepare_srq_quantize.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/prepare_srq_quantize.mlir @@ -8,7 +8,7 @@ func.func @main(%arg0: tensor) -> tensor { func.return %2 : tensor } -// CHECK: %[[cst:.*]] = stablehlo.constant +// CHECK: %[[cst:.*]] = arith.constant // CHECK: %[[q1:.*]] = "quantfork.qcast"(%arg0) // CHECK-SAME: quant.uniform // CHECK: %[[dq1:.*]] = "quantfork.dcast"(%[[q1]]) From 35b4081219728baa5f0b36dacebd03392474ac25 Mon Sep 17 00:00:00 2001 From: George Necula Date: Wed, 26 Jul 2023 01:08:35 -0700 Subject: [PATCH 171/410] Remove one use of inlining in XlaCallModule shape refinement. To improve debuggability, we want the shape refinement to make as few changes as possible to the module. In this change we remove one use of inlining. PiperOrigin-RevId: 551123379 --- .../compiler/tests/xla_call_module_test.py | 4 +- .../tf2xla/kernels/xla_call_module_loader.cc | 39 ++++++------------- 2 files changed, 13 insertions(+), 30 deletions(-) diff --git a/tensorflow/compiler/tests/xla_call_module_test.py b/tensorflow/compiler/tests/xla_call_module_test.py index 8b75dc450188b1..8f7bed7fe54d8a 100644 --- a/tensorflow/compiler/tests/xla_call_module_test.py +++ b/tensorflow/compiler/tests/xla_call_module_test.py @@ -249,7 +249,7 @@ def test_wrong_actual_args_errors(self): # x: f32[a, 2], return x module, version = serialize(""" module @jit_f.0 attributes {jax.uses_shape_polymorphism = true} { - func.func public @main(%arg0: tensor, %arg1: tensor<*xi32>) -> tensor { + func.func public @main(%arg0: tensor, %arg1: tensor) -> tensor { return %arg0 : tensor } } @@ -279,7 +279,7 @@ def f(x, y): with self.assertRaisesRegex( errors.InvalidArgumentError, 'Element type mismatch for argument 1 passed to XlaCallModule: ' - r'expecting tensor<\*xi32>, got tensor<2x3xf32>', + r'expecting tensor<\?x\?xi32>, got tensor<2x3xf32>', ): self._assertOpOutputMatchesExpected(f, (x, y_bad_etype), (x,)) diff --git a/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc b/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc index 1557c7fd2dbc8d..79257958431c7c 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc @@ -375,40 +375,23 @@ tsl::Status XlaCallModuleLoader::RefineDynamicShapes( } } - // Refine 'main' argument types to use static input types instead. - // This will only change the argument types and will not propagate the - // additional type information further. For that, we'll need to run - // shape refinement as explained below. - // Before refining the argument types it is useful to run the inliner to - // remove calls that may be called with the input arguments. - { - mlir::StatusScopedDiagnosticHandler diag_handler(module_->getContext()); - - mlir::PassManager pm_inline(module_->getContext()); - applyTensorflowAndCLOptions(pm_inline); - pm_inline.addPass(mlir::createInlinerPass()); - - if (mlir::failed(pm_inline.run(*module_))) { - return absl::InvalidArgumentError(absl::StrCat( - "Module inlining failed: ", diag_handler.ConsumeStatus().ToString())); - } + // Refine 'main' argument types to use static input types instead. The main + // arguments may occur as return values, or as inputs to called functions, + // and changing their types may invalidate the module. To prevent this + // we insert dummy conversion ops as the sole uses of the main arguments. + mlir::OpBuilder op_builder(module_->getBodyRegion()); + op_builder.setInsertionPointToStart(&main_body); + for (auto i = 0; i < main_body.getNumArguments(); ++i) { + mlir::BlockArgument arg = main_body.getArgument(i); + auto convert_op = op_builder.create( + arg.getLoc(), arg.getType(), arg); + arg.replaceAllUsesExcept(convert_op, convert_op); } auto static_array_output_types = llvm::to_vector(main_.getResultTypes()); for (auto i = 0; i < main_body.getNumArguments(); ++i) { auto arg = main_body.getArgument(i); arg.setType(static_array_input_types[i]); - // If the argument is used by `func.return`, then we also need to - // update the function result types. It's not great that we need this hack, - // but in the future when we have stablehlo.func, stablehlo.return, etc, - // this will not be needed. - // TODO(burmako): Once https://github.com/openxla/stablehlo/issues/425 is - // fixed, clean this up. - for (mlir::OpOperand &use : arg.getUses()) { - if (auto ret = llvm::dyn_cast(use.getOwner())) { - static_array_output_types[use.getOperandNumber()] = arg.getType(); - } - } } main_.setType(builder.getFunctionType(static_array_input_types, static_array_output_types)); From ab1f5ff898302ec149408c4d030241991544faf6 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Wed, 26 Jul 2023 01:50:35 -0700 Subject: [PATCH 172/410] [xla:gpu] Disable multi-threading in MLIR context when lowering to LMHLO PiperOrigin-RevId: 551131546 --- .../compiler/xla/service/gpu/compile_module_to_llvm_ir.cc | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/xla/service/gpu/compile_module_to_llvm_ir.cc b/tensorflow/compiler/xla/service/gpu/compile_module_to_llvm_ir.cc index 07347af3025796..f24fb7fac2fe68 100644 --- a/tensorflow/compiler/xla/service/gpu/compile_module_to_llvm_ir.cc +++ b/tensorflow/compiler/xla/service/gpu/compile_module_to_llvm_ir.cc @@ -413,7 +413,12 @@ Status CompileModuleToLlvmIrImpl( uint64_t start_usecs = tsl::Env::Default()->NowMicros(); mlir::DialectRegistry registry; IrEmitterUnnested::GetDependentDialects(registry); - auto mlir_context = std::make_unique(registry); + + // Disable MLIR multi-threading to prevent creating too many threads when + // compiling XLA executables concurrently (e.g. during auto-tuning). + auto mlir_context = std::make_unique( + registry, mlir::MLIRContext::Threading::DISABLED); + mlir_context->getDiagEngine().registerHandler(DiagnosticHandler); mlir::OwningOpRef mlir_module = mlir::ModuleOp::create(mlir::Builder(mlir_context.get()).getUnknownLoc()); From f52793a65c5b39e38dbe63aa8a29caac0a5b0813 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Wed, 26 Jul 2023 01:59:48 -0700 Subject: [PATCH 173/410] Use local_defines instead of defines or copt. This is the preferred way, because it adds the define just for the compilation of that target. This change requires adding additional local_defines lines to targets that before relied on the defines line of a target they depend on. PiperOrigin-RevId: 551133550 --- tensorflow/compiler/xla/service/BUILD | 1 + tensorflow/compiler/xla/service/gpu/BUILD | 29 +++++++++++++------ .../service/gpu/gemm_algorithm_picker_test.cc | 1 + .../xla/service/gpu/nvptx_compiler.cc | 1 + .../compiler/xla/service/gpu/runtime/BUILD | 12 +++++++- .../compiler/xla/stream_executor/gpu/BUILD | 2 ++ 6 files changed, 36 insertions(+), 10 deletions(-) diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 5f5284becfbcf3..7d70ef1b756ea5 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -6773,6 +6773,7 @@ build_test( xla_cc_binary( name = "xla_compile", srcs = ["xla_compile_main.cc"], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured(["TENSORFLOW_USE_ROCM"]), visibility = ["//visibility:public"], deps = [ ":compiler", diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 425279a2e6da6f..913b27a5bc868c 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -300,8 +300,8 @@ cc_library( name = "ir_emitter_unnested", srcs = ["ir_emitter_unnested.cc"], hdrs = ["ir_emitter_unnested.h"], - copts = if_cuda_is_configured(["-DGOOGLE_CUDA=1"]) + if_rocm_hipblaslt([ - "-DTF_HIPBLASLT=1", + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured(["TENSORFLOW_USE_ROCM=1"]) + if_rocm_hipblaslt([ + "TF_HIPBLASLT=1", ]), deps = [ ":backend_configs_cc", @@ -880,6 +880,9 @@ cc_library( "sequential_thunk.h", "while_thunk.h", ], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ + "TENSORFLOW_USE_ROCM=1", + ]), deps = [ ":backend_configs_cc", ":buffer_allocations", @@ -987,7 +990,7 @@ cc_library( srcs = ["ir_emission_utils.cc"], hdrs = ["ir_emission_utils.h"], compatible_with = get_compatible_with_portable(), - defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), deps = [ ":target_util", "//tensorflow/compiler/xla:xla_data_proto_cc", @@ -1120,7 +1123,7 @@ cc_library( name = "gemm_rewriter", srcs = ["gemm_rewriter.cc"], hdrs = ["gemm_rewriter.h"], - defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), deps = [ ":backend_configs_cc", ":cublas_cudnn", @@ -1247,6 +1250,9 @@ cc_library( hdrs = if_cuda_is_configured(["cublas_lt_matmul_thunk.h"]) + if_rocm_is_configured([ "cublas_lt_matmul_thunk.h", ]), + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ + "TENSORFLOW_USE_ROCM=1", + ]), deps = if_gpu_is_configured([ ":matmul_utils", ":thunk", @@ -1268,6 +1274,7 @@ cc_library( name = "gemm_algorithm_picker", srcs = if_cuda_is_configured(["gemm_algorithm_picker.cc"]), hdrs = if_cuda_is_configured(["gemm_algorithm_picker.h"]), + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), deps = if_cuda_is_configured([ ":backend_configs_cc", ":buffer_comparator", @@ -1360,6 +1367,7 @@ xla_cc_test( "requires-gpu-sm70", ], deps = [ + ":backend_configs_cc", ":gemm_algorithm_picker", ":gemm_rewriter", "//tensorflow/compiler/xla/service:gpu_plugin", @@ -1380,7 +1388,7 @@ cc_library( srcs = ["matmul_utils.cc"], hdrs = ["matmul_utils.h"], compatible_with = get_compatible_with_portable(), - defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ "TENSORFLOW_USE_ROCM=1", ]), deps = [ @@ -1512,8 +1520,8 @@ cc_library( name = "conv_algorithm_picker", srcs = if_gpu_is_configured(["conv_algorithm_picker.cc"]), hdrs = if_gpu_is_configured(["conv_algorithm_picker.h"]), - copts = if_cuda_is_configured(["-DGOOGLE_CUDA=1"]) + if_rocm_is_configured([ - "-DTENSORFLOW_USE_ROCM=1", + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ + "TENSORFLOW_USE_ROCM=1", ]), deps = if_gpu_is_configured([ ":backend_configs_cc", @@ -2291,6 +2299,9 @@ cc_library( hdrs = if_gpu_is_configured([ "gpu_compiler.h", ]), + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ + "TENSORFLOW_USE_ROCM=1", + ]), deps = if_gpu_is_configured([ ":alias_passthrough_params", ":all_reduce_blueconnect", @@ -2736,7 +2747,7 @@ cc_library( "infeed_manager.h", "outfeed_manager.h", ], - copts = if_cuda_is_configured(["-DGOOGLE_CUDA=1"]), + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), deps = [ ":xfeed_queue", "//tensorflow/compiler/xla:literal", @@ -2946,7 +2957,7 @@ cc_library( "hlo_fusion_analysis.h", "kernel_mapping_scheme.h", ], - copts = if_cuda_is_configured(["-DGOOGLE_CUDA=1"]), + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), deps = [ ":backend_configs_cc", ":gpu_device_info", diff --git a/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker_test.cc b/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker_test.cc index fbe3c3661471a3..49721cceaa7de0 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/gpu/gemm_rewriter.h" #include "tensorflow/compiler/xla/service/pattern_matcher.h" #include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h" diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index fe3f9e1de5e997..319218b18660e8 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -69,6 +69,7 @@ limitations under the License. #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_platform_id.h" #include "tensorflow/compiler/xla/stream_executor/device_description.h" #include "tensorflow/compiler/xla/stream_executor/gpu/asm_compiler.h" +#include "tensorflow/compiler/xla/stream_executor/gpu/gpu_driver.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/tsl/platform/path.h" #include "tensorflow/tsl/platform/status.h" diff --git a/tensorflow/compiler/xla/service/gpu/runtime/BUILD b/tensorflow/compiler/xla/service/gpu/runtime/BUILD index 7be2899dcd15c8..3707630d43c57c 100644 --- a/tensorflow/compiler/xla/service/gpu/runtime/BUILD +++ b/tensorflow/compiler/xla/service/gpu/runtime/BUILD @@ -27,6 +27,7 @@ cc_library( name = "cholesky", srcs = ["cholesky.cc"], hdrs = ["cholesky.h"], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured(["TENSORFLOW_USE_ROCM"]), deps = [ ":support", "//tensorflow/compiler/xla:xla_proto_cc", @@ -65,6 +66,9 @@ cc_library( name = "conv", srcs = ["conv.cc"], hdrs = ["conv.h"], + local_defines = if_cuda_is_configured([ + "GOOGLE_CUDA=1", + ]), deps = [ ":support", "//tensorflow/compiler/xla:status", @@ -113,6 +117,7 @@ cc_library( "-fno-strict-aliasing", ], features = ["-use_header_modules"], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured(["TENSORFLOW_USE_ROCM"]), deps = [ ":support", ":triangular_solve", @@ -132,7 +137,7 @@ cc_library( name = "executable", srcs = ["executable.cc"], hdrs = ["executable.h"], - copts = if_cuda_is_configured(["-DGOOGLE_CUDA=1"]), + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), deps = [ ":cholesky", ":collectives", @@ -311,6 +316,7 @@ cc_library( name = "gemm", srcs = ["gemm.cc"], hdrs = ["gemm.h"], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), deps = [ ":support", "//tensorflow/compiler/xla:status", @@ -337,6 +343,7 @@ cc_library( name = "graph_launch", srcs = ["graph_launch.cc"], hdrs = ["graph_launch.h"], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), deps = [ ":concurrent_region", ":conv", @@ -395,6 +402,7 @@ cc_library( name = "kernel_launch", srcs = ["kernel_launch.cc"], hdrs = ["kernel_launch.h"], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured(["TENSORFLOW_USE_ROCM"]), deps = [ ":concurrent_region", ":support", @@ -417,6 +425,7 @@ cc_library( name = "cublas_lt_matmul", srcs = ["cublas_lt_matmul.cc"], hdrs = ["cublas_lt_matmul.h"], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured(["TENSORFLOW_USE_ROCM"]), deps = [ ":support", "//tensorflow/compiler/xla:xla_proto_cc", @@ -528,6 +537,7 @@ cc_library( name = "triangular_solve", srcs = ["triangular_solve.cc"], hdrs = ["triangular_solve.h"], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured(["TENSORFLOW_USE_ROCM"]), deps = [ ":support", "//tensorflow/compiler/xla:xla_proto_cc", diff --git a/tensorflow/compiler/xla/stream_executor/gpu/BUILD b/tensorflow/compiler/xla/stream_executor/gpu/BUILD index 6e1e2bd9c8ec7a..02c2b41b0a96bb 100644 --- a/tensorflow/compiler/xla/stream_executor/gpu/BUILD +++ b/tensorflow/compiler/xla/stream_executor/gpu/BUILD @@ -316,6 +316,7 @@ cc_library( srcs = if_gpu_is_configured(["asm_compiler.cc"]), hdrs = if_gpu_is_configured(["asm_compiler.h"]), copts = tsl_copts(), + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), visibility = set_external_visibility([ "//third_party/py/jax/jaxlib:__subpackages__", "//tensorflow/compiler/mlir/tools/kernel_gen:__subpackages__", @@ -360,6 +361,7 @@ cc_library( srcs = if_gpu_is_configured(["redzone_allocator.cc"]), hdrs = if_gpu_is_configured(["redzone_allocator.h"]), copts = tsl_copts(), + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), visibility = set_external_visibility([ "//tensorflow/compiler/xla/service/gpu:__subpackages__", "//tensorflow/compiler/xla/stream_executor:__subpackages__", From bd153ae36cca06ef10fd9398ba1638ee9c068063 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 26 Jul 2023 02:02:09 -0700 Subject: [PATCH 174/410] Update GraphDef version to 1569. PiperOrigin-RevId: 551134115 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 046d4cc8289420..fa7a6a64e1608c 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 1568 // Updated: 2023/7/25 +#define TF_GRAPH_DEF_VERSION 1569 // Updated: 2023/7/26 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From 915dd6355b49cf24a8c7c7c19c3ec427c79dc0f6 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 26 Jul 2023 02:02:09 -0700 Subject: [PATCH 175/410] compat: Update forward compatibility horizon to 2023-07-26 PiperOrigin-RevId: 551134118 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index a0b5b73528629f..d9a74cbdaeddb3 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 7, 25) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 7, 26) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From 858f492f9108f0ee64fd1a1ed7e5e06b8701171a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 26 Jul 2023 02:41:26 -0700 Subject: [PATCH 176/410] Avoid pass by value for a Tensor argument. PiperOrigin-RevId: 551144969 --- tensorflow/core/kernels/matmul_op_impl.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/kernels/matmul_op_impl.h b/tensorflow/core/kernels/matmul_op_impl.h index 0888d247d4770c..9aa7f70fe657d6 100644 --- a/tensorflow/core/kernels/matmul_op_impl.h +++ b/tensorflow/core/kernels/matmul_op_impl.h @@ -88,7 +88,7 @@ struct ParallelMatMulKernel { } static void Run(const OpKernelContext* context, const Tensor& in_x, - const Tensor in_y, bool adj_x, bool adj_y, bool trans_x, + const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x, bool trans_y, const MatMulBCast& bcast, Tensor* out, int batch_size) { static_assert(IsComplex, "Complex type expected."); From af8b1d11e450fc968714cf9e4affcbeb4d31cddc Mon Sep 17 00:00:00 2001 From: Andrew Goodbody Date: Tue, 25 Jul 2023 15:36:48 +0100 Subject: [PATCH 177/410] [Linaro:ARM_CI] Reduce number of jobs run in parallel for testing Reduce the number of jobs run in parallel to limit the tendency to swap which results in long execution times. --- tensorflow/tools/ci_build/rel/ubuntu/cpu_arm64_nonpip.sh | 2 +- tensorflow/tools/ci_build/rel/ubuntu/cpu_arm64_pip.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/tools/ci_build/rel/ubuntu/cpu_arm64_nonpip.sh b/tensorflow/tools/ci_build/rel/ubuntu/cpu_arm64_nonpip.sh index d007774fd03ef0..cae5791083a1fa 100644 --- a/tensorflow/tools/ci_build/rel/ubuntu/cpu_arm64_nonpip.sh +++ b/tensorflow/tools/ci_build/rel/ubuntu/cpu_arm64_nonpip.sh @@ -119,7 +119,7 @@ bazel test ${TF_TEST_FLAGS} \ --action_env=PYTHON_BIN_PATH=${PYTHON_BIN_PATH} \ --build_tag_filters=${TF_FILTER_TAGS} \ --test_tag_filters=${TF_FILTER_TAGS} \ - --local_test_jobs=$(grep -c ^processor /proc/cpuinfo) \ + --jobs=32 \ --build_tests_only \ -- ${TF_TEST_TARGETS} diff --git a/tensorflow/tools/ci_build/rel/ubuntu/cpu_arm64_pip.sh b/tensorflow/tools/ci_build/rel/ubuntu/cpu_arm64_pip.sh index 98d49bb9b4f304..817b30da6cff88 100644 --- a/tensorflow/tools/ci_build/rel/ubuntu/cpu_arm64_pip.sh +++ b/tensorflow/tools/ci_build/rel/ubuntu/cpu_arm64_pip.sh @@ -172,7 +172,7 @@ bazel test ${TF_TEST_FLAGS} \ --action_env=PYTHON_BIN_PATH=${PYTHON_BIN_PATH} \ --build_tag_filters=${TF_FILTER_TAGS} \ --test_tag_filters=${TF_FILTER_TAGS} \ - --local_test_jobs=$(grep -c ^processor /proc/cpuinfo) \ + --jobs=32 \ --build_tests_only \ -- ${TF_TEST_TARGETS} From fd079e2cc49e912daa3445892a2048184ae87a5c Mon Sep 17 00:00:00 2001 From: Oleg Shyshkov Date: Wed, 26 Jul 2023 05:50:56 -0700 Subject: [PATCH 178/410] [XLA:GPU] Remove code duplication in DoMultiOutputFusion. PiperOrigin-RevId: 551180029 --- .../xla/service/gpu/multi_output_fusion.cc | 51 +++++-------------- 1 file changed, 13 insertions(+), 38 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc index 1add20eb6e9b23..3f4555b179f1d0 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc @@ -433,48 +433,23 @@ StatusOr GpuMultiOutputFusion::DoMultiOutputFusion() { TF_RETURN_IF_ERROR(cost_analysis.RemoveInstruction(producer)); TF_RETURN_IF_ERROR(cost_analysis.RemoveInstruction(consumer_for_fusion)); + HloInstruction* input_fusion; if (consumer_for_fusion->opcode() == HloOpcode::kFusion) { + input_fusion = consumer_for_fusion; VLOG(2) << "Fuse producer " << producer->name() << " into its consumer " << consumer_for_fusion->name(); - DumpFusionState( - *consumer_for_fusion, - absl::StrCat("About to fuse producer |", producer_name, - "| into consumer |", consumer_for_fusion->name(), - "| inside GPU multi-output fusion"), - /*producer=*/producer); - if (producer->opcode() == HloOpcode::kFusion) { - consumer_for_fusion->MergeFusionInstructionIntoMultiOutput(producer); - } else { - consumer_for_fusion->FuseInstructionIntoMultiOutput(producer); - CHECK_EQ(0, producer->user_count()); - TF_CHECK_OK(computation_->RemoveInstruction(producer)); - } - TF_RETURN_IF_ERROR(cost_analysis.RevisitInstruction(consumer_for_fusion)); - - DumpFusionState( - *consumer_for_fusion, - absl::StrCat("Fusing producer |", producer_name, "| into consumer |", - consumer_for_fusion->name(), - "| inside GPU multi-output fusion")); - RecomputeReachability(); - GpuPerformanceModel::RecordEstimatedRunTime(consumer_for_fusion, - &cost_analysis); - continue; + } else { + input_fusion = computation_->AddInstruction(HloInstruction::CreateFusion( + consumer_for_fusion->shape(), + ChooseFusionKind(*producer, *consumer_for_fusion), + consumer_for_fusion)); + VLOG(2) << "Fuse producer " << producer->name() << " and its consumer " + << consumer_for_fusion->name() << " into " + << input_fusion->name(); + TF_CHECK_OK( + computation_->ReplaceInstruction(consumer_for_fusion, input_fusion)); } - HloInstruction* input_fusion = - computation_->AddInstruction(HloInstruction::CreateFusion( - consumer_for_fusion->shape(), - ChooseFusionKind(*producer, *consumer_for_fusion), - consumer_for_fusion)); - VLOG(2) << "Fuse producer " << producer->name() << " and its consumer " - << consumer_for_fusion->name() << " into " << input_fusion->name(); - DumpFusionState( - *input_fusion, - absl::StrCat("About to fuse |", producer_name, "| into consumer |", - input_fusion->name(), "| inside GPU multi-output fusion"), - /*producer=*/input_fusion); - TF_CHECK_OK( - computation_->ReplaceInstruction(consumer_for_fusion, input_fusion)); + if (producer->opcode() == HloOpcode::kFusion) { input_fusion->MergeFusionInstructionIntoMultiOutput(producer); } else { From f9d4231ed84112571498175b990caa4987e1c64d Mon Sep 17 00:00:00 2001 From: Daniel Lang Date: Wed, 26 Jul 2023 13:16:09 +0000 Subject: [PATCH 179/410] [TFLite] Fix FlatBuffers package name in installed CMake files Similiar to #58677, the capitalization of FlatBuffers needs to match. Otherwise using TFLite via find_package() will fail to find FlatBuffers. --- tensorflow/lite/tools/cmake/tensorflow-liteConfig.cmake.in | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/lite/tools/cmake/tensorflow-liteConfig.cmake.in b/tensorflow/lite/tools/cmake/tensorflow-liteConfig.cmake.in index 04e49106a89a51..8845ba7a8963b5 100644 --- a/tensorflow/lite/tools/cmake/tensorflow-liteConfig.cmake.in +++ b/tensorflow/lite/tools/cmake/tensorflow-liteConfig.cmake.in @@ -17,7 +17,7 @@ include(CMakeFindDependencyMacro) find_dependency(absl) find_dependency(Eigen3) -find_dependency(Flatbuffers) +find_dependency(FlatBuffers) find_dependency(NEON_2_SSE) find_dependency(cpuinfo) find_dependency(ruy) From f95e8b0299b39a526d37af3d2d07c712ce072180 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Wed, 26 Jul 2023 06:29:10 -0700 Subject: [PATCH 180/410] Split ir_emission_utils. Create reduction_utils for utils related to reduction codegen. Move some functions to gpu_fusible. PiperOrigin-RevId: 551187475 --- tensorflow/compiler/xla/service/gpu/BUILD | 35 ++- .../compiler/xla/service/gpu/copy_fusion.cc | 1 + .../compiler/xla/service/gpu/gpu_compiler.cc | 1 + .../compiler/xla/service/gpu/gpu_fusible.cc | 41 ++++ .../compiler/xla/service/gpu/gpu_fusible.h | 29 +++ .../xla/service/gpu/hlo_fusion_analysis.cc | 16 ++ .../xla/service/gpu/hlo_fusion_analysis.h | 2 + .../xla/service/gpu/ir_emission_utils.cc | 221 ------------------ .../xla/service/gpu/ir_emission_utils.h | 74 +----- .../xla/service/gpu/ir_emitter_unnested.cc | 2 + .../xla/service/gpu/reduction_splitter.cc | 5 +- .../xla/service/gpu/reduction_utils.cc | 211 +++++++++++++++++ .../xla/service/gpu/reduction_utils.h | 69 ++++++ .../service/gpu/tree_reduction_rewriter.cc | 2 +- 14 files changed, 406 insertions(+), 303 deletions(-) create mode 100644 tensorflow/compiler/xla/service/gpu/reduction_utils.cc create mode 100644 tensorflow/compiler/xla/service/gpu/reduction_utils.h diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 913b27a5bc868c..dc4e6af7944a10 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -313,6 +313,7 @@ cc_library( ":gpu_device_info", ":gpu_executable", ":gpu_fused_mha_runner", + ":gpu_fusible", ":hlo_fusion_analysis", ":hlo_to_ir_bindings", ":ir_emission_utils", @@ -324,6 +325,7 @@ cc_library( ":matmul_utils", ":nccl_collective_thunks", ":parallel_loop_emitter", + ":reduction_utils", ":target_util", ":thunk", "//tensorflow/compiler/xla:autotuning_proto_cc", @@ -990,7 +992,6 @@ cc_library( srcs = ["ir_emission_utils.cc"], hdrs = ["ir_emission_utils.h"], compatible_with = get_compatible_with_portable(), - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), deps = [ ":target_util", "//tensorflow/compiler/xla:xla_data_proto_cc", @@ -1008,10 +1009,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_set", "@llvm-project//llvm:Core", "@llvm-project//mlir:ArithDialect", - ] + if_cuda_is_configured([ - ":gpu_asm_opts_util", - "//tensorflow/compiler/xla/stream_executor/gpu:asm_compiler", - ]), + ], ) xla_cc_test( @@ -1030,6 +1028,25 @@ xla_cc_test( ], ) +cc_library( + name = "reduction_utils", + srcs = ["reduction_utils.cc"], + hdrs = ["reduction_utils.h"], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/hlo/ir:hlo", + "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/tsl/platform:logging", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/types:span", + ] + if_cuda_is_configured([ + ":gpu_asm_opts_util", + "//tensorflow/compiler/xla/stream_executor/gpu:asm_compiler", + ]), +) + cc_library( name = "cublas_cudnn", srcs = ["cublas_cudnn.cc"], @@ -2342,6 +2359,7 @@ cc_library( ":reduction_dimension_grouper", ":reduction_layout_normalizer", ":reduction_splitter", + ":reduction_utils", ":runtime_intrinsics", ":scatter_slice_simplifier", ":softmax_rewriter_triton", @@ -2964,6 +2982,7 @@ cc_library( ":gpu_fusible", ":ir_emission_utils", ":launch_dimensions", + ":reduction_utils", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:union_find", @@ -3132,6 +3151,7 @@ cc_library( deps = [ ":gpu_device_info", ":ir_emission_utils", + ":reduction_utils", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/hlo/ir:hlo", "//tensorflow/compiler/xla/service:instruction_fusion", @@ -3483,7 +3503,7 @@ cc_library( srcs = ["reduction_splitter.cc"], hdrs = ["reduction_splitter.h"], deps = [ - ":ir_emission_utils", + ":reduction_utils", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/hlo/ir:hlo", "//tensorflow/compiler/xla/service:hlo_pass", @@ -3529,7 +3549,7 @@ cc_library( hdrs = ["tree_reduction_rewriter.h"], deps = [ ":gpu_types", - ":ir_emission_utils", + ":reduction_utils", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", @@ -3934,6 +3954,7 @@ cc_library( hdrs = ["copy_fusion.h"], deps = [ ":ir_emission_utils", + ":reduction_utils", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/hlo/ir:hlo", "//tensorflow/compiler/xla/service:hlo_pass", diff --git a/tensorflow/compiler/xla/service/gpu/copy_fusion.cc b/tensorflow/compiler/xla/service/gpu/copy_fusion.cc index eb8a47d5f72902..3fff01641b5c86 100644 --- a/tensorflow/compiler/xla/service/gpu/copy_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/copy_fusion.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_opcode.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/gpu/reduction_utils.h" namespace xla { namespace gpu { diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 557abcdf6b1171..c150e80ea4e103 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -120,6 +120,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/reduction_dimension_grouper.h" #include "tensorflow/compiler/xla/service/gpu/reduction_layout_normalizer.h" #include "tensorflow/compiler/xla/service/gpu/reduction_splitter.h" +#include "tensorflow/compiler/xla/service/gpu/reduction_utils.h" #include "tensorflow/compiler/xla/service/gpu/runtime_intrinsics.h" #include "tensorflow/compiler/xla/service/gpu/scatter_slice_simplifier.h" #include "tensorflow/compiler/xla/service/gpu/softmax_rewriter_triton.h" diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc index 5a548afd788e93..bb25d173cda2fe 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_opcode.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/gpu/reduction_utils.h" #include "tensorflow/compiler/xla/service/instruction_fusion.h" #include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -739,5 +740,45 @@ size_t GetOutputSizeOfFusible(const HloInstruction& instr) { return ShapeUtil::TupleElementCount(root->shape()); } +// Recursive helper for GetFusionRoots below. +static void GetFusionRootsRec(HloInstruction* root, + std::vector& out) { + if (root->opcode() == HloOpcode::kGetTupleElement) { + return GetFusionRootsRec(root->mutable_operand(0), out); + } else if (root->opcode() == HloOpcode::kTuple) { + for (int i = 0; i < root->operand_count(); i++) { + GetFusionRootsRec(root->mutable_operand(i), out); + } + } else { + if (!out.empty() && out.back() == root) { + return; + } + CHECK(!absl::c_linear_search(out, root)) + << "Fusion root contains instruction " << root->ToString() + << " multiple times"; + out.push_back(root); + } +} + +std::vector GetFusionRoots(HloComputation* computation) { + std::vector out; + GetFusionRootsRec(computation->root_instruction(), out); + return out; +} + +bool HasAnyTiledTransposeRoot(HloComputation* computation) { + return absl::c_any_of(GetFusionRoots(computation), + [&](const HloInstruction* instr) { + return FindAnyTiledTranspose(*instr); + }); +} + +bool HasAnyUnnestedReductionRoot(HloComputation* computation) { + return absl::c_any_of( + GetFusionRoots(computation), [&](const HloInstruction* instr) { + return IsReductionFromOrToContiguousDimensions(*instr); + }); +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible.h b/tensorflow/compiler/xla/service/gpu/gpu_fusible.h index 31fba3c0f9a39b..536dcf91e0088d 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_fusible.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible.h @@ -16,6 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_FUSIBLE_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_FUSIBLE_H_ +#include + +#include "tensorflow/compiler/xla/hlo/ir/hlo_computation.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" #include "tensorflow/compiler/xla/service/gpu/gpu_device_info.h" #include "tensorflow/compiler/xla/service/instruction_fusion.h" @@ -158,6 +161,32 @@ absl::InlinedVector GetOutputsOfFusible( // Returns the output size of the fusible `instr`. size_t GetOutputSizeOfFusible(const HloInstruction& instr); +// Returns instructions which are roots of the fusion, following the operands of +// GTE instructions in the root tuple. Groups multiple subsequent instructions +// with the same root. CHECKs that the fusion never outputs the same instruction +// twice, as well as that there are no explicitly created tuples or nested gtes +// in fusion output. +// +// For input: (tuple (gte R1) (gte R1) O2) +// Expected output: [R1, O2] +// +// For input: (tuple R1 R2 O2) +// Expected output: [R1, R2, O2] +// +// For input: (tuple (gte R1) (gte R1) R2 O3) +// Expected output: [R1, R2, O3] +// +// For input: R1 +// Expected output: [R1] +std::vector GetFusionRoots(HloComputation* computation); + +// Whether there is a fusion root triggering transposition emitter. +bool HasAnyTiledTransposeRoot(HloComputation* computation); + +// Returns whether the computation has at least one root triggering unnested +// reduction emitter. +bool HasAnyUnnestedReductionRoot(HloComputation* computation); + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/hlo_fusion_analysis.cc b/tensorflow/compiler/xla/service/gpu/hlo_fusion_analysis.cc index 2824ad9ee2c4cf..a2a80633f8c0e4 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_fusion_analysis.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_fusion_analysis.cc @@ -357,6 +357,22 @@ StatusOr HloFusionAnalysis::GetLaunchDimensions( } } +namespace { +// Returns the hero reduction of the computation. +// We always use the first reduce root that triggers unnested reduction emitter +// as the hero reduction, since all the reductions are required to have the same +// shape and layout as verified by `IsFusedReductionOutputConsistent()`. +HloInstruction* FindHeroReduction(absl::Span roots) { + auto it = absl::c_find_if(roots, [](HloInstruction* instr) { + return IsReductionFromOrToContiguousDimensions(*instr); + }); + if (it == roots.end()) { + return nullptr; + } + return *it; +} +} // namespace + const ReductionCodegenInfo* HloFusionAnalysis::GetReductionCodegenInfo() { if (reduction_codegen_info_.has_value()) { return &reduction_codegen_info_.value(); diff --git a/tensorflow/compiler/xla/service/gpu/hlo_fusion_analysis.h b/tensorflow/compiler/xla/service/gpu/hlo_fusion_analysis.h index 9ded64edabfe68..ed7a1d4d6c9421 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_fusion_analysis.h +++ b/tensorflow/compiler/xla/service/gpu/hlo_fusion_analysis.h @@ -25,9 +25,11 @@ limitations under the License. #include "tensorflow/compiler/xla/hlo/ir/hlo_instructions.h" #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/gpu/gpu_device_info.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h" #include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h" +#include "tensorflow/compiler/xla/service/gpu/reduction_utils.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/stream_executor/device_description.h" diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index b5827e1ec08417..9d6886471c23a9 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -40,11 +40,6 @@ limitations under the License. #include "tensorflow/compiler/xla/translate/mhlo_to_hlo/type_to_shape.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#ifdef GOOGLE_CUDA -#include "tensorflow/compiler/xla/service/gpu/gpu_asm_opts_util.h" -#include "tensorflow/compiler/xla/stream_executor/gpu/asm_compiler.h" -#endif // GOOGLE_CUDA - namespace xla { namespace gpu { @@ -55,37 +50,6 @@ bool IsRank2(const Shape& shape, int64_t batch_dimensions_size) { return shape.rank() == batch_dimensions_size + 2; } -// Given a shape and a group of contiguous dimensions in the shape, returns -// a tuple of three values (major, middle, minor), where major is the size of -// the dimensions more major then the given dimensions, minor is the size of -// dimensions more minor then the given dimensions, and middle is the size of -// the given dimensions. -Vector3 PartitionShapeByMiddleDimensions( - const Shape& shape, absl::Span dims_middle) { - CHECK(LayoutUtil::AreDimensionsConsecutive(shape.layout(), dims_middle)); - Vector3 values = {1, 1, 1}; - enum Segment { kMajor = 0, kMiddle = 1, kMinor = 2 }; - Segment cur_segment = kMinor; - - for (int64_t cur_dim : LayoutUtil::MinorToMajor(shape)) { - if (cur_segment != kMajor) { - // Handle change of segments. - bool cur_dim_in_middle = absl::c_linear_search(dims_middle, cur_dim); - if (cur_segment == kMinor) { - if (cur_dim_in_middle) { - cur_segment = kMiddle; - } - } else if (cur_segment == kMiddle) { - if (!cur_dim_in_middle) { - cur_segment = kMajor; - } - } - } - values[cur_segment] *= shape.dimensions(cur_dim); - } - return values; -} - Shape GetShapeFromTensorType(mlir::Value value) { constexpr char kDefaultLayoutAttrName[] = "xla_shape"; @@ -141,34 +105,6 @@ bool IsMatrixMultiplication(const HloInstruction& dot) { return true; } -int64_t MinThreadsXRowReduction(const HloModuleConfig& hlo_module_config) { -#ifdef GOOGLE_CUDA - auto ptxas_config = - PtxOptsFromDebugOptions(hlo_module_config.debug_options()); - auto ptxas_version_tuple = - se::GetAsmCompilerVersion(ptxas_config.preferred_cuda_dir); - // ptxas versions prior to 12.2 have a very rare bug when very high register - // spilling occurs with some order of instructions, so use less threads to - // reduce register pressure. - if (!ptxas_version_tuple.ok() || - ptxas_version_tuple.value() < std::array{12, 2, 0}) { - return 512; - } -#endif // GOOGLE_CUDA - return 1024; -} - -Vector3 GetReductionTiling(const ReductionDimensions& reduction_dimensions) { - if (reduction_dimensions.is_row_reduction) { - int64_t tile_z = std::min(reduction_dimensions.dimensions[0], - BatchedReductionRaceFreeBound()); - return {tile_z, 1, 16}; - } - - // Column reduction. - return {1, 128, 1}; -} - const char* const kCusolverCholeskyCallTarget = "__cusolver$cholesky"; bool IsCustomCallToCusolver(const HloInstruction& hlo) { @@ -178,59 +114,6 @@ bool IsCustomCallToCusolver(const HloInstruction& hlo) { return hlo.custom_call_target() == kCusolverCholeskyCallTarget; } -static bool IsUnnestedReductionFasterThanElemental( - const ReductionDimensions& reduction_dimensions) { - if (reduction_dimensions.is_row_reduction) { - // For row reduction, the tile block is 1 x tile_size_x, and we are reducing - // along tile_size_x which needs to be large enough to make the tiling - // implementation efficient. - // For very small reductions with a power-of-two size, we can fit multiple - // reductions inside a single warp, which is more efficient than a loop. - return (reduction_dimensions.dimensions[2] >= WarpSize()) || - ((WarpSize() % reduction_dimensions.dimensions[2]) == 0); - } - - // For column reduction, the tile block is tile_size_y x tile_size_x, and we - // are reducing along tile_size_y. Only tile_size_y needs to be - // large enough to make the tiling implementation efficient. - int64_t major_size = reduction_dimensions.dimensions[1]; - int64_t minor_size = reduction_dimensions.dimensions[2]; - - // Rule generated by sweeping the search space of small column reductions. - bool prefer_elemental_emitter = - (major_size < WarpSize()) || - (major_size < 2 * WarpSize() && minor_size < WarpSize()) || - (major_size < 4 * WarpSize() && minor_size < 8) || - (major_size < 8 * WarpSize() && minor_size < 3); - - return !prefer_elemental_emitter; -} - -bool IsReductionFromOrToContiguousDimensions(const HloInstruction& reduce) { - if (reduce.opcode() != HloOpcode::kReduce) { - return false; - } - - const Shape& operand_shape = reduce.operand(0)->shape(); - absl::Span dims_to_reduce = reduce.dimensions(); - DimensionVector dims_to_keep; - for (int64_t dim = 0; dim < operand_shape.dimensions().size(); ++dim) { - if (!absl::c_linear_search(dims_to_reduce, dim)) { - dims_to_keep.push_back(dim); - } - } - - // We support fast codegen for three cases: - // 1) Row reduction: (K, R) - // 2) Column reduction: (K, R, K) - // 3) "Batched" row reduction: (R, K, R) - return (LayoutUtil::AreDimensionsConsecutive(operand_shape.layout(), - dims_to_keep) || - LayoutUtil::AreDimensionsConsecutive(operand_shape.layout(), - dims_to_reduce)) && - IsUnnestedReductionFasterThanElemental( - GetReductionKindAndContiguousComponents(reduce)); -} bool IsInputFusibleSlices(mlir::Operation* unnested_hlo, bool verify_no_strides) { @@ -257,46 +140,6 @@ bool IsInputFusibleSlices(mlir::Operation* unnested_hlo, return true; } -ReductionDimensions GetReductionKindAndContiguousComponents( - const HloInstruction& reduce) { - Shape input_shape = reduce.operand(0)->shape(); - absl::Span dims_to_reduce = reduce.dimensions(); - DimensionVector dims_to_keep; - for (int64_t dim = 0; dim < input_shape.rank(); ++dim) { - if (!absl::c_linear_search(dims_to_reduce, dim)) { - dims_to_keep.push_back(dim); - } - } - - if (dims_to_keep.empty()) { - return {/*is_row_reduction=*/true, - {1, 1, ShapeUtil::ElementsIn(input_shape)}}; - } - - if (LayoutUtil::AreDimensionsConsecutive(input_shape.layout(), - dims_to_keep)) { - Vector3 shape_partition = - PartitionShapeByMiddleDimensions(input_shape, dims_to_keep); - if (shape_partition[1] == 1) { - return {/*is_row_reduction=*/true, - {1, 1, shape_partition[0] * shape_partition[2]}}; - } - if (shape_partition[2] == 1) { - return {/*is_row_reduction=*/false, - {1, shape_partition[0], shape_partition[1]}}; - } - return {/*is_row_reduction=*/true, shape_partition}; - } - - Vector3 shape_partition = - PartitionShapeByMiddleDimensions(input_shape, dims_to_reduce); - - if (shape_partition[2] == 1) { - return {/*is_row_reduction=*/true, - {1, shape_partition[0], shape_partition[1]}}; - } - return {/*is_row_reduction=*/false, shape_partition}; -} // This emits a device-side call to // "i32 vprintf(i8* fmt, arguments_type* arguments)" in the driver; see @@ -799,46 +642,6 @@ Shape GetShape(mlir::Value value) { return {}; } -bool ReductionIsRaceFree(const HloModuleConfig& hlo_module_config, - const ReductionDimensions& reduction_dimensions) { - Vector3 reduction_tiling = GetReductionTiling(reduction_dimensions); - return (reduction_dimensions.is_row_reduction && - reduction_dimensions.dimensions[2] <= - MinThreadsXRowReduction(hlo_module_config) * - reduction_tiling[2] && - reduction_dimensions.dimensions[0] <= - BatchedReductionRaceFreeBound()) || - (!reduction_dimensions.is_row_reduction && - reduction_dimensions.dimensions[1] <= - WarpSize() * reduction_tiling[1]); -} - -// Recursive helper for GetFusionRoots below. -static void GetFusionRootsRec(HloInstruction* root, - std::vector& out) { - if (root->opcode() == HloOpcode::kGetTupleElement) { - return GetFusionRootsRec(root->mutable_operand(0), out); - } else if (root->opcode() == HloOpcode::kTuple) { - for (int i = 0; i < root->operand_count(); i++) { - GetFusionRootsRec(root->mutable_operand(i), out); - } - } else { - if (!out.empty() && out.back() == root) { - return; - } - CHECK(!absl::c_linear_search(out, root)) - << "Fusion root contains instruction " << root->ToString() - << " multiple times"; - out.push_back(root); - } -} - -std::vector GetFusionRoots(HloComputation* computation) { - std::vector out; - GetFusionRootsRec(computation->root_instruction(), out); - return out; -} - std::optional FindTiledTranspose( const HloInstruction& instr) { if (instr.opcode() != HloOpcode::kCopy) { @@ -985,30 +788,6 @@ const HloInstruction& FindNonTrivialHero(const HloInstruction& instr) { return *non_trivial_hero; } -bool HasAnyTiledTransposeRoot(HloComputation* computation) { - return absl::c_any_of(GetFusionRoots(computation), - [&](const HloInstruction* instr) { - return FindAnyTiledTranspose(*instr); - }); -} - -bool HasAnyUnnestedReductionRoot(HloComputation* computation) { - return absl::c_any_of( - GetFusionRoots(computation), [&](const HloInstruction* instr) { - return IsReductionFromOrToContiguousDimensions(*instr); - }); -} - -HloInstruction* FindHeroReduction(absl::Span roots) { - auto it = absl::c_find_if(roots, [](HloInstruction* instr) { - return IsReductionFromOrToContiguousDimensions(*instr); - }); - if (it == roots.end()) { - return nullptr; - } - return *it; -} - void LogAndVerify(const llvm::Module* m) { if (VLOG_IS_ON(5)) { XLA_VLOG_LINES(5, llvm_ir::DumpToString(m)); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h index 032237d3f8d5f4..7c649cb4d3dae6 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h @@ -47,13 +47,6 @@ bool IsMatrixMultiplication(const HloInstruction& dot); inline constexpr int64_t WarpSize() { return 32; } -// Need at least 1024 threads/block for reasonable tree reduction -// performance (assuming all data fits). -int64_t MinThreadsXRowReduction(const HloModuleConfig& hlo_module_config); - -// When doing batched row reduction, how big the batch dimension could be. -inline constexpr int64_t BatchedReductionRaceFreeBound() { return 8; } - // Fusions that use Triton have FusionBackendConfig.kind equal to this string. inline constexpr absl::string_view kTritonGemmFusionKind = "__triton_gemm"; @@ -78,40 +71,12 @@ bool IsCustomCallToCusolver(const HloInstruction& hlo); // is a success/failure code per batch element. extern const char* const kCusolverCholeskyCallTarget; -// Returns true if either the dimensions being reduced or the dimensions being -// kept are contiguous in the input of the reduce instruction. -bool IsReductionFromOrToContiguousDimensions(const HloInstruction& reduce); - // Returns whether unnested_hlo is an input fusion whose root is either a slice // or a tuple of slices. If verify_no_strides is true, returns false unless all // ROOT slices have no strides. bool IsInputFusibleSlices(mlir::Operation* unnested_hlo, bool verify_no_strides); -struct ReductionDimensions { - // Indicates whether the reduction is a row reduction or a column reduction. - bool is_row_reduction; - - // Contains the size of the three contiguous components for - // the reduction [depth, height, width] (major-to-minor ordering). - // - // For row reduction, we do: [D, H, W] -> [D, H]. - // For column reduction, we do: [D, H, W] -> [D, W]. - Vector3 dimensions; -}; - -// Given the input shape and dimensions to reduce for a reduction, returns -// ReductionDimensions. -// -// Prerequisite: the reduction instruction passes the check -// IsReductionFromOrToContiguousDimensions, which guarantees either the -// dimensions to reduce or the dimensions to keep are consecutive. -ReductionDimensions GetReductionKindAndContiguousComponents( - const HloInstruction& reduce); - -// Get tiling per thread for the given reduction in dimensions [D, H, W]. -Vector3 GetReductionTiling(const ReductionDimensions& reduction_dimensions); - // Emits call to "vprintf" with given format and arguments. llvm::Value* EmitPrintf(absl::string_view fmt, absl::Span arguments, @@ -171,12 +136,7 @@ GetOutputDefiningDynamicUpdateSliceOps(mlir::lmhlo::FusionOp fusion); Shape GetShape(mlir::Value value); -// Returns whether the given reduction can be safely generated without atomics: -// that is, at most one block will write to every output element. -bool ReductionIsRaceFree(const HloModuleConfig& hlo_module_config, - const ReductionDimensions& reduction_dimensions); - -// Description of how to emit a given transposition. +/// Description of how to emit a given transposition. // // On a group of input parameters that are 0-2-1 transpose of the outputs // of a fusion kernel, stores the input parameters that are safe for the @@ -204,40 +164,8 @@ struct TransposeDimsAndParams { } }; -// Returns instructions which are roots of the fusion, following the operands of -// GTE instructions in the root tuple. Groups multiple subsequent instructions -// with the same root. CHECKs that the fusion never outputs the same instruction -// twice, as well as that there are no explicitly created tuples or nested gtes -// in fusion output. -// -// For input: (tuple (gte R1) (gte R1) O2) -// Expected output: [R1, O2] -// -// For input: (tuple R1 R2 O2) -// Expected output: [R1, R2, O2] -// -// For input: (tuple (gte R1) (gte R1) R2 O3) -// Expected output: [R1, R2, O3] -// -// For input: R1 -// Expected output: [R1] -std::vector GetFusionRoots(HloComputation* computation); - -// Returns whether the computation has at least one root triggering unnested -// reduction emitter. -bool HasAnyUnnestedReductionRoot(HloComputation* computation); - -// Returns the hero reduction of the computation. -// We always use the first reduce root that triggers unnested reduction emitter -// as the hero reduction, since all the reductions are required to have the same -// shape and layout as verified by `IsFusedReductionOutputConsistent()`. -HloInstruction* FindHeroReduction(absl::Span roots); - const HloInstruction& FindNonTrivialHero(const HloInstruction& instr); -// Whether there is a fusion root triggering transposition emitter. -bool HasAnyTiledTransposeRoot(HloComputation* computation); - struct TransposeDescription { Vector3 dimensions; Vector3 permutation; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index d51ef9046c6e5b..98f30f7078ceba 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -99,6 +99,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/gpu_device_info.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" #include "tensorflow/compiler/xla/service/gpu/gpu_fused_mha_runner.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h" #include "tensorflow/compiler/xla/service/gpu/hlo_fusion_analysis.h" #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h" #include "tensorflow/compiler/xla/service/gpu/infeed_thunk.h" @@ -118,6 +119,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.h" #include "tensorflow/compiler/xla/service/gpu/outfeed_thunk.h" #include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h" +#include "tensorflow/compiler/xla/service/gpu/reduction_utils.h" #include "tensorflow/compiler/xla/service/gpu/replica_id_thunk.h" #include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h" #include "tensorflow/compiler/xla/service/gpu/target_util.h" diff --git a/tensorflow/compiler/xla/service/gpu/reduction_splitter.cc b/tensorflow/compiler/xla/service/gpu/reduction_splitter.cc index 2baa2328a115c7..473545f543f1d3 100644 --- a/tensorflow/compiler/xla/service/gpu/reduction_splitter.cc +++ b/tensorflow/compiler/xla/service/gpu/reduction_splitter.cc @@ -16,10 +16,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/reduction_splitter.h" #include +#include +#include +#include #include "tensorflow/compiler/xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_instructions.h" -#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/gpu/reduction_utils.h" #include "tensorflow/compiler/xla/shape_util.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/gpu/reduction_utils.cc b/tensorflow/compiler/xla/service/gpu/reduction_utils.cc new file mode 100644 index 00000000000000..a33406f937bc64 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/reduction_utils.cc @@ -0,0 +1,211 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/reduction_utils.h" + +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/tsl/platform/logging.h" + +#ifdef GOOGLE_CUDA +#include "tensorflow/compiler/xla/service/gpu/gpu_asm_opts_util.h" +#include "tensorflow/compiler/xla/stream_executor/gpu/asm_compiler.h" +#endif // GOOGLE_CUDA + +namespace xla { +namespace gpu { + +namespace { +// Given a shape and a group of contiguous dimensions in the shape, returns +// a tuple of three values (major, middle, minor), where major is the size of +// the dimensions more major then the given dimensions, minor is the size of +// dimensions more minor then the given dimensions, and middle is the size of +// the given dimensions. +Vector3 PartitionShapeByMiddleDimensions( + const Shape& shape, absl::Span dims_middle) { + CHECK(LayoutUtil::AreDimensionsConsecutive(shape.layout(), dims_middle)); + Vector3 values = {1, 1, 1}; + enum Segment { kMajor = 0, kMiddle = 1, kMinor = 2 }; + Segment cur_segment = kMinor; + + for (int64_t cur_dim : LayoutUtil::MinorToMajor(shape)) { + if (cur_segment != kMajor) { + // Handle change of segments. + bool cur_dim_in_middle = absl::c_linear_search(dims_middle, cur_dim); + if (cur_segment == kMinor) { + if (cur_dim_in_middle) { + cur_segment = kMiddle; + } + } else if (cur_segment == kMiddle) { + if (!cur_dim_in_middle) { + cur_segment = kMajor; + } + } + } + values[cur_segment] *= shape.dimensions(cur_dim); + } + return values; +} +} // namespace + +int64_t MinThreadsXRowReduction(const HloModuleConfig& hlo_module_config) { +#ifdef GOOGLE_CUDA + auto ptxas_config = + PtxOptsFromDebugOptions(hlo_module_config.debug_options()); + auto ptxas_version_tuple = + se::GetAsmCompilerVersion(ptxas_config.preferred_cuda_dir); + // ptxas versions prior to 12.2 have a very rare bug when very high register + // spilling occurs with some order of instructions, so use less threads to + // reduce register pressure. + if (!ptxas_version_tuple.ok() || + ptxas_version_tuple.value() < std::array{12, 2, 0}) { + return 512; + } +#endif // GOOGLE_CUDA + return 1024; +} + +Vector3 GetReductionTiling(const ReductionDimensions& reduction_dimensions) { + if (reduction_dimensions.is_row_reduction) { + int64_t tile_z = std::min(reduction_dimensions.dimensions[0], + BatchedReductionRaceFreeBound()); + return {tile_z, 1, 16}; + } + + // Column reduction. + return {1, 128, 1}; +} + +static bool IsUnnestedReductionFasterThanElemental( + const ReductionDimensions& reduction_dimensions) { + const int kWarpSize = 32; + if (reduction_dimensions.is_row_reduction) { + // For row reduction, the tile block is 1 x tile_size_x, and we are reducing + // along tile_size_x which needs to be large enough to make the tiling + // implementation efficient. + // For very small reductions with a power-of-two size, we can fit multiple + // reductions inside a single warp, which is more efficient than a loop. + return (reduction_dimensions.dimensions[2] >= kWarpSize) || + ((kWarpSize % reduction_dimensions.dimensions[2]) == 0); + } + + // For column reduction, the tile block is tile_size_y x tile_size_x, and we + // are reducing along tile_size_y. Only tile_size_y needs to be + // large enough to make the tiling implementation efficient. + int64_t major_size = reduction_dimensions.dimensions[1]; + int64_t minor_size = reduction_dimensions.dimensions[2]; + + // Rule generated by sweeping the search space of small column reductions. + bool prefer_elemental_emitter = + (major_size < kWarpSize) || + (major_size < 2 * kWarpSize && minor_size < kWarpSize) || + (major_size < 4 * kWarpSize && minor_size < 8) || + (major_size < 8 * kWarpSize && minor_size < 3); + + return !prefer_elemental_emitter; +} + +bool IsReductionFromOrToContiguousDimensions(const HloInstruction& reduce) { + if (reduce.opcode() != HloOpcode::kReduce) { + return false; + } + + const Shape& operand_shape = reduce.operand(0)->shape(); + absl::Span dims_to_reduce = reduce.dimensions(); + DimensionVector dims_to_keep; + for (int64_t dim = 0; dim < operand_shape.dimensions().size(); ++dim) { + if (!absl::c_linear_search(dims_to_reduce, dim)) { + dims_to_keep.push_back(dim); + } + } + + // We support fast codegen for three cases: + // 1) Row reduction: (K, R) + // 2) Column reduction: (K, R, K) + // 3) "Batched" row reduction: (R, K, R) + return (LayoutUtil::AreDimensionsConsecutive(operand_shape.layout(), + dims_to_keep) || + LayoutUtil::AreDimensionsConsecutive(operand_shape.layout(), + dims_to_reduce)) && + IsUnnestedReductionFasterThanElemental( + GetReductionKindAndContiguousComponents(reduce)); +} + +bool ReductionIsRaceFree(const HloModuleConfig& hlo_module_config, + const ReductionDimensions& reduction_dimensions) { + const int kWarpSize = 32; + Vector3 reduction_tiling = GetReductionTiling(reduction_dimensions); + return (reduction_dimensions.is_row_reduction && + reduction_dimensions.dimensions[2] <= + MinThreadsXRowReduction(hlo_module_config) * + reduction_tiling[2] && + reduction_dimensions.dimensions[0] <= + BatchedReductionRaceFreeBound()) || + (!reduction_dimensions.is_row_reduction && + reduction_dimensions.dimensions[1] <= + kWarpSize * reduction_tiling[1]); +} + +ReductionDimensions GetReductionKindAndContiguousComponents( + const HloInstruction& reduce) { + Shape input_shape = reduce.operand(0)->shape(); + absl::Span dims_to_reduce = reduce.dimensions(); + DimensionVector dims_to_keep; + for (int64_t dim = 0; dim < input_shape.rank(); ++dim) { + if (!absl::c_linear_search(dims_to_reduce, dim)) { + dims_to_keep.push_back(dim); + } + } + + if (dims_to_keep.empty()) { + return {/*is_row_reduction=*/true, + {1, 1, ShapeUtil::ElementsIn(input_shape)}}; + } + + if (LayoutUtil::AreDimensionsConsecutive(input_shape.layout(), + dims_to_keep)) { + Vector3 shape_partition = + PartitionShapeByMiddleDimensions(input_shape, dims_to_keep); + if (shape_partition[1] == 1) { + return {/*is_row_reduction=*/true, + {1, 1, shape_partition[0] * shape_partition[2]}}; + } + if (shape_partition[2] == 1) { + return {/*is_row_reduction=*/false, + {1, shape_partition[0], shape_partition[1]}}; + } + return {/*is_row_reduction=*/true, shape_partition}; + } + + Vector3 shape_partition = + PartitionShapeByMiddleDimensions(input_shape, dims_to_reduce); + + if (shape_partition[2] == 1) { + return {/*is_row_reduction=*/true, + {1, shape_partition[0], shape_partition[1]}}; + } + return {/*is_row_reduction=*/false, shape_partition}; +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/reduction_utils.h b/tensorflow/compiler/xla/service/gpu/reduction_utils.h new file mode 100644 index 00000000000000..5f87f0347e3e18 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/reduction_utils.h @@ -0,0 +1,69 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_REDUCTION_UTILS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_REDUCTION_UTILS_H_ + +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/util.h" + +namespace xla { +namespace gpu { + +// Need at least 1024 threads/block for reasonable tree reduction +// performance (assuming all data fits). +int64_t MinThreadsXRowReduction(const HloModuleConfig& hlo_module_config); + +// When doing batched row reduction, how big the batch dimension could be. +inline constexpr int64_t BatchedReductionRaceFreeBound() { return 8; } + +// Returns true if either the dimensions being reduced or the dimensions being +// kept are contiguous in the input of the reduce instruction. +bool IsReductionFromOrToContiguousDimensions(const HloInstruction& reduce); + +struct ReductionDimensions { + // Indicates whether the reduction is a row reduction or a column reduction. + bool is_row_reduction; + + // Contains the size of the three contiguous components for + // the reduction [depth, height, width] (major-to-minor ordering). + // + // For row reduction, we do: [D, H, W] -> [D, H]. + // For column reduction, we do: [D, H, W] -> [D, W]. + Vector3 dimensions; +}; + +// Given the input shape and dimensions to reduce for a reduction, returns +// ReductionDimensions. +// +// Prerequisite: the reduction instruction passes the check +// IsReductionFromOrToContiguousDimensions, which guarantees either the +// dimensions to reduce or the dimensions to keep are consecutive. +ReductionDimensions GetReductionKindAndContiguousComponents( + const HloInstruction& reduce); + +// Get tiling per thread for the given reduction in dimensions [D, H, W]. +Vector3 GetReductionTiling(const ReductionDimensions& reduction_dimensions); + +// Returns whether the given reduction can be safely generated without atomics : +// that is, at most one block will write to every output element. +bool ReductionIsRaceFree(const HloModuleConfig& hlo_module_config, + const ReductionDimensions& reduction_dimensions); + +} // namespace gpu +} // namespace xla +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_REDUCTION_UTILS_H_ diff --git a/tensorflow/compiler/xla/service/gpu/tree_reduction_rewriter.cc b/tensorflow/compiler/xla/service/gpu/tree_reduction_rewriter.cc index b2f5b48fcb17e7..7d1ff1b7bb840e 100644 --- a/tensorflow/compiler/xla/service/gpu/tree_reduction_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/tree_reduction_rewriter.cc @@ -29,7 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_module.h" #include "tensorflow/compiler/xla/service/collective_ops_utils.h" -#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/gpu/reduction_utils.h" #include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" From 5b2c3dd7866f9e9e4003817e65fefbaf7f2c0697 Mon Sep 17 00:00:00 2001 From: "Jiyoun (Jen) Ha" Date: Wed, 26 Jul 2023 07:14:45 -0700 Subject: [PATCH 181/410] Register stablehlo dialect for tf_tfl_translate PiperOrigin-RevId: 551196576 --- tensorflow/compiler/mlir/lite/BUILD | 2 +- tensorflow/compiler/mlir/lite/tf_tfl_translate.cc | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 73aa53f9d54479..17c52a69c3e67f 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -1209,7 +1209,6 @@ tf_cc_binary( "//tensorflow/core/platform:errors", "//tensorflow/lite:framework", "//tensorflow/lite/schema:schema_fbs", - "//tensorflow/tsl/platform:status", "//tensorflow/tsl/platform:statusor", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", @@ -1219,6 +1218,7 @@ tf_cc_binary( "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", "@llvm-project//mlir:Transforms", + "@stablehlo//:stablehlo_ops", ], ) diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc index 7187bebd594bba..a4c4b36f87b14c 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc @@ -42,6 +42,7 @@ limitations under the License. #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Support/FileUtilities.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "tensorflow/cc/saved_model/loader.h" #include "tensorflow/compiler/mlir/init_mlir.h" #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" @@ -164,7 +165,8 @@ int main(int argc, char **argv) { // back to do it properly in the future mlir::DialectRegistry registry; RegisterAllTensorFlowDialects(registry); - registry.insert(); + registry + .insert(); context.appendDialectRegistry(registry); } From 6d990a87f15d6e28397c5d38b56d2545eebbd9ed Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Wed, 26 Jul 2023 08:26:36 -0700 Subject: [PATCH 182/410] [XLA:GPU] Make float normalization for convolutions ignore other types of HLOs. PiperOrigin-RevId: 551212985 --- tensorflow/compiler/xla/service/gpu/BUILD | 15 +++++++ .../xla/service/gpu/float_support_test.cc | 44 +++++++++++++++++++ .../xla/service/gpu/nvptx_compiler.cc | 5 +++ 3 files changed, 64 insertions(+) create mode 100644 tensorflow/compiler/xla/service/gpu/float_support_test.cc diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index dc4e6af7944a10..34b6bf9b5e9382 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -3266,6 +3266,21 @@ xla_cc_test( ], ) +xla_test( + name = "float_support_test", + srcs = ["float_support_test.cc"], + backend_tags = {"gpu": [ + "requires-gpu-sm80", + ]}, + backends = [ + "gpu", + ], + deps = [ + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + xla_cc_test( name = "conv_layout_normalization_test", srcs = ["conv_layout_normalization_test.cc"], diff --git a/tensorflow/compiler/xla/service/gpu/float_support_test.cc b/tensorflow/compiler/xla/service/gpu/float_support_test.cc new file mode 100644 index 00000000000000..b711b7ef2f856d --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/float_support_test.cc @@ -0,0 +1,44 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" + +namespace xla { +namespace gpu { +namespace { + +using FloatSupportTest = HloTestBase; + +TEST_F(FloatSupportTest, MixedTypeDotIsNotUpcasted) { + const std::string kHloText = R"( +ENTRY e { + p0 = bf16[32,32] parameter(0) + p1 = bf16[32,32] parameter(1) + ROOT d = f32[32,32] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"; + + MatchOptimizedHlo(kHloText, R"( +; CHECK-NOT: convert +)"); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{1e-6, 1e-6})); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index 319218b18660e8..17a042a3959b3f 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -102,6 +102,11 @@ class ConvBfloat16Support : public FloatSupport { return (hlo.opcode() != HloOpcode::kConvolution) || is_conv_bf16_supported_; } + bool SupportsMixedPrecisions(const HloInstruction& hlo) const override { + // Skip all HLOs other than convolutions. + return (hlo.opcode() != HloOpcode::kConvolution); + } + private: bool is_conv_bf16_supported_; }; From 9720b405905dee209a3f7d003de21d388e1aaef4 Mon Sep 17 00:00:00 2001 From: Nathan Luehr Date: Tue, 25 Jul 2023 17:33:15 -0500 Subject: [PATCH 183/410] Avoid nullptr as row offsets to cusparseCreateCsr As of CUDA 12.2 additional input validation allows NULL for the row offsets only when rows=0. --- tensorflow/core/kernels/sparse/sparse_mat_mul_op.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/kernels/sparse/sparse_mat_mul_op.cc b/tensorflow/core/kernels/sparse/sparse_mat_mul_op.cc index e6db9bd83d730d..e424db601b0f34 100644 --- a/tensorflow/core/kernels/sparse/sparse_mat_mul_op.cc +++ b/tensorflow/core/kernels/sparse/sparse_mat_mul_op.cc @@ -493,7 +493,7 @@ class CSRSparseMatMulGPUOp : public OpKernel { matC.InitializeCsr( a_input_dense_shape(a_input_dense_shape.size() - 2), b_input_dense_shape(b_input_dense_shape.size() - 1), 0, - nullptr, nullptr, nullptr)); + c_row_ptr.data(), nullptr, nullptr)); // Check required size for buffer1 and possibly re-allocate size_t bufferSize1; From 63abf159be48a9c7e67cd41b13f48411080a888b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 26 Jul 2023 09:04:02 -0700 Subject: [PATCH 184/410] Added attributed to IFRT CallOp that specifies if an atom program is in global or local view. If the attribute is set on a CallOp, then verification logic converts the programs arguments and results from local view to global view to verify that local view shape + sharding is equivalent to the expected global view shape. PiperOrigin-RevId: 551222813 --- .../compiler/xla/python/ifrt/ir/constants.h | 4 + .../compiler/xla/python/ifrt/ir/ifrt_ops.cc | 89 +++++++++++++++++-- .../xla/python/ifrt/ir/tests/verify_call.mlir | 31 +++++++ 3 files changed, 115 insertions(+), 9 deletions(-) diff --git a/tensorflow/compiler/xla/python/ifrt/ir/constants.h b/tensorflow/compiler/xla/python/ifrt/ir/constants.h index ac9841167be216..66c14071134de1 100644 --- a/tensorflow/compiler/xla/python/ifrt/ir/constants.h +++ b/tensorflow/compiler/xla/python/ifrt/ir/constants.h @@ -29,6 +29,10 @@ inline constexpr llvm::StringLiteral kIfrtFunctionAttrName = "ifrt.function"; // Must be used in a FuncOp with `ifrt.function` attr. inline constexpr llvm::StringLiteral kIfrtDonatedArgAttrName = "ifrt.donated"; +// Name of UnitAttr on CallOp used to indicate that the atom program is +// in "local" view (i.e., already sharded). +inline constexpr llvm::StringLiteral kIfrtLocalViewAttrName = "ifrt.local_view"; + } // namespace ifrt } // namespace xla diff --git a/tensorflow/compiler/xla/python/ifrt/ir/ifrt_ops.cc b/tensorflow/compiler/xla/python/ifrt/ir/ifrt_ops.cc index 179e86b7cd6f3e..e5fd010c9ac7bf 100644 --- a/tensorflow/compiler/xla/python/ifrt/ir/ifrt_ops.cc +++ b/tensorflow/compiler/xla/python/ifrt/ir/ifrt_ops.cc @@ -34,6 +34,7 @@ limitations under the License. #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/xla/python/ifrt/ir/constants.h" #include "tensorflow/compiler/xla/python/ifrt/ir/ifrt_dialect.h" // Generated definitions. @@ -59,6 +60,25 @@ mlir::FailureOr GetGlobalShape(mlir::Value value) { return GetGlobalShape(value.getType()); } +mlir::FailureOr GetGlobalShapeFromLocal( + mlir::Type type, ShardingParam shard_param) { + if (auto local_ranked_tensor = type.dyn_cast()) { + llvm::SmallVector global_shape; + auto local_shape = local_ranked_tensor.getShape(); + if (local_shape.size() != shard_param.dim_shards().size()) { + return mlir::failure(); + } + for (auto [idx, dim_shard] : llvm::enumerate(shard_param.dim_shards())) { + global_shape.push_back(dim_shard * local_shape[idx]); + } + return mlir::RankedTensorType::get(global_shape, + local_ranked_tensor.getElementType()); + } else { + // IFRT arrays cannot be in the local view. + return mlir::failure(); + } +} + template mlir::LogicalResult VerifySameGlobalShape(mlir::Operation* op, llvm::StringRef lhs_mnemonic, T lhs, @@ -81,6 +101,40 @@ mlir::LogicalResult VerifySameGlobalShape(mlir::Operation* op, return mlir::success(); } +// Verifies that the global shape of a call op argument/result is the same +// as the global shape of corresponding argument/result of the function in +// local view. +mlir::LogicalResult VerifyGlobalLocalShapesEquivalent( + mlir::Operation* op, llvm::StringRef call_mnemonic, mlir::Value call_value, + llvm::StringRef callee_mnemonic, mlir::Type callee_type) { + // The call values are in the global view. + mlir::FailureOr call_shape = + GetGlobalShape(call_value); + if (mlir::failed(call_shape)) { + return op->emitOpError() << "fails to get global shape from " + << call_mnemonic << ": " << call_value; + } + // The types of the CallOp func signature must be IfrtArrayType. + auto array = call_value.getType().dyn_cast(); + if (array == nullptr) { + return mlir::failure(); + } + // Convert from local shape to global shape using the sharding provided + // by the CallOp func signature. + mlir::FailureOr callee_shape = + GetGlobalShapeFromLocal(callee_type, array.getSharding()); + if (mlir::failed(callee_shape)) { + return op->emitOpError() << "fails to get global shape from " + << callee_mnemonic << ": " << callee_type; + } + if (*call_shape != *callee_shape) { + return op->emitOpError() + << "requires the same global shape. " << call_mnemonic << " " + << *call_shape << " vs " << callee_mnemonic << " " << *callee_shape; + } + return mlir::success(); +} + // Verifies that each of `inputs` and `outputs` is placed on a subset of // `devices`. mlir::LogicalResult VerifyDevicePlacement( @@ -237,7 +291,8 @@ mlir::LogicalResult CallOp::verifySymbolUses( mlir::SymbolTableCollection& symbol_table) { mlir::func::FuncOp callee = getCalleeOp(symbol_table); mlir::FunctionType callee_type = callee.getFunctionType(); - + auto local_view_attr = + (*this)->getAttrOfType(kIfrtLocalViewAttrName); // Verify inputs. if (callee_type.getNumInputs() != getInputs().size()) { return emitOpError() << "requires the same input size. Input " @@ -245,10 +300,18 @@ mlir::LogicalResult CallOp::verifySymbolUses( << callee_type.getNumInputs(); } for (int i = 0; i < callee_type.getNumInputs(); ++i) { - if (mlir::failed(VerifySameGlobalShape( - *this, llvm::Twine("Input #").concat(llvm::Twine(i)).str(), - getInputs()[i], "Callee", callee_type.getInput(i)))) { - return mlir::failure(); + if (local_view_attr == nullptr) { + if (mlir::failed(VerifySameGlobalShape( + *this, llvm::Twine("Input #").concat(llvm::Twine(i)).str(), + getInputs()[i], "Callee", callee_type.getInput(i)))) { + return mlir::failure(); + } + } else { + if (mlir::failed(VerifyGlobalLocalShapesEquivalent( + *this, llvm::Twine("Input #").concat(llvm::Twine(i)).str(), + getInputs()[i], "Callee", callee_type.getInput(i)))) { + return mlir::failure(); + } } } @@ -259,10 +322,18 @@ mlir::LogicalResult CallOp::verifySymbolUses( << callee_type.getNumResults(); } for (int i = 0; i < callee_type.getNumResults(); ++i) { - if (mlir::failed(VerifySameGlobalShape( - *this, llvm::Twine("Output #").concat(llvm::Twine(i)).str(), - getOutputs()[i], "Callee", callee_type.getResult(i)))) { - return mlir::failure(); + if (local_view_attr == nullptr) { + if (mlir::failed(VerifySameGlobalShape( + *this, llvm::Twine("Output #").concat(llvm::Twine(i)).str(), + getOutputs()[i], "Callee", callee_type.getResult(i)))) { + return mlir::failure(); + } + } else { + if (mlir::failed(VerifyGlobalLocalShapesEquivalent( + *this, llvm::Twine("Output #").concat(llvm::Twine(i)).str(), + getOutputs()[i], "Callee", callee_type.getResult(i)))) { + return mlir::failure(); + } } } diff --git a/tensorflow/compiler/xla/python/ifrt/ir/tests/verify_call.mlir b/tensorflow/compiler/xla/python/ifrt/ir/tests/verify_call.mlir index 7bc2353b7b8a4a..444de6ae58b873 100644 --- a/tensorflow/compiler/xla/python/ifrt/ir/tests/verify_call.mlir +++ b/tensorflow/compiler/xla/python/ifrt/ir/tests/verify_call.mlir @@ -311,3 +311,34 @@ func.func @io_aliases_should_have_same_type( func.func @callee(%arg0: tensor<2x2xi32>) -> tensor<2x2xi32> { return %arg0 : tensor<2x2xi32> } + +// ----- + +func.func @good_call_local_view( + %arg0: !ifrt.array, 2x2 to [0, 1] on 2x2, [0,1,2,3]>) + attributes {ifrt.function} { + %0, %ctrl_0 = ifrt.Call @callee(%arg0) on devices [0,1,2,3] {ifrt.local_view} + : (!ifrt.array, 2x2 to [0, 1] on 2x2, [0,1,2,3]>) + -> !ifrt.array, 2x2 to [0, 1] on 2x2, [0,1,2,3]> + return +} + +func.func @callee(%arg0: tensor<2x2xi32>) -> tensor<2x2xi32> { + return %arg0 : tensor<2x2xi32> +} + +// ----- + +func.func @call_local_view_should_have_valid_shape( + %arg0: !ifrt.array, 2x2 to [0, 1] on 2x2, [0,1,2,3]>) + attributes {ifrt.function} { + // expected-error@+1 {{'ifrt.Call' op requires the same global shape. Input #0 'tensor<4x4xi32>' vs Callee 'tensor<8x8xi32>'}} + %0, %ctrl_0 = ifrt.Call @callee(%arg0) on devices [0,1,2,3] {ifrt.local_view} + : (!ifrt.array, 2x2 to [0, 1] on 2x2, [0,1,2,3]>) + -> !ifrt.array, 2x2 to [0, 1] on 2x2, [0,1,2,3]> + return +} + +func.func @callee(%arg0: tensor<4x4xi32>) -> tensor<4x4xi32> { + return %arg0 : tensor<4x4xi32> +} \ No newline at end of file From a034b3d48a9d3dbccff22800ab4b435a89f45103 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 26 Jul 2023 09:10:33 -0700 Subject: [PATCH 185/410] Add python and numpy headers to the local_config_python folder in the wheel. PiperOrigin-RevId: 551224665 --- .../tools/pip_package/build_pip_package.sh | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/tensorflow/tools/pip_package/build_pip_package.sh b/tensorflow/tools/pip_package/build_pip_package.sh index 26e65e373ed87f..7cf9bc4b31352a 100755 --- a/tensorflow/tools/pip_package/build_pip_package.sh +++ b/tensorflow/tools/pip_package/build_pip_package.sh @@ -47,6 +47,22 @@ function cp_external() { cp "${src_dir}/local_config_cuda/cuda/cuda/cuda_config.h" "${dest_dir}/local_config_cuda/cuda/cuda/" } +function cp_local_config_python() { + local src_dir=$1 + local dest_dir=$2 + pushd . + cd "$src_dir" + mkdir -p "${dest_dir}/local_config_python/numpy_include/" + cp -r "pypi_numpy/site-packages/numpy/core/include/numpy" "${dest_dir}/local_config_python/numpy_include/" + mkdir -p "${dest_dir}/local_config_python/python_include/" + if is_windows; then + cp -r python_*/include/* "${dest_dir}/local_config_python/python_include/" + else + cp -r python_*/include/python*/* "${dest_dir}/local_config_python/python_include/" + fi + popd +} + function copy_xla_aot_runtime_sources() { local src_dir=$1 local dst_dir=$2 @@ -158,6 +174,9 @@ function prepare_src() { cp_external \ bazel-bin/tensorflow/tools/pip_package/simple_console_for_window_unzip/runfiles \ "${EXTERNAL_INCLUDES}/" + cp_local_config_python \ + bazel-bin/tensorflow/tools/pip_package/simple_console_for_window_unzip/runfiles \ + "${EXTERNAL_INCLUDES}/" copy_xla_aot_runtime_sources \ bazel-bin/tensorflow/tools/pip_package/simple_console_for_window_unzip/runfiles/org_tensorflow \ "${XLA_AOT_RUNTIME_SOURCES}/" @@ -201,11 +220,17 @@ function prepare_src() { cp_external \ bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow/external \ "${EXTERNAL_INCLUDES}" + cp_local_config_python \ + bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow/external \ + "${EXTERNAL_INCLUDES}" else # New-style runfiles structure (--nolegacy_external_runfiles). cp_external \ bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles \ "${EXTERNAL_INCLUDES}" + cp_local_config_python \ + bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles \ + "${EXTERNAL_INCLUDES}" fi copy_xla_aot_runtime_sources \ bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow \ From daa9a3496fb02ff2713904eb67ddadae91dfcb79 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 26 Jul 2023 09:27:27 -0700 Subject: [PATCH 186/410] Integrate LLVM at llvm/llvm-project@365d6eb1f7d8 Updates LLVM usage to match [365d6eb1f7d8](https://github.com/llvm/llvm-project/commit/365d6eb1f7d8) PiperOrigin-RevId: 551229328 --- .../mlir/tensorflow/tests/legalize_hlo.mlir | 4 +- .../tests/legalize-tf-no-tf2xla-fallback.mlir | 4 +- .../mlir/tf2xla/tests/legalize-tf.mlir | 4 +- .../xla/mlir_hlo/deallocation/utils/util.cc | 15 - third_party/llvm/generated.patch | 11 + third_party/llvm/workspace.bzl | 4 +- third_party/stablehlo/temporary.patch | 3122 +++++++++++++++++ third_party/triton/cl550499635.patch | 17 + third_party/triton/workspace.bzl | 1 + 9 files changed, 3159 insertions(+), 23 deletions(-) create mode 100644 third_party/triton/cl550499635.patch diff --git a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir index 7c5a8261d5fd12..2b11d4f77e2860 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir @@ -721,7 +721,7 @@ func.func @const() -> tensor<2xi32> { // CHECK-LABEL: func @relu( // CHECK-SAME: %[[VAL_0:.*]]: tensor<1xi32>) -> tensor<1xi32> { // CHECK: %[[VAL_1:.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor -// CHECK: %[[VAL_2:.*]] = "tf.Maximum"(%[[VAL_1]], %[[VAL_0]]) : (tensor, tensor<1xi32>) -> tensor<1xi32> +// CHECK: %[[VAL_2:.*]] = "tf.Maximum"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1xi32>, tensor) -> tensor<1xi32> // CHECK: return %[[VAL_2]] : tensor<1xi32> // CHECK: } func.func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> { @@ -733,7 +733,7 @@ func.func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> { // CHECK-LABEL: func @relu_unranked( // CHECK-SAME: %[[VAL_0:.*]]: tensor) -> tensor { // CHECK: %[[VAL_1:.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor -// CHECK: %[[VAL_2:.*]] = "tf.Maximum"(%[[VAL_1]], %[[VAL_0]]) : (tensor, tensor) -> tensor +// CHECK: %[[VAL_2:.*]] = "tf.Maximum"(%[[VAL_0]], %[[VAL_1]]) : (tensor, tensor) -> tensor // CHECK: return %[[VAL_2]] : tensor // CHECK: } func.func @relu_unranked(%arg0: tensor) -> tensor { diff --git a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-no-tf2xla-fallback.mlir b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-no-tf2xla-fallback.mlir index 92fa37f7e444a3..0f75fcd0e65a17 100644 --- a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-no-tf2xla-fallback.mlir +++ b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-no-tf2xla-fallback.mlir @@ -3410,7 +3410,7 @@ func.func @strided_slice_nonconstant_begin_end(%arg0: tensor, %arg1: tensor // CHECK-NEXT: %[[CMP:.*]] = chlo.broadcast_compare %[[INDEX2]], %[[ZERO]] // CHECK-DAG-SAME: {comparison_direction = #mhlo} : (tensor, tensor) -> tensor // CHECK-NEXT: %[[DIM:.*]] = mhlo.constant dense<32> : tensor - // CHECK-NEXT: %[[WRAP:.*]] = chlo.broadcast_add %[[DIM]], %[[INDEX2]] : (tensor, tensor) -> tensor + // CHECK-NEXT: %[[WRAP:.*]] = chlo.broadcast_add %[[INDEX2]], %[[DIM]] : (tensor, tensor) -> tensor // CHECK-NEXT: %[[INDEX3:.*]] = mhlo.select %[[CMP]], %[[WRAP]], %[[INDEX2]] : // CHECK-DAG-SAME: (tensor, tensor, tensor) -> tensor // CHECK-NEXT: %[[SLICED:.*]] = "mhlo.dynamic_slice" @@ -3439,7 +3439,7 @@ func.func @strided_slice_nonconstant_begin_end_with_start_end_mask(%input: tenso // CHECK-NEXT: %[[CMP:.*]] = chlo.broadcast_compare %[[INDEX2]], %[[ZERO]] // CHECK-DAG-SAME: {comparison_direction = #mhlo} : (tensor, tensor) -> tensor // CHECK-NEXT: %[[DIM:.*]] = mhlo.constant dense<32> : tensor - // CHECK-NEXT: %[[WRAP:.*]] = chlo.broadcast_add %[[DIM]], %[[INDEX2]] : (tensor, tensor) -> tensor + // CHECK-NEXT: %[[WRAP:.*]] = chlo.broadcast_add %[[INDEX2]], %[[DIM]] : (tensor, tensor) -> tensor // CHECK-NEXT: %[[INDEX3:.*]] = mhlo.select %[[CMP]], %[[WRAP]], %[[INDEX2]] : // CHECK-DAG-SAME: (tensor, tensor, tensor) -> tensor // CHECK-NEXT: %[[SLICED:.*]] = "mhlo.dynamic_slice" diff --git a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir index 7f523887410c73..1940344f5c04a1 100644 --- a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir @@ -3538,7 +3538,7 @@ func.func @strided_slice_nonconstant_begin_end(%arg0: tensor, %arg1: tensor // CHECK-NEXT: %[[CMP:.*]] = chlo.broadcast_compare %[[INDEX2]], %[[ZERO]] // CHECK-DAG-SAME: {comparison_direction = #mhlo} : (tensor, tensor) -> tensor // CHECK-NEXT: %[[DIM:.*]] = mhlo.constant dense<32> : tensor - // CHECK-NEXT: %[[WRAP:.*]] = chlo.broadcast_add %[[DIM]], %[[INDEX2]] : (tensor, tensor) -> tensor + // CHECK-NEXT: %[[WRAP:.*]] = chlo.broadcast_add %[[INDEX2]], %[[DIM]] : (tensor, tensor) -> tensor // CHECK-NEXT: %[[INDEX3:.*]] = mhlo.select %[[CMP]], %[[WRAP]], %[[INDEX2]] : // CHECK-DAG-SAME: (tensor, tensor, tensor) -> tensor // CHECK-NEXT: %[[SLICED:.*]] = "mhlo.dynamic_slice" @@ -3567,7 +3567,7 @@ func.func @strided_slice_nonconstant_begin_end_with_start_end_mask(%input: tenso // CHECK-NEXT: %[[CMP:.*]] = chlo.broadcast_compare %[[INDEX2]], %[[ZERO]] // CHECK-DAG-SAME: {comparison_direction = #mhlo} : (tensor, tensor) -> tensor // CHECK-NEXT: %[[DIM:.*]] = mhlo.constant dense<32> : tensor - // CHECK-NEXT: %[[WRAP:.*]] = chlo.broadcast_add %[[DIM]], %[[INDEX2]] : (tensor, tensor) -> tensor + // CHECK-NEXT: %[[WRAP:.*]] = chlo.broadcast_add %[[INDEX2]], %[[DIM]] : (tensor, tensor) -> tensor // CHECK-NEXT: %[[INDEX3:.*]] = mhlo.select %[[CMP]], %[[WRAP]], %[[INDEX2]] : // CHECK-DAG-SAME: (tensor, tensor, tensor) -> tensor // CHECK-NEXT: %[[SLICED:.*]] = "mhlo.dynamic_slice" diff --git a/tensorflow/compiler/xla/mlir_hlo/deallocation/utils/util.cc b/tensorflow/compiler/xla/mlir_hlo/deallocation/utils/util.cc index a71f6132cbff8c..c70b568a4317fe 100644 --- a/tensorflow/compiler/xla/mlir_hlo/deallocation/utils/util.cc +++ b/tensorflow/compiler/xla/mlir_hlo/deallocation/utils/util.cc @@ -51,21 +51,6 @@ SmallVector getSuccessorRegions(RegionBranchOpInterface op, } } - // TODO(frgossen): Fix this in the `RegionBranchOpInterface`. - // RegionBranchOpInterface believes for ops are always executed at least once. - if (llvm::isa(op) && !index) { - assert(llvm::none_of(edges, - [](auto& edge) { - return edge.successorRegionIndex == std::nullopt; - }) && - "this was fixed, please remove this if"); - auto& edge = edges.emplace_back(); - edge.successorRegionIndex = edge.predecessorRegionIndex = std::nullopt; - edge.successorOpOrRegion = edge.predecessorOp = op; - edge.successorValueIndex = 0; - edge.predecessorOperandIndex = 3; - } - return edges; } diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch index 509398da979e83..442a292926da98 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch @@ -1 +1,12 @@ Auto generated patch. Do not edit or delete it, even if empty. +diff -ruN --strip-trailing-cr a/utils/bazel/llvm-project-overlay/libc/test/src/stdlib/BUILD.bazel b/utils/bazel/llvm-project-overlay/libc/test/src/stdlib/BUILD.bazel +--- a/utils/bazel/llvm-project-overlay/libc/test/src/stdlib/BUILD.bazel ++++ b/utils/bazel/llvm-project-overlay/libc/test/src/stdlib/BUILD.bazel +@@ -152,6 +152,7 @@ + deps = [ + "//libc:__support_cpp_limits", + "//libc:__support_cpp_type_traits", ++ "//libc:__support_macros_properties_architectures", + "//libc:errno.__internal__", + "//libc/test/UnitTest:LibcUnitTest", + ], diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index c877a9c5b35a34..c7b951aa0465f0 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "1c154bd755153b5c6ada4bbed58facf23f6abffc" - LLVM_SHA256 = "212983fc56c42105e34690097801981974acec667e9753a6465d7a5c07075e0e" + LLVM_COMMIT = "365d6eb1f7d86cf28dc7d4995c3949e9d8bead58" + LLVM_SHA256 = "12156f9a68392d8248c1239a691719cb8fd25dfcc0265acbc3b054a347a58ca8" tf_http_archive( name = name, diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch index ee89192bb082c9..ddd7f6c1f59a86 100644 --- a/third_party/stablehlo/temporary.patch +++ b/third_party/stablehlo/temporary.patch @@ -914,6 +914,3128 @@ diff --ruN a/stablehlo/stablehlo/dialect/StablehloOps.td b/stablehlo/stablehlo/d let arguments = (ins HLO_PredOrIntTensor:$lhs, HLO_PredOrIntTensor:$rhs +diff --ruN a/stablehlo/stablehlo/testdata/acosh_shape_bfloat16_20_20.mlir b/stablehlo/stablehlo/testdata/acosh_shape_bfloat16_20_20.mlir +--- stablehlo/stablehlo/testdata/acosh_shape_bfloat16_20_20.mlir ++++ stablehlo/stablehlo/testdata/acosh_shape_bfloat16_20_20.mlir +@@ -16,9 +16,9 @@ + %10 = stablehlo.constant dense<6.914060e-01> : tensor<20x20xbf16> + %11 = stablehlo.add %8, %10 : tensor<20x20xbf16> + %12 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xbf16> +- %13 = stablehlo.add %12, %0 : tensor<20x20xbf16> ++ %13 = stablehlo.add %0, %12 : tensor<20x20xbf16> + %14 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xbf16> +- %15 = stablehlo.add %14, %0 : tensor<20x20xbf16> ++ %15 = stablehlo.add %0, %14 : tensor<20x20xbf16> + %16 = stablehlo.multiply %13, %15 : tensor<20x20xbf16> + %17 = stablehlo.sqrt %16 : tensor<20x20xbf16> + %18 = stablehlo.add %0, %17 : tensor<20x20xbf16> +diff --ruN a/stablehlo/stablehlo/testdata/acosh_shape_float16_20_20.mlir b/stablehlo/stablehlo/testdata/acosh_shape_float16_20_20.mlir +--- stablehlo/stablehlo/testdata/acosh_shape_float16_20_20.mlir ++++ stablehlo/stablehlo/testdata/acosh_shape_float16_20_20.mlir +@@ -16,9 +16,9 @@ + %10 = stablehlo.constant dense<6.933590e-01> : tensor<20x20xf16> + %11 = stablehlo.add %8, %10 : tensor<20x20xf16> + %12 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf16> +- %13 = stablehlo.add %12, %0 : tensor<20x20xf16> ++ %13 = stablehlo.add %0, %12 : tensor<20x20xf16> + %14 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xf16> +- %15 = stablehlo.add %14, %0 : tensor<20x20xf16> ++ %15 = stablehlo.add %0, %14 : tensor<20x20xf16> + %16 = stablehlo.multiply %13, %15 : tensor<20x20xf16> + %17 = stablehlo.sqrt %16 : tensor<20x20xf16> + %18 = stablehlo.add %0, %17 : tensor<20x20xf16> +diff --ruN a/stablehlo/stablehlo/testdata/acosh_shape_float32_20_20.mlir b/stablehlo/stablehlo/testdata/acosh_shape_float32_20_20.mlir +--- stablehlo/stablehlo/testdata/acosh_shape_float32_20_20.mlir ++++ stablehlo/stablehlo/testdata/acosh_shape_float32_20_20.mlir +@@ -16,9 +16,9 @@ + %10 = stablehlo.constant dense<0.693147182> : tensor<20x20xf32> + %11 = stablehlo.add %8, %10 : tensor<20x20xf32> + %12 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> +- %13 = stablehlo.add %12, %0 : tensor<20x20xf32> ++ %13 = stablehlo.add %0, %12 : tensor<20x20xf32> + %14 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xf32> +- %15 = stablehlo.add %14, %0 : tensor<20x20xf32> ++ %15 = stablehlo.add %0, %14 : tensor<20x20xf32> + %16 = stablehlo.multiply %13, %15 : tensor<20x20xf32> + %17 = stablehlo.sqrt %16 : tensor<20x20xf32> + %18 = stablehlo.add %0, %17 : tensor<20x20xf32> +diff --ruN a/stablehlo/stablehlo/testdata/asin_shape_bfloat16_20_20.mlir b/stablehlo/stablehlo/testdata/asin_shape_bfloat16_20_20.mlir +--- stablehlo/stablehlo/testdata/asin_shape_bfloat16_20_20.mlir ++++ stablehlo/stablehlo/testdata/asin_shape_bfloat16_20_20.mlir +@@ -11,9 +11,9 @@ + %5 = stablehlo.multiply %0, %0 : tensor<20x20xbf16> + %6 = stablehlo.subtract %4, %5 : tensor<20x20xbf16> + %7 = stablehlo.sqrt %6 : tensor<20x20xbf16> +- %8 = stablehlo.add %3, %7 : tensor<20x20xbf16> ++ %8 = stablehlo.add %7, %3 : tensor<20x20xbf16> + %9 = stablehlo.atan2 %0, %8 : tensor<20x20xbf16> +- %10 = stablehlo.multiply %2, %9 : tensor<20x20xbf16> ++ %10 = stablehlo.multiply %9, %2 : tensor<20x20xbf16> + %11 = stablehlo.custom_call @check.eq(%10, %1) : (tensor<20x20xbf16>, tensor<20x20xbf16>) -> tensor + return %11 : tensor + } +diff --ruN a/stablehlo/stablehlo/testdata/asin_shape_complex64_20_20.mlir b/stablehlo/stablehlo/testdata/asin_shape_complex64_20_20.mlir +--- stablehlo/stablehlo/testdata/asin_shape_complex64_20_20.mlir ++++ stablehlo/stablehlo/testdata/asin_shape_complex64_20_20.mlir +@@ -11,9 +11,9 @@ + %5 = stablehlo.multiply %0, %0 : tensor<20x20xcomplex> + %6 = stablehlo.subtract %4, %5 : tensor<20x20xcomplex> + %7 = stablehlo.sqrt %6 : tensor<20x20xcomplex> +- %8 = stablehlo.add %3, %7 : tensor<20x20xcomplex> ++ %8 = stablehlo.add %7, %3 : tensor<20x20xcomplex> + %9 = stablehlo.atan2 %0, %8 : tensor<20x20xcomplex> +- %10 = stablehlo.multiply %2, %9 : tensor<20x20xcomplex> ++ %10 = stablehlo.multiply %9, %2 : tensor<20x20xcomplex> + %11 = stablehlo.custom_call @check.eq(%10, %1) : (tensor<20x20xcomplex>, tensor<20x20xcomplex>) -> tensor + return %11 : tensor + } +diff --ruN a/stablehlo/stablehlo/testdata/asin_shape_float16_20_20.mlir b/stablehlo/stablehlo/testdata/asin_shape_float16_20_20.mlir +--- stablehlo/stablehlo/testdata/asin_shape_float16_20_20.mlir ++++ stablehlo/stablehlo/testdata/asin_shape_float16_20_20.mlir +@@ -11,9 +11,9 @@ + %5 = stablehlo.multiply %0, %0 : tensor<20x20xf16> + %6 = stablehlo.subtract %4, %5 : tensor<20x20xf16> + %7 = stablehlo.sqrt %6 : tensor<20x20xf16> +- %8 = stablehlo.add %3, %7 : tensor<20x20xf16> ++ %8 = stablehlo.add %7, %3 : tensor<20x20xf16> + %9 = stablehlo.atan2 %0, %8 : tensor<20x20xf16> +- %10 = stablehlo.multiply %2, %9 : tensor<20x20xf16> ++ %10 = stablehlo.multiply %9, %2 : tensor<20x20xf16> + %11 = stablehlo.custom_call @check.eq(%10, %1) : (tensor<20x20xf16>, tensor<20x20xf16>) -> tensor + return %11 : tensor + } +diff --ruN a/stablehlo/stablehlo/testdata/asin_shape_float32_20_20.mlir b/stablehlo/stablehlo/testdata/asin_shape_float32_20_20.mlir +--- stablehlo/stablehlo/testdata/asin_shape_float32_20_20.mlir ++++ stablehlo/stablehlo/testdata/asin_shape_float32_20_20.mlir +@@ -11,9 +11,9 @@ + %5 = stablehlo.multiply %0, %0 : tensor<20x20xf32> + %6 = stablehlo.subtract %4, %5 : tensor<20x20xf32> + %7 = stablehlo.sqrt %6 : tensor<20x20xf32> +- %8 = stablehlo.add %3, %7 : tensor<20x20xf32> ++ %8 = stablehlo.add %7, %3 : tensor<20x20xf32> + %9 = stablehlo.atan2 %0, %8 : tensor<20x20xf32> +- %10 = stablehlo.multiply %2, %9 : tensor<20x20xf32> ++ %10 = stablehlo.multiply %9, %2 : tensor<20x20xf32> + %11 = stablehlo.custom_call @check.eq(%10, %1) : (tensor<20x20xf32>, tensor<20x20xf32>) -> tensor + return %11 : tensor + } +diff --ruN a/stablehlo/stablehlo/testdata/asinh_shape_bfloat16_20_20.mlir b/stablehlo/stablehlo/testdata/asinh_shape_bfloat16_20_20.mlir +--- stablehlo/stablehlo/testdata/asinh_shape_bfloat16_20_20.mlir ++++ stablehlo/stablehlo/testdata/asinh_shape_bfloat16_20_20.mlir +@@ -28,7 +28,7 @@ + %22 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xbf16> + %23 = stablehlo.add %21, %22 : tensor<20x20xbf16> + %24 = stablehlo.sqrt %23 : tensor<20x20xbf16> +- %25 = stablehlo.add %18, %24 : tensor<20x20xbf16> ++ %25 = stablehlo.add %24, %18 : tensor<20x20xbf16> + %26 = stablehlo.divide %17, %25 : tensor<20x20xbf16> + %27 = stablehlo.multiply %16, %26 : tensor<20x20xbf16> + %28 = stablehlo.add %15, %27 : tensor<20x20xbf16> +diff --ruN a/stablehlo/stablehlo/testdata/asinh_shape_float16_20_20.mlir b/stablehlo/stablehlo/testdata/asinh_shape_float16_20_20.mlir +--- stablehlo/stablehlo/testdata/asinh_shape_float16_20_20.mlir ++++ stablehlo/stablehlo/testdata/asinh_shape_float16_20_20.mlir +@@ -28,7 +28,7 @@ + %22 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf16> + %23 = stablehlo.add %21, %22 : tensor<20x20xf16> + %24 = stablehlo.sqrt %23 : tensor<20x20xf16> +- %25 = stablehlo.add %18, %24 : tensor<20x20xf16> ++ %25 = stablehlo.add %24, %18 : tensor<20x20xf16> + %26 = stablehlo.divide %17, %25 : tensor<20x20xf16> + %27 = stablehlo.multiply %16, %26 : tensor<20x20xf16> + %28 = stablehlo.add %15, %27 : tensor<20x20xf16> +diff --ruN a/stablehlo/stablehlo/testdata/asinh_shape_float32_20_20.mlir b/stablehlo/stablehlo/testdata/asinh_shape_float32_20_20.mlir +--- stablehlo/stablehlo/testdata/asinh_shape_float32_20_20.mlir ++++ stablehlo/stablehlo/testdata/asinh_shape_float32_20_20.mlir +@@ -28,7 +28,7 @@ + %22 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %23 = stablehlo.add %21, %22 : tensor<20x20xf32> + %24 = stablehlo.sqrt %23 : tensor<20x20xf32> +- %25 = stablehlo.add %18, %24 : tensor<20x20xf32> ++ %25 = stablehlo.add %24, %18 : tensor<20x20xf32> + %26 = stablehlo.divide %17, %25 : tensor<20x20xf32> + %27 = stablehlo.multiply %16, %26 : tensor<20x20xf32> + %28 = stablehlo.add %15, %27 : tensor<20x20xf32> +diff --ruN a/stablehlo/stablehlo/testdata/bessel_i0e_shape_bfloat16_20_20.mlir b/stablehlo/stablehlo/testdata/bessel_i0e_shape_bfloat16_20_20.mlir +--- stablehlo/stablehlo/testdata/bessel_i0e_shape_bfloat16_20_20.mlir ++++ stablehlo/stablehlo/testdata/bessel_i0e_shape_bfloat16_20_20.mlir +@@ -33,7 +33,7 @@ + %8 = stablehlo.constant dense<5.000000e-01> : tensor<20x20xf32> + %9 = stablehlo.constant dense<5.000000e-01> : tensor + %10 = stablehlo.constant dense<5.000000e-01> : tensor<20x20xf32> +- %11 = stablehlo.multiply %10, %3 : tensor<20x20xf32> ++ %11 = stablehlo.multiply %3, %10 : tensor<20x20xf32> + %12 = stablehlo.constant dense<2.000000e+00> : tensor + %13 = stablehlo.constant dense<2.000000e+00> : tensor<20x20xf32> + %14 = stablehlo.subtract %11, %13 : tensor<20x20xf32> +@@ -133,7 +133,7 @@ + %108 = stablehlo.constant dense<0.676795303> : tensor<20x20xf32> + %109 = stablehlo.add %106, %108 : tensor<20x20xf32> + %110 = stablehlo.subtract %109, %99 : tensor<20x20xf32> +- %111 = stablehlo.multiply %8, %110 : tensor<20x20xf32> ++ %111 = stablehlo.multiply %110, %8 : tensor<20x20xf32> + %112 = stablehlo.constant dense<5.000000e-01> : tensor + %113 = stablehlo.constant dense<5.000000e-01> : tensor<20x20xf32> + %114 = stablehlo.constant dense<3.200000e+01> : tensor +@@ -182,7 +182,7 @@ + %157 = stablehlo.constant dense<0.804490387> : tensor<20x20xf32> + %158 = stablehlo.add %155, %157 : tensor<20x20xf32> + %159 = stablehlo.subtract %158, %148 : tensor<20x20xf32> +- %160 = stablehlo.multiply %113, %159 : tensor<20x20xf32> ++ %160 = stablehlo.multiply %159, %113 : tensor<20x20xf32> + %161 = stablehlo.sqrt %3 : tensor<20x20xf32> + %162 = stablehlo.divide %160, %161 : tensor<20x20xf32> + %163 = stablehlo.select %6, %111, %162 : tensor<20x20xi1>, tensor<20x20xf32> +diff --ruN a/stablehlo/stablehlo/testdata/bessel_i0e_shape_float16_20_20.mlir b/stablehlo/stablehlo/testdata/bessel_i0e_shape_float16_20_20.mlir +--- stablehlo/stablehlo/testdata/bessel_i0e_shape_float16_20_20.mlir ++++ stablehlo/stablehlo/testdata/bessel_i0e_shape_float16_20_20.mlir +@@ -33,7 +33,7 @@ + %8 = stablehlo.constant dense<5.000000e-01> : tensor<20x20xf32> + %9 = stablehlo.constant dense<5.000000e-01> : tensor + %10 = stablehlo.constant dense<5.000000e-01> : tensor<20x20xf32> +- %11 = stablehlo.multiply %10, %3 : tensor<20x20xf32> ++ %11 = stablehlo.multiply %3, %10 : tensor<20x20xf32> + %12 = stablehlo.constant dense<2.000000e+00> : tensor + %13 = stablehlo.constant dense<2.000000e+00> : tensor<20x20xf32> + %14 = stablehlo.subtract %11, %13 : tensor<20x20xf32> +@@ -133,7 +133,7 @@ + %108 = stablehlo.constant dense<0.676795303> : tensor<20x20xf32> + %109 = stablehlo.add %106, %108 : tensor<20x20xf32> + %110 = stablehlo.subtract %109, %99 : tensor<20x20xf32> +- %111 = stablehlo.multiply %8, %110 : tensor<20x20xf32> ++ %111 = stablehlo.multiply %110, %8 : tensor<20x20xf32> + %112 = stablehlo.constant dense<5.000000e-01> : tensor + %113 = stablehlo.constant dense<5.000000e-01> : tensor<20x20xf32> + %114 = stablehlo.constant dense<3.200000e+01> : tensor +@@ -182,7 +182,7 @@ + %157 = stablehlo.constant dense<0.804490387> : tensor<20x20xf32> + %158 = stablehlo.add %155, %157 : tensor<20x20xf32> + %159 = stablehlo.subtract %158, %148 : tensor<20x20xf32> +- %160 = stablehlo.multiply %113, %159 : tensor<20x20xf32> ++ %160 = stablehlo.multiply %159, %113 : tensor<20x20xf32> + %161 = stablehlo.sqrt %3 : tensor<20x20xf32> + %162 = stablehlo.divide %160, %161 : tensor<20x20xf32> + %163 = stablehlo.select %6, %111, %162 : tensor<20x20xi1>, tensor<20x20xf32> +diff --ruN a/stablehlo/stablehlo/testdata/bessel_i0e_shape_float32_20_20.mlir b/stablehlo/stablehlo/testdata/bessel_i0e_shape_float32_20_20.mlir +--- stablehlo/stablehlo/testdata/bessel_i0e_shape_float32_20_20.mlir ++++ stablehlo/stablehlo/testdata/bessel_i0e_shape_float32_20_20.mlir +@@ -32,7 +32,7 @@ + %7 = stablehlo.constant dense<5.000000e-01> : tensor<20x20xf32> + %8 = stablehlo.constant dense<5.000000e-01> : tensor + %9 = stablehlo.constant dense<5.000000e-01> : tensor<20x20xf32> +- %10 = stablehlo.multiply %9, %2 : tensor<20x20xf32> ++ %10 = stablehlo.multiply %2, %9 : tensor<20x20xf32> + %11 = stablehlo.constant dense<2.000000e+00> : tensor + %12 = stablehlo.constant dense<2.000000e+00> : tensor<20x20xf32> + %13 = stablehlo.subtract %10, %12 : tensor<20x20xf32> +@@ -132,7 +132,7 @@ + %107 = stablehlo.constant dense<0.676795303> : tensor<20x20xf32> + %108 = stablehlo.add %105, %107 : tensor<20x20xf32> + %109 = stablehlo.subtract %108, %98 : tensor<20x20xf32> +- %110 = stablehlo.multiply %7, %109 : tensor<20x20xf32> ++ %110 = stablehlo.multiply %109, %7 : tensor<20x20xf32> + %111 = stablehlo.constant dense<5.000000e-01> : tensor + %112 = stablehlo.constant dense<5.000000e-01> : tensor<20x20xf32> + %113 = stablehlo.constant dense<3.200000e+01> : tensor +@@ -181,7 +181,7 @@ + %156 = stablehlo.constant dense<0.804490387> : tensor<20x20xf32> + %157 = stablehlo.add %154, %156 : tensor<20x20xf32> + %158 = stablehlo.subtract %157, %147 : tensor<20x20xf32> +- %159 = stablehlo.multiply %112, %158 : tensor<20x20xf32> ++ %159 = stablehlo.multiply %158, %112 : tensor<20x20xf32> + %160 = stablehlo.sqrt %2 : tensor<20x20xf32> + %161 = stablehlo.divide %159, %160 : tensor<20x20xf32> + %162 = stablehlo.select %5, %110, %161 : tensor<20x20xi1>, tensor<20x20xf32> +diff --ruN a/stablehlo/stablehlo/testdata/bessel_i1e_shape_bfloat16_20_20.mlir b/stablehlo/stablehlo/testdata/bessel_i1e_shape_bfloat16_20_20.mlir +--- stablehlo/stablehlo/testdata/bessel_i1e_shape_bfloat16_20_20.mlir ++++ stablehlo/stablehlo/testdata/bessel_i1e_shape_bfloat16_20_20.mlir +@@ -11,7 +11,7 @@ + %5 = stablehlo.constant dense<2.000000e+00> : tensor<20x20xf32> + %6 = stablehlo.constant dense<3.200000e+01> : tensor<20x20xf32> + %7 = stablehlo.constant dense<8.000000e+00> : tensor<20x20xf32> +- %8 = stablehlo.multiply %4, %3 : tensor<20x20xf32> ++ %8 = stablehlo.multiply %3, %4 : tensor<20x20xf32> + %9 = stablehlo.subtract %8, %5 : tensor<20x20xf32> + %10 = stablehlo.constant dense<0.000000e+00> : tensor<20x20xf32> + %11 = stablehlo.constant dense<0.000000e+00> : tensor<20x20xf32> +diff --ruN a/stablehlo/stablehlo/testdata/bessel_i1e_shape_float16_20_20.mlir b/stablehlo/stablehlo/testdata/bessel_i1e_shape_float16_20_20.mlir +--- stablehlo/stablehlo/testdata/bessel_i1e_shape_float16_20_20.mlir ++++ stablehlo/stablehlo/testdata/bessel_i1e_shape_float16_20_20.mlir +@@ -11,7 +11,7 @@ + %5 = stablehlo.constant dense<2.000000e+00> : tensor<20x20xf32> + %6 = stablehlo.constant dense<3.200000e+01> : tensor<20x20xf32> + %7 = stablehlo.constant dense<8.000000e+00> : tensor<20x20xf32> +- %8 = stablehlo.multiply %4, %3 : tensor<20x20xf32> ++ %8 = stablehlo.multiply %3, %4 : tensor<20x20xf32> + %9 = stablehlo.subtract %8, %5 : tensor<20x20xf32> + %10 = stablehlo.constant dense<0.000000e+00> : tensor<20x20xf32> + %11 = stablehlo.constant dense<0.000000e+00> : tensor<20x20xf32> +diff --ruN a/stablehlo/stablehlo/testdata/bessel_i1e_shape_float32_20_20.mlir b/stablehlo/stablehlo/testdata/bessel_i1e_shape_float32_20_20.mlir +--- stablehlo/stablehlo/testdata/bessel_i1e_shape_float32_20_20.mlir ++++ stablehlo/stablehlo/testdata/bessel_i1e_shape_float32_20_20.mlir +@@ -10,7 +10,7 @@ + %4 = stablehlo.constant dense<2.000000e+00> : tensor<20x20xf32> + %5 = stablehlo.constant dense<3.200000e+01> : tensor<20x20xf32> + %6 = stablehlo.constant dense<8.000000e+00> : tensor<20x20xf32> +- %7 = stablehlo.multiply %3, %2 : tensor<20x20xf32> ++ %7 = stablehlo.multiply %2, %3 : tensor<20x20xf32> + %8 = stablehlo.subtract %7, %4 : tensor<20x20xf32> + %9 = stablehlo.constant dense<0.000000e+00> : tensor<20x20xf32> + %10 = stablehlo.constant dense<0.000000e+00> : tensor<20x20xf32> +diff --ruN a/stablehlo/stablehlo/testdata/conv_general_dilated_1d_stride_2_even_enable_xla_True_dynamic.mlir b/stablehlo/stablehlo/testdata/conv_general_dilated_1d_stride_2_even_enable_xla_True_dynamic.mlir +--- stablehlo/stablehlo/testdata/conv_general_dilated_1d_stride_2_even_enable_xla_True_dynamic.mlir ++++ stablehlo/stablehlo/testdata/conv_general_dilated_1d_stride_2_even_enable_xla_True_dynamic.mlir +@@ -16,7 +16,7 @@ + %11 = stablehlo.constant dense<1> : tensor + %12 = stablehlo.subtract %3, %11 : tensor + %13 = stablehlo.select %10, %12, %3 : tensor, tensor +- %14 = stablehlo.multiply %2, %13 : tensor ++ %14 = stablehlo.multiply %13, %2 : tensor + %15 = stablehlo.subtract %1, %14 : tensor + %16 = stablehlo.constant dense<2> : tensor + %17 = stablehlo.add %15, %16 : tensor +@@ -32,7 +32,7 @@ + %27 = stablehlo.constant dense<1> : tensor + %28 = stablehlo.subtract %19, %27 : tensor + %29 = stablehlo.select %26, %28, %19 : tensor, tensor +- %30 = stablehlo.multiply %18, %29 : tensor ++ %30 = stablehlo.multiply %29, %18 : tensor + %31 = stablehlo.subtract %17, %30 : tensor + %32 = stablehlo.constant dense<-1> : tensor + %33 = stablehlo.multiply %arg0, %32 : tensor +@@ -48,7 +48,7 @@ + %43 = stablehlo.constant dense<1> : tensor + %44 = stablehlo.subtract %35, %43 : tensor + %45 = stablehlo.select %42, %44, %35 : tensor, tensor +- %46 = stablehlo.multiply %34, %45 : tensor ++ %46 = stablehlo.multiply %45, %34 : tensor + %47 = stablehlo.subtract %33, %46 : tensor + %48 = stablehlo.constant dense<-1> : tensor + %49 = stablehlo.multiply %arg0, %48 : tensor +@@ -64,7 +64,7 @@ + %59 = stablehlo.constant dense<1> : tensor + %60 = stablehlo.subtract %51, %59 : tensor + %61 = stablehlo.select %58, %60, %51 : tensor, tensor +- %62 = stablehlo.multiply %50, %61 : tensor ++ %62 = stablehlo.multiply %61, %50 : tensor + %63 = stablehlo.subtract %49, %62 : tensor + %64 = stablehlo.constant dense<2> : tensor + %65 = stablehlo.add %63, %64 : tensor +@@ -80,7 +80,7 @@ + %75 = stablehlo.constant dense<1> : tensor + %76 = stablehlo.subtract %67, %75 : tensor + %77 = stablehlo.select %74, %76, %67 : tensor, tensor +- %78 = stablehlo.multiply %66, %77 : tensor ++ %78 = stablehlo.multiply %77, %66 : tensor + %79 = stablehlo.subtract %65, %78 : tensor + %80 = stablehlo.constant dense<-1> : tensor + %81 = stablehlo.multiply %77, %80 : tensor +diff --ruN a/stablehlo/stablehlo/testdata/conv_general_dilated_1d_stride_2_odd_enable_xla_True_dynamic.mlir b/stablehlo/stablehlo/testdata/conv_general_dilated_1d_stride_2_odd_enable_xla_True_dynamic.mlir +--- stablehlo/stablehlo/testdata/conv_general_dilated_1d_stride_2_odd_enable_xla_True_dynamic.mlir ++++ stablehlo/stablehlo/testdata/conv_general_dilated_1d_stride_2_odd_enable_xla_True_dynamic.mlir +@@ -16,7 +16,7 @@ + %11 = stablehlo.constant dense<1> : tensor + %12 = stablehlo.subtract %3, %11 : tensor + %13 = stablehlo.select %10, %12, %3 : tensor, tensor +- %14 = stablehlo.multiply %2, %13 : tensor ++ %14 = stablehlo.multiply %13, %2 : tensor + %15 = stablehlo.subtract %1, %14 : tensor + %16 = stablehlo.constant dense<2> : tensor + %17 = stablehlo.add %15, %16 : tensor +@@ -32,7 +32,7 @@ + %27 = stablehlo.constant dense<1> : tensor + %28 = stablehlo.subtract %19, %27 : tensor + %29 = stablehlo.select %26, %28, %19 : tensor, tensor +- %30 = stablehlo.multiply %18, %29 : tensor ++ %30 = stablehlo.multiply %29, %18 : tensor + %31 = stablehlo.subtract %17, %30 : tensor + %32 = stablehlo.constant dense<-1> : tensor + %33 = stablehlo.multiply %arg0, %32 : tensor +@@ -48,7 +48,7 @@ + %43 = stablehlo.constant dense<1> : tensor + %44 = stablehlo.subtract %35, %43 : tensor + %45 = stablehlo.select %42, %44, %35 : tensor, tensor +- %46 = stablehlo.multiply %34, %45 : tensor ++ %46 = stablehlo.multiply %45, %34 : tensor + %47 = stablehlo.subtract %33, %46 : tensor + %48 = stablehlo.constant dense<-1> : tensor + %49 = stablehlo.multiply %arg0, %48 : tensor +@@ -64,7 +64,7 @@ + %59 = stablehlo.constant dense<1> : tensor + %60 = stablehlo.subtract %51, %59 : tensor + %61 = stablehlo.select %58, %60, %51 : tensor, tensor +- %62 = stablehlo.multiply %50, %61 : tensor ++ %62 = stablehlo.multiply %61, %50 : tensor + %63 = stablehlo.subtract %49, %62 : tensor + %64 = stablehlo.constant dense<2> : tensor + %65 = stablehlo.add %63, %64 : tensor +@@ -80,7 +80,7 @@ + %75 = stablehlo.constant dense<1> : tensor + %76 = stablehlo.subtract %67, %75 : tensor + %77 = stablehlo.select %74, %76, %67 : tensor, tensor +- %78 = stablehlo.multiply %66, %77 : tensor ++ %78 = stablehlo.multiply %77, %66 : tensor + %79 = stablehlo.subtract %65, %78 : tensor + %80 = stablehlo.constant dense<-1> : tensor + %81 = stablehlo.multiply %77, %80 : tensor +diff --ruN a/stablehlo/stablehlo/testdata/digamma_shape_bfloat16_20_20.mlir b/stablehlo/stablehlo/testdata/digamma_shape_bfloat16_20_20.mlir +--- stablehlo/stablehlo/testdata/digamma_shape_bfloat16_20_20.mlir ++++ stablehlo/stablehlo/testdata/digamma_shape_bfloat16_20_20.mlir +@@ -21,7 +21,7 @@ + %15 = stablehlo.divide %11, %14 : tensor<20x20xf32> + %16 = stablehlo.subtract %9, %15 : tensor<20x20xf32> + %17 = stablehlo.divide %11, %13 : tensor<20x20xf32> +- %18 = stablehlo.add %10, %17 : tensor<20x20xf32> ++ %18 = stablehlo.add %17, %10 : tensor<20x20xf32> + %19 = stablehlo.constant dense<-1259.13916> : tensor<20x20xf32> + %20 = stablehlo.constant dense<2.000000e+00> : tensor<20x20xf32> + %21 = stablehlo.add %8, %20 : tensor<20x20xf32> +@@ -79,11 +79,11 @@ + %73 = stablehlo.divide %67, %69 : tensor<20x20xf32> + %74 = stablehlo.add %66, %73 : tensor<20x20xf32> + %75 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> +- %76 = stablehlo.add %75, %8 : tensor<20x20xf32> ++ %76 = stablehlo.add %8, %75 : tensor<20x20xf32> + %77 = stablehlo.constant dense<2.01490307> : tensor<20x20xf32> + %78 = stablehlo.divide %8, %75 : tensor<20x20xf32> + %79 = stablehlo.log_plus_one %78 : tensor<20x20xf32> +- %80 = stablehlo.add %77, %79 : tensor<20x20xf32> ++ %80 = stablehlo.add %79, %77 : tensor<20x20xf32> + %81 = stablehlo.divide %72, %74 : tensor<20x20xf32> + %82 = stablehlo.constant dense<7.000000e+00> : tensor<20x20xf32> + %83 = stablehlo.divide %82, %76 : tensor<20x20xf32> +@@ -95,10 +95,10 @@ + %89 = stablehlo.abs %88 : tensor<20x20xf32> + %90 = stablehlo.add %2, %89 : tensor<20x20xf32> + %91 = stablehlo.constant dense<3.14159274> : tensor<20x20xf32> +- %92 = stablehlo.multiply %91, %90 : tensor<20x20xf32> ++ %92 = stablehlo.multiply %90, %91 : tensor<20x20xf32> + %93 = stablehlo.cosine %92 : tensor<20x20xf32> + %94 = stablehlo.sine %92 : tensor<20x20xf32> +- %95 = stablehlo.multiply %91, %93 : tensor<20x20xf32> ++ %95 = stablehlo.multiply %93, %91 : tensor<20x20xf32> + %96 = stablehlo.divide %95, %94 : tensor<20x20xf32> + %97 = stablehlo.subtract %85, %96 : tensor<20x20xf32> + %98 = stablehlo.select %4, %97, %85 : tensor<20x20xi1>, tensor<20x20xf32> +diff --ruN a/stablehlo/stablehlo/testdata/digamma_shape_float16_20_20.mlir b/stablehlo/stablehlo/testdata/digamma_shape_float16_20_20.mlir +--- stablehlo/stablehlo/testdata/digamma_shape_float16_20_20.mlir ++++ stablehlo/stablehlo/testdata/digamma_shape_float16_20_20.mlir +@@ -21,7 +21,7 @@ + %15 = stablehlo.divide %11, %14 : tensor<20x20xf32> + %16 = stablehlo.subtract %9, %15 : tensor<20x20xf32> + %17 = stablehlo.divide %11, %13 : tensor<20x20xf32> +- %18 = stablehlo.add %10, %17 : tensor<20x20xf32> ++ %18 = stablehlo.add %17, %10 : tensor<20x20xf32> + %19 = stablehlo.constant dense<-1259.13916> : tensor<20x20xf32> + %20 = stablehlo.constant dense<2.000000e+00> : tensor<20x20xf32> + %21 = stablehlo.add %8, %20 : tensor<20x20xf32> +@@ -79,11 +79,11 @@ + %73 = stablehlo.divide %67, %69 : tensor<20x20xf32> + %74 = stablehlo.add %66, %73 : tensor<20x20xf32> + %75 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> +- %76 = stablehlo.add %75, %8 : tensor<20x20xf32> ++ %76 = stablehlo.add %8, %75 : tensor<20x20xf32> + %77 = stablehlo.constant dense<2.01490307> : tensor<20x20xf32> + %78 = stablehlo.divide %8, %75 : tensor<20x20xf32> + %79 = stablehlo.log_plus_one %78 : tensor<20x20xf32> +- %80 = stablehlo.add %77, %79 : tensor<20x20xf32> ++ %80 = stablehlo.add %79, %77 : tensor<20x20xf32> + %81 = stablehlo.divide %72, %74 : tensor<20x20xf32> + %82 = stablehlo.constant dense<7.000000e+00> : tensor<20x20xf32> + %83 = stablehlo.divide %82, %76 : tensor<20x20xf32> +@@ -95,10 +95,10 @@ + %89 = stablehlo.abs %88 : tensor<20x20xf32> + %90 = stablehlo.add %2, %89 : tensor<20x20xf32> + %91 = stablehlo.constant dense<3.14159274> : tensor<20x20xf32> +- %92 = stablehlo.multiply %91, %90 : tensor<20x20xf32> ++ %92 = stablehlo.multiply %90, %91 : tensor<20x20xf32> + %93 = stablehlo.cosine %92 : tensor<20x20xf32> + %94 = stablehlo.sine %92 : tensor<20x20xf32> +- %95 = stablehlo.multiply %91, %93 : tensor<20x20xf32> ++ %95 = stablehlo.multiply %93, %91 : tensor<20x20xf32> + %96 = stablehlo.divide %95, %94 : tensor<20x20xf32> + %97 = stablehlo.subtract %85, %96 : tensor<20x20xf32> + %98 = stablehlo.select %4, %97, %85 : tensor<20x20xi1>, tensor<20x20xf32> +diff --ruN a/stablehlo/stablehlo/testdata/digamma_shape_float32_20_20.mlir b/stablehlo/stablehlo/testdata/digamma_shape_float32_20_20.mlir +--- stablehlo/stablehlo/testdata/digamma_shape_float32_20_20.mlir ++++ stablehlo/stablehlo/testdata/digamma_shape_float32_20_20.mlir +@@ -20,7 +20,7 @@ + %14 = stablehlo.divide %10, %13 : tensor<20x20xf32> + %15 = stablehlo.subtract %8, %14 : tensor<20x20xf32> + %16 = stablehlo.divide %10, %12 : tensor<20x20xf32> +- %17 = stablehlo.add %9, %16 : tensor<20x20xf32> ++ %17 = stablehlo.add %16, %9 : tensor<20x20xf32> + %18 = stablehlo.constant dense<-1259.13916> : tensor<20x20xf32> + %19 = stablehlo.constant dense<2.000000e+00> : tensor<20x20xf32> + %20 = stablehlo.add %7, %19 : tensor<20x20xf32> +@@ -78,11 +78,11 @@ + %72 = stablehlo.divide %66, %68 : tensor<20x20xf32> + %73 = stablehlo.add %65, %72 : tensor<20x20xf32> + %74 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> +- %75 = stablehlo.add %74, %7 : tensor<20x20xf32> ++ %75 = stablehlo.add %7, %74 : tensor<20x20xf32> + %76 = stablehlo.constant dense<2.01490307> : tensor<20x20xf32> + %77 = stablehlo.divide %7, %74 : tensor<20x20xf32> + %78 = stablehlo.log_plus_one %77 : tensor<20x20xf32> +- %79 = stablehlo.add %76, %78 : tensor<20x20xf32> ++ %79 = stablehlo.add %78, %76 : tensor<20x20xf32> + %80 = stablehlo.divide %71, %73 : tensor<20x20xf32> + %81 = stablehlo.constant dense<7.000000e+00> : tensor<20x20xf32> + %82 = stablehlo.divide %81, %75 : tensor<20x20xf32> +@@ -94,10 +94,10 @@ + %88 = stablehlo.abs %87 : tensor<20x20xf32> + %89 = stablehlo.add %0, %88 : tensor<20x20xf32> + %90 = stablehlo.constant dense<3.14159274> : tensor<20x20xf32> +- %91 = stablehlo.multiply %90, %89 : tensor<20x20xf32> ++ %91 = stablehlo.multiply %89, %90 : tensor<20x20xf32> + %92 = stablehlo.cosine %91 : tensor<20x20xf32> + %93 = stablehlo.sine %91 : tensor<20x20xf32> +- %94 = stablehlo.multiply %90, %92 : tensor<20x20xf32> ++ %94 = stablehlo.multiply %92, %90 : tensor<20x20xf32> + %95 = stablehlo.divide %94, %93 : tensor<20x20xf32> + %96 = stablehlo.subtract %84, %95 : tensor<20x20xf32> + %97 = stablehlo.select %3, %96, %84 : tensor<20x20xi1>, tensor<20x20xf32> +diff --ruN a/stablehlo/stablehlo/testdata/erf_shape_bfloat16_20_20.mlir b/stablehlo/stablehlo/testdata/erf_shape_bfloat16_20_20.mlir +--- stablehlo/stablehlo/testdata/erf_shape_bfloat16_20_20.mlir ++++ stablehlo/stablehlo/testdata/erf_shape_bfloat16_20_20.mlir +@@ -11,7 +11,7 @@ + %5 = stablehlo.clamp %3, %2, %4 : tensor<20x20xf32> + %6 = stablehlo.multiply %5, %5 : tensor<20x20xf32> + %7 = stablehlo.constant dense<0.000000e+00> : tensor<20x20xf32> +- %8 = stablehlo.multiply %7, %6 : tensor<20x20xf32> ++ %8 = stablehlo.multiply %6, %7 : tensor<20x20xf32> + %9 = stablehlo.constant dense<-2.72614237E-10> : tensor<20x20xf32> + %10 = stablehlo.add %8, %9 : tensor<20x20xf32> + %11 = stablehlo.multiply %10, %6 : tensor<20x20xf32> +@@ -33,7 +33,7 @@ + %27 = stablehlo.constant dense<-0.0160960332> : tensor<20x20xf32> + %28 = stablehlo.add %26, %27 : tensor<20x20xf32> + %29 = stablehlo.constant dense<0.000000e+00> : tensor<20x20xf32> +- %30 = stablehlo.multiply %29, %6 : tensor<20x20xf32> ++ %30 = stablehlo.multiply %6, %29 : tensor<20x20xf32> + %31 = stablehlo.constant dense<-1.45660715E-5> : tensor<20x20xf32> + %32 = stablehlo.add %30, %31 : tensor<20x20xf32> + %33 = stablehlo.multiply %32, %6 : tensor<20x20xf32> +diff --ruN a/stablehlo/stablehlo/testdata/erf_shape_float16_20_20.mlir b/stablehlo/stablehlo/testdata/erf_shape_float16_20_20.mlir +--- stablehlo/stablehlo/testdata/erf_shape_float16_20_20.mlir ++++ stablehlo/stablehlo/testdata/erf_shape_float16_20_20.mlir +@@ -11,7 +11,7 @@ + %5 = stablehlo.clamp %3, %2, %4 : tensor<20x20xf32> + %6 = stablehlo.multiply %5, %5 : tensor<20x20xf32> + %7 = stablehlo.constant dense<0.000000e+00> : tensor<20x20xf32> +- %8 = stablehlo.multiply %7, %6 : tensor<20x20xf32> ++ %8 = stablehlo.multiply %6, %7 : tensor<20x20xf32> + %9 = stablehlo.constant dense<-2.72614237E-10> : tensor<20x20xf32> + %10 = stablehlo.add %8, %9 : tensor<20x20xf32> + %11 = stablehlo.multiply %10, %6 : tensor<20x20xf32> +@@ -33,7 +33,7 @@ + %27 = stablehlo.constant dense<-0.0160960332> : tensor<20x20xf32> + %28 = stablehlo.add %26, %27 : tensor<20x20xf32> + %29 = stablehlo.constant dense<0.000000e+00> : tensor<20x20xf32> +- %30 = stablehlo.multiply %29, %6 : tensor<20x20xf32> ++ %30 = stablehlo.multiply %6, %29 : tensor<20x20xf32> + %31 = stablehlo.constant dense<-1.45660715E-5> : tensor<20x20xf32> + %32 = stablehlo.add %30, %31 : tensor<20x20xf32> + %33 = stablehlo.multiply %32, %6 : tensor<20x20xf32> +diff --ruN a/stablehlo/stablehlo/testdata/erf_shape_float32_20_20.mlir b/stablehlo/stablehlo/testdata/erf_shape_float32_20_20.mlir +--- stablehlo/stablehlo/testdata/erf_shape_float32_20_20.mlir ++++ stablehlo/stablehlo/testdata/erf_shape_float32_20_20.mlir +@@ -10,7 +10,7 @@ + %4 = stablehlo.clamp %2, %0, %3 : tensor<20x20xf32> + %5 = stablehlo.multiply %4, %4 : tensor<20x20xf32> + %6 = stablehlo.constant dense<0.000000e+00> : tensor<20x20xf32> +- %7 = stablehlo.multiply %6, %5 : tensor<20x20xf32> ++ %7 = stablehlo.multiply %5, %6 : tensor<20x20xf32> + %8 = stablehlo.constant dense<-2.72614237E-10> : tensor<20x20xf32> + %9 = stablehlo.add %7, %8 : tensor<20x20xf32> + %10 = stablehlo.multiply %9, %5 : tensor<20x20xf32> +@@ -32,7 +32,7 @@ + %26 = stablehlo.constant dense<-0.0160960332> : tensor<20x20xf32> + %27 = stablehlo.add %25, %26 : tensor<20x20xf32> + %28 = stablehlo.constant dense<0.000000e+00> : tensor<20x20xf32> +- %29 = stablehlo.multiply %28, %5 : tensor<20x20xf32> ++ %29 = stablehlo.multiply %5, %28 : tensor<20x20xf32> + %30 = stablehlo.constant dense<-1.45660715E-5> : tensor<20x20xf32> + %31 = stablehlo.add %29, %30 : tensor<20x20xf32> + %32 = stablehlo.multiply %31, %5 : tensor<20x20xf32> +diff --ruN a/stablehlo/stablehlo/testdata/erfc_shape_bfloat16_20_20.mlir b/stablehlo/stablehlo/testdata/erfc_shape_bfloat16_20_20.mlir +--- stablehlo/stablehlo/testdata/erfc_shape_bfloat16_20_20.mlir ++++ stablehlo/stablehlo/testdata/erfc_shape_bfloat16_20_20.mlir +@@ -17,7 +17,7 @@ + %11 = stablehlo.constant dense<2.000000e+00> : tensor<20x20xf32> + %12 = stablehlo.compare LT, %5, %11 : (tensor<20x20xf32>, tensor<20x20xf32>) -> tensor<20x20xi1> + %13 = stablehlo.constant dense<0.000000e+00> : tensor<20x20xf32> +- %14 = stablehlo.multiply %13, %7 : tensor<20x20xf32> ++ %14 = stablehlo.multiply %7, %13 : tensor<20x20xf32> + %15 = stablehlo.constant dense<2.326820e-02> : tensor<20x20xf32> + %16 = stablehlo.add %14, %15 : tensor<20x20xf32> + %17 = stablehlo.multiply %16, %7 : tensor<20x20xf32> +@@ -45,7 +45,7 @@ + %39 = stablehlo.constant dense<0.563825965> : tensor<20x20xf32> + %40 = stablehlo.add %38, %39 : tensor<20x20xf32> + %41 = stablehlo.constant dense<0.000000e+00> : tensor<20x20xf32> +- %42 = stablehlo.multiply %41, %7 : tensor<20x20xf32> ++ %42 = stablehlo.multiply %7, %41 : tensor<20x20xf32> + %43 = stablehlo.constant dense<-10.477664> : tensor<20x20xf32> + %44 = stablehlo.add %42, %43 : tensor<20x20xf32> + %45 = stablehlo.multiply %44, %7 : tensor<20x20xf32> +@@ -81,7 +81,7 @@ + %75 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %76 = stablehlo.multiply %2, %2 : tensor<20x20xf32> + %77 = stablehlo.constant dense<0.000000e+00> : tensor<20x20xf32> +- %78 = stablehlo.multiply %77, %76 : tensor<20x20xf32> ++ %78 = stablehlo.multiply %76, %77 : tensor<20x20xf32> + %79 = stablehlo.constant dense<7.85386146E-5> : tensor<20x20xf32> + %80 = stablehlo.add %78, %79 : tensor<20x20xf32> + %81 = stablehlo.multiply %80, %76 : tensor<20x20xf32> +diff --ruN a/stablehlo/stablehlo/testdata/erfc_shape_float16_20_20.mlir b/stablehlo/stablehlo/testdata/erfc_shape_float16_20_20.mlir +--- stablehlo/stablehlo/testdata/erfc_shape_float16_20_20.mlir ++++ stablehlo/stablehlo/testdata/erfc_shape_float16_20_20.mlir +@@ -17,7 +17,7 @@ + %11 = stablehlo.constant dense<2.000000e+00> : tensor<20x20xf32> + %12 = stablehlo.compare LT, %5, %11 : (tensor<20x20xf32>, tensor<20x20xf32>) -> tensor<20x20xi1> + %13 = stablehlo.constant dense<0.000000e+00> : tensor<20x20xf32> +- %14 = stablehlo.multiply %13, %7 : tensor<20x20xf32> ++ %14 = stablehlo.multiply %7, %13 : tensor<20x20xf32> + %15 = stablehlo.constant dense<2.326820e-02> : tensor<20x20xf32> + %16 = stablehlo.add %14, %15 : tensor<20x20xf32> + %17 = stablehlo.multiply %16, %7 : tensor<20x20xf32> +@@ -45,7 +45,7 @@ + %39 = stablehlo.constant dense<0.563825965> : tensor<20x20xf32> + %40 = stablehlo.add %38, %39 : tensor<20x20xf32> + %41 = stablehlo.constant dense<0.000000e+00> : tensor<20x20xf32> +- %42 = stablehlo.multiply %41, %7 : tensor<20x20xf32> ++ %42 = stablehlo.multiply %7, %41 : tensor<20x20xf32> + %43 = stablehlo.constant dense<-10.477664> : tensor<20x20xf32> + %44 = stablehlo.add %42, %43 : tensor<20x20xf32> + %45 = stablehlo.multiply %44, %7 : tensor<20x20xf32> +@@ -81,7 +81,7 @@ + %75 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %76 = stablehlo.multiply %2, %2 : tensor<20x20xf32> + %77 = stablehlo.constant dense<0.000000e+00> : tensor<20x20xf32> +- %78 = stablehlo.multiply %77, %76 : tensor<20x20xf32> ++ %78 = stablehlo.multiply %76, %77 : tensor<20x20xf32> + %79 = stablehlo.constant dense<7.85386146E-5> : tensor<20x20xf32> + %80 = stablehlo.add %78, %79 : tensor<20x20xf32> + %81 = stablehlo.multiply %80, %76 : tensor<20x20xf32> +diff --ruN a/stablehlo/stablehlo/testdata/erfc_shape_float32_20_20.mlir b/stablehlo/stablehlo/testdata/erfc_shape_float32_20_20.mlir +--- stablehlo/stablehlo/testdata/erfc_shape_float32_20_20.mlir ++++ stablehlo/stablehlo/testdata/erfc_shape_float32_20_20.mlir +@@ -16,7 +16,7 @@ + %10 = stablehlo.constant dense<2.000000e+00> : tensor<20x20xf32> + %11 = stablehlo.compare LT, %4, %10 : (tensor<20x20xf32>, tensor<20x20xf32>) -> tensor<20x20xi1> + %12 = stablehlo.constant dense<0.000000e+00> : tensor<20x20xf32> +- %13 = stablehlo.multiply %12, %6 : tensor<20x20xf32> ++ %13 = stablehlo.multiply %6, %12 : tensor<20x20xf32> + %14 = stablehlo.constant dense<2.326820e-02> : tensor<20x20xf32> + %15 = stablehlo.add %13, %14 : tensor<20x20xf32> + %16 = stablehlo.multiply %15, %6 : tensor<20x20xf32> +@@ -44,7 +44,7 @@ + %38 = stablehlo.constant dense<0.563825965> : tensor<20x20xf32> + %39 = stablehlo.add %37, %38 : tensor<20x20xf32> + %40 = stablehlo.constant dense<0.000000e+00> : tensor<20x20xf32> +- %41 = stablehlo.multiply %40, %6 : tensor<20x20xf32> ++ %41 = stablehlo.multiply %6, %40 : tensor<20x20xf32> + %42 = stablehlo.constant dense<-10.477664> : tensor<20x20xf32> + %43 = stablehlo.add %41, %42 : tensor<20x20xf32> + %44 = stablehlo.multiply %43, %6 : tensor<20x20xf32> +@@ -80,7 +80,7 @@ + %74 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %75 = stablehlo.multiply %0, %0 : tensor<20x20xf32> + %76 = stablehlo.constant dense<0.000000e+00> : tensor<20x20xf32> +- %77 = stablehlo.multiply %76, %75 : tensor<20x20xf32> ++ %77 = stablehlo.multiply %75, %76 : tensor<20x20xf32> + %78 = stablehlo.constant dense<7.85386146E-5> : tensor<20x20xf32> + %79 = stablehlo.add %77, %78 : tensor<20x20xf32> + %80 = stablehlo.multiply %79, %75 : tensor<20x20xf32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_dtypes_shape_bfloat16_10__axis_0_enable_xla_True.mlir b/stablehlo/stablehlo/testdata/gather_dtypes_shape_bfloat16_10__axis_0_enable_xla_True.mlir +--- stablehlo/stablehlo/testdata/gather_dtypes_shape_bfloat16_10__axis_0_enable_xla_True.mlir ++++ stablehlo/stablehlo/testdata/gather_dtypes_shape_bfloat16_10__axis_0_enable_xla_True.mlir +@@ -34,7 +34,7 @@ + %12 = stablehlo.compare LT, %6, %11, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %13 = stablehlo.constant dense<1> : tensor + %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1xi32> +- %15 = stablehlo.add %6, %14 : tensor<1xi32> ++ %15 = stablehlo.add %14, %6 : tensor<1xi32> + %16 = stablehlo.select %12, %15, %6 : tensor<1xi1>, tensor<1xi32> + %17 = stablehlo.broadcast_in_dim %16, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %18 = "stablehlo.gather"(%9, %17) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -45,7 +45,7 @@ + %23 = stablehlo.compare LT, %7, %22, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %24 = stablehlo.constant dense<1> : tensor + %25 = stablehlo.broadcast_in_dim %24, dims = [] : (tensor) -> tensor<1xi32> +- %26 = stablehlo.add %7, %25 : tensor<1xi32> ++ %26 = stablehlo.add %25, %7 : tensor<1xi32> + %27 = stablehlo.select %23, %26, %7 : tensor<1xi1>, tensor<1xi32> + %28 = stablehlo.broadcast_in_dim %27, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %29 = "stablehlo.gather"(%20, %28) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_dtypes_shape_bool_10__axis_0_enable_xla_True.mlir b/stablehlo/stablehlo/testdata/gather_dtypes_shape_bool_10__axis_0_enable_xla_True.mlir +--- stablehlo/stablehlo/testdata/gather_dtypes_shape_bool_10__axis_0_enable_xla_True.mlir ++++ stablehlo/stablehlo/testdata/gather_dtypes_shape_bool_10__axis_0_enable_xla_True.mlir +@@ -34,7 +34,7 @@ + %12 = stablehlo.compare LT, %6, %11, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %13 = stablehlo.constant dense<1> : tensor + %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1xi32> +- %15 = stablehlo.add %6, %14 : tensor<1xi32> ++ %15 = stablehlo.add %14, %6 : tensor<1xi32> + %16 = stablehlo.select %12, %15, %6 : tensor<1xi1>, tensor<1xi32> + %17 = stablehlo.broadcast_in_dim %16, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %18 = "stablehlo.gather"(%9, %17) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -45,7 +45,7 @@ + %23 = stablehlo.compare LT, %7, %22, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %24 = stablehlo.constant dense<1> : tensor + %25 = stablehlo.broadcast_in_dim %24, dims = [] : (tensor) -> tensor<1xi32> +- %26 = stablehlo.add %7, %25 : tensor<1xi32> ++ %26 = stablehlo.add %25, %7 : tensor<1xi32> + %27 = stablehlo.select %23, %26, %7 : tensor<1xi1>, tensor<1xi32> + %28 = stablehlo.broadcast_in_dim %27, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %29 = "stablehlo.gather"(%20, %28) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_dtypes_shape_complex64_10__axis_0_enable_xla_True.mlir b/stablehlo/stablehlo/testdata/gather_dtypes_shape_complex64_10__axis_0_enable_xla_True.mlir +--- stablehlo/stablehlo/testdata/gather_dtypes_shape_complex64_10__axis_0_enable_xla_True.mlir ++++ stablehlo/stablehlo/testdata/gather_dtypes_shape_complex64_10__axis_0_enable_xla_True.mlir +@@ -34,7 +34,7 @@ + %12 = stablehlo.compare LT, %6, %11, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %13 = stablehlo.constant dense<1> : tensor + %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1xi32> +- %15 = stablehlo.add %6, %14 : tensor<1xi32> ++ %15 = stablehlo.add %14, %6 : tensor<1xi32> + %16 = stablehlo.select %12, %15, %6 : tensor<1xi1>, tensor<1xi32> + %17 = stablehlo.broadcast_in_dim %16, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %18 = "stablehlo.gather"(%9, %17) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -45,7 +45,7 @@ + %23 = stablehlo.compare LT, %7, %22, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %24 = stablehlo.constant dense<1> : tensor + %25 = stablehlo.broadcast_in_dim %24, dims = [] : (tensor) -> tensor<1xi32> +- %26 = stablehlo.add %7, %25 : tensor<1xi32> ++ %26 = stablehlo.add %25, %7 : tensor<1xi32> + %27 = stablehlo.select %23, %26, %7 : tensor<1xi1>, tensor<1xi32> + %28 = stablehlo.broadcast_in_dim %27, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %29 = "stablehlo.gather"(%20, %28) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_dtypes_shape_float16_10__axis_0_enable_xla_True.mlir b/stablehlo/stablehlo/testdata/gather_dtypes_shape_float16_10__axis_0_enable_xla_True.mlir +--- stablehlo/stablehlo/testdata/gather_dtypes_shape_float16_10__axis_0_enable_xla_True.mlir ++++ stablehlo/stablehlo/testdata/gather_dtypes_shape_float16_10__axis_0_enable_xla_True.mlir +@@ -34,7 +34,7 @@ + %12 = stablehlo.compare LT, %6, %11, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %13 = stablehlo.constant dense<1> : tensor + %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1xi32> +- %15 = stablehlo.add %6, %14 : tensor<1xi32> ++ %15 = stablehlo.add %14, %6 : tensor<1xi32> + %16 = stablehlo.select %12, %15, %6 : tensor<1xi1>, tensor<1xi32> + %17 = stablehlo.broadcast_in_dim %16, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %18 = "stablehlo.gather"(%9, %17) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -45,7 +45,7 @@ + %23 = stablehlo.compare LT, %7, %22, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %24 = stablehlo.constant dense<1> : tensor + %25 = stablehlo.broadcast_in_dim %24, dims = [] : (tensor) -> tensor<1xi32> +- %26 = stablehlo.add %7, %25 : tensor<1xi32> ++ %26 = stablehlo.add %25, %7 : tensor<1xi32> + %27 = stablehlo.select %23, %26, %7 : tensor<1xi1>, tensor<1xi32> + %28 = stablehlo.broadcast_in_dim %27, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %29 = "stablehlo.gather"(%20, %28) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_dtypes_shape_float32_10__axis_0_enable_xla_True.mlir b/stablehlo/stablehlo/testdata/gather_dtypes_shape_float32_10__axis_0_enable_xla_True.mlir +--- stablehlo/stablehlo/testdata/gather_dtypes_shape_float32_10__axis_0_enable_xla_True.mlir ++++ stablehlo/stablehlo/testdata/gather_dtypes_shape_float32_10__axis_0_enable_xla_True.mlir +@@ -34,7 +34,7 @@ + %12 = stablehlo.compare LT, %6, %11, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %13 = stablehlo.constant dense<1> : tensor + %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1xi32> +- %15 = stablehlo.add %6, %14 : tensor<1xi32> ++ %15 = stablehlo.add %14, %6 : tensor<1xi32> + %16 = stablehlo.select %12, %15, %6 : tensor<1xi1>, tensor<1xi32> + %17 = stablehlo.broadcast_in_dim %16, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %18 = "stablehlo.gather"(%9, %17) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -45,7 +45,7 @@ + %23 = stablehlo.compare LT, %7, %22, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %24 = stablehlo.constant dense<1> : tensor + %25 = stablehlo.broadcast_in_dim %24, dims = [] : (tensor) -> tensor<1xi32> +- %26 = stablehlo.add %7, %25 : tensor<1xi32> ++ %26 = stablehlo.add %25, %7 : tensor<1xi32> + %27 = stablehlo.select %23, %26, %7 : tensor<1xi1>, tensor<1xi32> + %28 = stablehlo.broadcast_in_dim %27, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %29 = "stablehlo.gather"(%20, %28) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_dtypes_shape_int16_10__axis_0_enable_xla_True.mlir b/stablehlo/stablehlo/testdata/gather_dtypes_shape_int16_10__axis_0_enable_xla_True.mlir +--- stablehlo/stablehlo/testdata/gather_dtypes_shape_int16_10__axis_0_enable_xla_True.mlir ++++ stablehlo/stablehlo/testdata/gather_dtypes_shape_int16_10__axis_0_enable_xla_True.mlir +@@ -34,7 +34,7 @@ + %12 = stablehlo.compare LT, %6, %11, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %13 = stablehlo.constant dense<1> : tensor + %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1xi32> +- %15 = stablehlo.add %6, %14 : tensor<1xi32> ++ %15 = stablehlo.add %14, %6 : tensor<1xi32> + %16 = stablehlo.select %12, %15, %6 : tensor<1xi1>, tensor<1xi32> + %17 = stablehlo.broadcast_in_dim %16, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %18 = "stablehlo.gather"(%9, %17) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -45,7 +45,7 @@ + %23 = stablehlo.compare LT, %7, %22, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %24 = stablehlo.constant dense<1> : tensor + %25 = stablehlo.broadcast_in_dim %24, dims = [] : (tensor) -> tensor<1xi32> +- %26 = stablehlo.add %7, %25 : tensor<1xi32> ++ %26 = stablehlo.add %25, %7 : tensor<1xi32> + %27 = stablehlo.select %23, %26, %7 : tensor<1xi1>, tensor<1xi32> + %28 = stablehlo.broadcast_in_dim %27, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %29 = "stablehlo.gather"(%20, %28) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_dtypes_shape_int32_10__axis_0_enable_xla_True.mlir b/stablehlo/stablehlo/testdata/gather_dtypes_shape_int32_10__axis_0_enable_xla_True.mlir +--- stablehlo/stablehlo/testdata/gather_dtypes_shape_int32_10__axis_0_enable_xla_True.mlir ++++ stablehlo/stablehlo/testdata/gather_dtypes_shape_int32_10__axis_0_enable_xla_True.mlir +@@ -34,7 +34,7 @@ + %12 = stablehlo.compare LT, %6, %11, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %13 = stablehlo.constant dense<1> : tensor + %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1xi32> +- %15 = stablehlo.add %6, %14 : tensor<1xi32> ++ %15 = stablehlo.add %14, %6 : tensor<1xi32> + %16 = stablehlo.select %12, %15, %6 : tensor<1xi1>, tensor<1xi32> + %17 = stablehlo.broadcast_in_dim %16, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %18 = "stablehlo.gather"(%9, %17) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -45,7 +45,7 @@ + %23 = stablehlo.compare LT, %7, %22, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %24 = stablehlo.constant dense<1> : tensor + %25 = stablehlo.broadcast_in_dim %24, dims = [] : (tensor) -> tensor<1xi32> +- %26 = stablehlo.add %7, %25 : tensor<1xi32> ++ %26 = stablehlo.add %25, %7 : tensor<1xi32> + %27 = stablehlo.select %23, %26, %7 : tensor<1xi1>, tensor<1xi32> + %28 = stablehlo.broadcast_in_dim %27, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %29 = "stablehlo.gather"(%20, %28) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_dtypes_shape_int8_10__axis_0_enable_xla_True.mlir b/stablehlo/stablehlo/testdata/gather_dtypes_shape_int8_10__axis_0_enable_xla_True.mlir +--- stablehlo/stablehlo/testdata/gather_dtypes_shape_int8_10__axis_0_enable_xla_True.mlir ++++ stablehlo/stablehlo/testdata/gather_dtypes_shape_int8_10__axis_0_enable_xla_True.mlir +@@ -34,7 +34,7 @@ + %12 = stablehlo.compare LT, %6, %11, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %13 = stablehlo.constant dense<1> : tensor + %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1xi32> +- %15 = stablehlo.add %6, %14 : tensor<1xi32> ++ %15 = stablehlo.add %14, %6 : tensor<1xi32> + %16 = stablehlo.select %12, %15, %6 : tensor<1xi1>, tensor<1xi32> + %17 = stablehlo.broadcast_in_dim %16, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %18 = "stablehlo.gather"(%9, %17) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -45,7 +45,7 @@ + %23 = stablehlo.compare LT, %7, %22, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %24 = stablehlo.constant dense<1> : tensor + %25 = stablehlo.broadcast_in_dim %24, dims = [] : (tensor) -> tensor<1xi32> +- %26 = stablehlo.add %7, %25 : tensor<1xi32> ++ %26 = stablehlo.add %25, %7 : tensor<1xi32> + %27 = stablehlo.select %23, %26, %7 : tensor<1xi1>, tensor<1xi32> + %28 = stablehlo.broadcast_in_dim %27, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %29 = "stablehlo.gather"(%20, %28) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_dtypes_shape_uint16_10__axis_0_enable_xla_True.mlir b/stablehlo/stablehlo/testdata/gather_dtypes_shape_uint16_10__axis_0_enable_xla_True.mlir +--- stablehlo/stablehlo/testdata/gather_dtypes_shape_uint16_10__axis_0_enable_xla_True.mlir ++++ stablehlo/stablehlo/testdata/gather_dtypes_shape_uint16_10__axis_0_enable_xla_True.mlir +@@ -34,7 +34,7 @@ + %12 = stablehlo.compare LT, %6, %11, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %13 = stablehlo.constant dense<1> : tensor + %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1xi32> +- %15 = stablehlo.add %6, %14 : tensor<1xi32> ++ %15 = stablehlo.add %14, %6 : tensor<1xi32> + %16 = stablehlo.select %12, %15, %6 : tensor<1xi1>, tensor<1xi32> + %17 = stablehlo.broadcast_in_dim %16, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %18 = "stablehlo.gather"(%9, %17) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -45,7 +45,7 @@ + %23 = stablehlo.compare LT, %7, %22, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %24 = stablehlo.constant dense<1> : tensor + %25 = stablehlo.broadcast_in_dim %24, dims = [] : (tensor) -> tensor<1xi32> +- %26 = stablehlo.add %7, %25 : tensor<1xi32> ++ %26 = stablehlo.add %25, %7 : tensor<1xi32> + %27 = stablehlo.select %23, %26, %7 : tensor<1xi1>, tensor<1xi32> + %28 = stablehlo.broadcast_in_dim %27, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %29 = "stablehlo.gather"(%20, %28) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_dtypes_shape_uint32_10__axis_0_enable_xla_True.mlir b/stablehlo/stablehlo/testdata/gather_dtypes_shape_uint32_10__axis_0_enable_xla_True.mlir +--- stablehlo/stablehlo/testdata/gather_dtypes_shape_uint32_10__axis_0_enable_xla_True.mlir ++++ stablehlo/stablehlo/testdata/gather_dtypes_shape_uint32_10__axis_0_enable_xla_True.mlir +@@ -34,7 +34,7 @@ + %12 = stablehlo.compare LT, %6, %11, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %13 = stablehlo.constant dense<1> : tensor + %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1xi32> +- %15 = stablehlo.add %6, %14 : tensor<1xi32> ++ %15 = stablehlo.add %14, %6 : tensor<1xi32> + %16 = stablehlo.select %12, %15, %6 : tensor<1xi1>, tensor<1xi32> + %17 = stablehlo.broadcast_in_dim %16, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %18 = "stablehlo.gather"(%9, %17) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -45,7 +45,7 @@ + %23 = stablehlo.compare LT, %7, %22, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %24 = stablehlo.constant dense<1> : tensor + %25 = stablehlo.broadcast_in_dim %24, dims = [] : (tensor) -> tensor<1xi32> +- %26 = stablehlo.add %7, %25 : tensor<1xi32> ++ %26 = stablehlo.add %25, %7 : tensor<1xi32> + %27 = stablehlo.select %23, %26, %7 : tensor<1xi1>, tensor<1xi32> + %28 = stablehlo.broadcast_in_dim %27, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %29 = "stablehlo.gather"(%20, %28) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_dtypes_shape_uint8_10__axis_0_enable_xla_True.mlir b/stablehlo/stablehlo/testdata/gather_dtypes_shape_uint8_10__axis_0_enable_xla_True.mlir +--- stablehlo/stablehlo/testdata/gather_dtypes_shape_uint8_10__axis_0_enable_xla_True.mlir ++++ stablehlo/stablehlo/testdata/gather_dtypes_shape_uint8_10__axis_0_enable_xla_True.mlir +@@ -34,7 +34,7 @@ + %12 = stablehlo.compare LT, %6, %11, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %13 = stablehlo.constant dense<1> : tensor + %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1xi32> +- %15 = stablehlo.add %6, %14 : tensor<1xi32> ++ %15 = stablehlo.add %14, %6 : tensor<1xi32> + %16 = stablehlo.select %12, %15, %6 : tensor<1xi1>, tensor<1xi32> + %17 = stablehlo.broadcast_in_dim %16, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %18 = "stablehlo.gather"(%9, %17) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -45,7 +45,7 @@ + %23 = stablehlo.compare LT, %7, %22, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %24 = stablehlo.constant dense<1> : tensor + %25 = stablehlo.broadcast_in_dim %24, dims = [] : (tensor) -> tensor<1xi32> +- %26 = stablehlo.add %7, %25 : tensor<1xi32> ++ %26 = stablehlo.add %25, %7 : tensor<1xi32> + %27 = stablehlo.select %23, %26, %7 : tensor<1xi1>, tensor<1xi32> + %28 = stablehlo.broadcast_in_dim %27, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %29 = "stablehlo.gather"(%20, %28) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__1__axis_0_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__1__axis_0_enable_xla_True_mode_fill.mlir +--- stablehlo/stablehlo/testdata/gather_from_take_indices_name__1__axis_0_enable_xla_True_mode_fill.mlir ++++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__1__axis_0_enable_xla_True_mode_fill.mlir +@@ -39,7 +39,7 @@ + %17 = stablehlo.compare LT, %6, %16, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %18 = stablehlo.constant dense<3> : tensor + %19 = stablehlo.broadcast_in_dim %18, dims = [] : (tensor) -> tensor<1xi32> +- %20 = stablehlo.add %6, %19 : tensor<1xi32> ++ %20 = stablehlo.add %19, %6 : tensor<1xi32> + %21 = stablehlo.select %17, %20, %6 : tensor<1xi1>, tensor<1xi32> + %22 = stablehlo.broadcast_in_dim %21, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %23 = "stablehlo.gather"(%14, %22) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -55,7 +55,7 @@ + %33 = stablehlo.compare LT, %7, %32, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %34 = stablehlo.constant dense<3> : tensor + %35 = stablehlo.broadcast_in_dim %34, dims = [] : (tensor) -> tensor<1xi32> +- %36 = stablehlo.add %7, %35 : tensor<1xi32> ++ %36 = stablehlo.add %35, %7 : tensor<1xi32> + %37 = stablehlo.select %33, %36, %7 : tensor<1xi1>, tensor<1xi32> + %38 = stablehlo.broadcast_in_dim %37, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %39 = "stablehlo.gather"(%30, %38) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__1__axis_1_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__1__axis_1_enable_xla_True_mode_fill.mlir +--- stablehlo/stablehlo/testdata/gather_from_take_indices_name__1__axis_1_enable_xla_True_mode_fill.mlir ++++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__1__axis_1_enable_xla_True_mode_fill.mlir +@@ -39,7 +39,7 @@ + %17 = stablehlo.compare LT, %6, %16, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %18 = stablehlo.constant dense<3> : tensor + %19 = stablehlo.broadcast_in_dim %18, dims = [] : (tensor) -> tensor<1xi32> +- %20 = stablehlo.add %6, %19 : tensor<1xi32> ++ %20 = stablehlo.add %19, %6 : tensor<1xi32> + %21 = stablehlo.select %17, %20, %6 : tensor<1xi1>, tensor<1xi32> + %22 = stablehlo.broadcast_in_dim %21, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %23 = "stablehlo.gather"(%14, %22) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -55,7 +55,7 @@ + %33 = stablehlo.compare LT, %7, %32, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %34 = stablehlo.constant dense<3> : tensor + %35 = stablehlo.broadcast_in_dim %34, dims = [] : (tensor) -> tensor<1xi32> +- %36 = stablehlo.add %7, %35 : tensor<1xi32> ++ %36 = stablehlo.add %35, %7 : tensor<1xi32> + %37 = stablehlo.select %33, %36, %7 : tensor<1xi1>, tensor<1xi32> + %38 = stablehlo.broadcast_in_dim %37, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %39 = "stablehlo.gather"(%30, %38) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__1__axis_2_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__1__axis_2_enable_xla_True_mode_fill.mlir +--- stablehlo/stablehlo/testdata/gather_from_take_indices_name__1__axis_2_enable_xla_True_mode_fill.mlir ++++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__1__axis_2_enable_xla_True_mode_fill.mlir +@@ -39,7 +39,7 @@ + %17 = stablehlo.compare LT, %6, %16, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %18 = stablehlo.constant dense<3> : tensor + %19 = stablehlo.broadcast_in_dim %18, dims = [] : (tensor) -> tensor<1xi32> +- %20 = stablehlo.add %6, %19 : tensor<1xi32> ++ %20 = stablehlo.add %19, %6 : tensor<1xi32> + %21 = stablehlo.select %17, %20, %6 : tensor<1xi1>, tensor<1xi32> + %22 = stablehlo.broadcast_in_dim %21, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %23 = "stablehlo.gather"(%14, %22) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -55,7 +55,7 @@ + %33 = stablehlo.compare LT, %7, %32, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %34 = stablehlo.constant dense<3> : tensor + %35 = stablehlo.broadcast_in_dim %34, dims = [] : (tensor) -> tensor<1xi32> +- %36 = stablehlo.add %7, %35 : tensor<1xi32> ++ %36 = stablehlo.add %35, %7 : tensor<1xi32> + %37 = stablehlo.select %33, %36, %7 : tensor<1xi1>, tensor<1xi32> + %38 = stablehlo.broadcast_in_dim %37, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %39 = "stablehlo.gather"(%30, %38) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__2__axis_0_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__2__axis_0_enable_xla_True_mode_fill.mlir +--- stablehlo/stablehlo/testdata/gather_from_take_indices_name__2__axis_0_enable_xla_True_mode_fill.mlir ++++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__2__axis_0_enable_xla_True_mode_fill.mlir +@@ -41,7 +41,7 @@ + %19 = stablehlo.compare LT, %8, %18, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %20 = stablehlo.constant dense<3> : tensor + %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<1xi32> +- %22 = stablehlo.add %8, %21 : tensor<1xi32> ++ %22 = stablehlo.add %21, %8 : tensor<1xi32> + %23 = stablehlo.select %19, %22, %8 : tensor<1xi1>, tensor<1xi32> + %24 = stablehlo.broadcast_in_dim %23, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %25 = "stablehlo.gather"(%16, %24) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -57,7 +57,7 @@ + %35 = stablehlo.compare LT, %9, %34, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %36 = stablehlo.constant dense<3> : tensor + %37 = stablehlo.broadcast_in_dim %36, dims = [] : (tensor) -> tensor<1xi32> +- %38 = stablehlo.add %9, %37 : tensor<1xi32> ++ %38 = stablehlo.add %37, %9 : tensor<1xi32> + %39 = stablehlo.select %35, %38, %9 : tensor<1xi1>, tensor<1xi32> + %40 = stablehlo.broadcast_in_dim %39, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %41 = "stablehlo.gather"(%32, %40) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__2__axis_1_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__2__axis_1_enable_xla_True_mode_fill.mlir +--- stablehlo/stablehlo/testdata/gather_from_take_indices_name__2__axis_1_enable_xla_True_mode_fill.mlir ++++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__2__axis_1_enable_xla_True_mode_fill.mlir +@@ -41,7 +41,7 @@ + %19 = stablehlo.compare LT, %8, %18, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %20 = stablehlo.constant dense<3> : tensor + %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<1xi32> +- %22 = stablehlo.add %8, %21 : tensor<1xi32> ++ %22 = stablehlo.add %21, %8 : tensor<1xi32> + %23 = stablehlo.select %19, %22, %8 : tensor<1xi1>, tensor<1xi32> + %24 = stablehlo.broadcast_in_dim %23, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %25 = "stablehlo.gather"(%16, %24) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -57,7 +57,7 @@ + %35 = stablehlo.compare LT, %9, %34, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %36 = stablehlo.constant dense<3> : tensor + %37 = stablehlo.broadcast_in_dim %36, dims = [] : (tensor) -> tensor<1xi32> +- %38 = stablehlo.add %9, %37 : tensor<1xi32> ++ %38 = stablehlo.add %37, %9 : tensor<1xi32> + %39 = stablehlo.select %35, %38, %9 : tensor<1xi1>, tensor<1xi32> + %40 = stablehlo.broadcast_in_dim %39, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %41 = "stablehlo.gather"(%32, %40) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__2__axis_2_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__2__axis_2_enable_xla_True_mode_fill.mlir +--- stablehlo/stablehlo/testdata/gather_from_take_indices_name__2__axis_2_enable_xla_True_mode_fill.mlir ++++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__2__axis_2_enable_xla_True_mode_fill.mlir +@@ -41,7 +41,7 @@ + %19 = stablehlo.compare LT, %8, %18, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %20 = stablehlo.constant dense<3> : tensor + %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<1xi32> +- %22 = stablehlo.add %8, %21 : tensor<1xi32> ++ %22 = stablehlo.add %21, %8 : tensor<1xi32> + %23 = stablehlo.select %19, %22, %8 : tensor<1xi1>, tensor<1xi32> + %24 = stablehlo.broadcast_in_dim %23, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %25 = "stablehlo.gather"(%16, %24) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -57,7 +57,7 @@ + %35 = stablehlo.compare LT, %9, %34, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %36 = stablehlo.constant dense<3> : tensor + %37 = stablehlo.broadcast_in_dim %36, dims = [] : (tensor) -> tensor<1xi32> +- %38 = stablehlo.add %9, %37 : tensor<1xi32> ++ %38 = stablehlo.add %37, %9 : tensor<1xi32> + %39 = stablehlo.select %35, %38, %9 : tensor<1xi1>, tensor<1xi32> + %40 = stablehlo.broadcast_in_dim %39, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %41 = "stablehlo.gather"(%32, %40) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__3__axis_0_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__3__axis_0_enable_xla_True_mode_fill.mlir +--- stablehlo/stablehlo/testdata/gather_from_take_indices_name__3__axis_0_enable_xla_True_mode_fill.mlir ++++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__3__axis_0_enable_xla_True_mode_fill.mlir +@@ -41,7 +41,7 @@ + %19 = stablehlo.compare LT, %8, %18, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %20 = stablehlo.constant dense<3> : tensor + %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<1xi32> +- %22 = stablehlo.add %8, %21 : tensor<1xi32> ++ %22 = stablehlo.add %21, %8 : tensor<1xi32> + %23 = stablehlo.select %19, %22, %8 : tensor<1xi1>, tensor<1xi32> + %24 = stablehlo.broadcast_in_dim %23, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %25 = "stablehlo.gather"(%16, %24) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -57,7 +57,7 @@ + %35 = stablehlo.compare LT, %9, %34, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %36 = stablehlo.constant dense<3> : tensor + %37 = stablehlo.broadcast_in_dim %36, dims = [] : (tensor) -> tensor<1xi32> +- %38 = stablehlo.add %9, %37 : tensor<1xi32> ++ %38 = stablehlo.add %37, %9 : tensor<1xi32> + %39 = stablehlo.select %35, %38, %9 : tensor<1xi1>, tensor<1xi32> + %40 = stablehlo.broadcast_in_dim %39, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %41 = "stablehlo.gather"(%32, %40) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__3__axis_1_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__3__axis_1_enable_xla_True_mode_fill.mlir +--- stablehlo/stablehlo/testdata/gather_from_take_indices_name__3__axis_1_enable_xla_True_mode_fill.mlir ++++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__3__axis_1_enable_xla_True_mode_fill.mlir +@@ -41,7 +41,7 @@ + %19 = stablehlo.compare LT, %8, %18, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %20 = stablehlo.constant dense<3> : tensor + %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<1xi32> +- %22 = stablehlo.add %8, %21 : tensor<1xi32> ++ %22 = stablehlo.add %21, %8 : tensor<1xi32> + %23 = stablehlo.select %19, %22, %8 : tensor<1xi1>, tensor<1xi32> + %24 = stablehlo.broadcast_in_dim %23, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %25 = "stablehlo.gather"(%16, %24) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -57,7 +57,7 @@ + %35 = stablehlo.compare LT, %9, %34, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %36 = stablehlo.constant dense<3> : tensor + %37 = stablehlo.broadcast_in_dim %36, dims = [] : (tensor) -> tensor<1xi32> +- %38 = stablehlo.add %9, %37 : tensor<1xi32> ++ %38 = stablehlo.add %37, %9 : tensor<1xi32> + %39 = stablehlo.select %35, %38, %9 : tensor<1xi1>, tensor<1xi32> + %40 = stablehlo.broadcast_in_dim %39, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %41 = "stablehlo.gather"(%32, %40) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__3__axis_2_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__3__axis_2_enable_xla_True_mode_fill.mlir +--- stablehlo/stablehlo/testdata/gather_from_take_indices_name__3__axis_2_enable_xla_True_mode_fill.mlir ++++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__3__axis_2_enable_xla_True_mode_fill.mlir +@@ -41,7 +41,7 @@ + %19 = stablehlo.compare LT, %8, %18, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %20 = stablehlo.constant dense<3> : tensor + %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<1xi32> +- %22 = stablehlo.add %8, %21 : tensor<1xi32> ++ %22 = stablehlo.add %21, %8 : tensor<1xi32> + %23 = stablehlo.select %19, %22, %8 : tensor<1xi1>, tensor<1xi32> + %24 = stablehlo.broadcast_in_dim %23, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %25 = "stablehlo.gather"(%16, %24) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -57,7 +57,7 @@ + %35 = stablehlo.compare LT, %9, %34, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %36 = stablehlo.constant dense<3> : tensor + %37 = stablehlo.broadcast_in_dim %36, dims = [] : (tensor) -> tensor<1xi32> +- %38 = stablehlo.add %9, %37 : tensor<1xi32> ++ %38 = stablehlo.add %37, %9 : tensor<1xi32> + %39 = stablehlo.select %35, %38, %9 : tensor<1xi1>, tensor<1xi32> + %40 = stablehlo.broadcast_in_dim %39, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %41 = "stablehlo.gather"(%32, %40) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__3_uint32__axis_0_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__3_uint32__axis_0_enable_xla_True_mode_fill.mlir +--- stablehlo/stablehlo/testdata/gather_from_take_indices_name__3_uint32__axis_0_enable_xla_True_mode_fill.mlir ++++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__3_uint32__axis_0_enable_xla_True_mode_fill.mlir +@@ -42,7 +42,7 @@ + %20 = stablehlo.compare LT, %8, %19, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %21 = stablehlo.constant dense<3> : tensor + %22 = stablehlo.broadcast_in_dim %21, dims = [] : (tensor) -> tensor<1xi32> +- %23 = stablehlo.add %8, %22 : tensor<1xi32> ++ %23 = stablehlo.add %22, %8 : tensor<1xi32> + %24 = stablehlo.select %20, %23, %8 : tensor<1xi1>, tensor<1xi32> + %25 = stablehlo.broadcast_in_dim %24, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %26 = "stablehlo.gather"(%16, %25) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -58,7 +58,7 @@ + %36 = stablehlo.compare LT, %9, %35, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %37 = stablehlo.constant dense<3> : tensor + %38 = stablehlo.broadcast_in_dim %37, dims = [] : (tensor) -> tensor<1xi32> +- %39 = stablehlo.add %9, %38 : tensor<1xi32> ++ %39 = stablehlo.add %38, %9 : tensor<1xi32> + %40 = stablehlo.select %36, %39, %9 : tensor<1xi1>, tensor<1xi32> + %41 = stablehlo.broadcast_in_dim %40, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %42 = "stablehlo.gather"(%33, %41) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__3_uint32__axis_1_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__3_uint32__axis_1_enable_xla_True_mode_fill.mlir +--- stablehlo/stablehlo/testdata/gather_from_take_indices_name__3_uint32__axis_1_enable_xla_True_mode_fill.mlir ++++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__3_uint32__axis_1_enable_xla_True_mode_fill.mlir +@@ -42,7 +42,7 @@ + %20 = stablehlo.compare LT, %8, %19, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %21 = stablehlo.constant dense<3> : tensor + %22 = stablehlo.broadcast_in_dim %21, dims = [] : (tensor) -> tensor<1xi32> +- %23 = stablehlo.add %8, %22 : tensor<1xi32> ++ %23 = stablehlo.add %22, %8 : tensor<1xi32> + %24 = stablehlo.select %20, %23, %8 : tensor<1xi1>, tensor<1xi32> + %25 = stablehlo.broadcast_in_dim %24, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %26 = "stablehlo.gather"(%16, %25) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -58,7 +58,7 @@ + %36 = stablehlo.compare LT, %9, %35, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %37 = stablehlo.constant dense<3> : tensor + %38 = stablehlo.broadcast_in_dim %37, dims = [] : (tensor) -> tensor<1xi32> +- %39 = stablehlo.add %9, %38 : tensor<1xi32> ++ %39 = stablehlo.add %38, %9 : tensor<1xi32> + %40 = stablehlo.select %36, %39, %9 : tensor<1xi1>, tensor<1xi32> + %41 = stablehlo.broadcast_in_dim %40, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %42 = "stablehlo.gather"(%33, %41) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__3_uint32__axis_2_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__3_uint32__axis_2_enable_xla_True_mode_fill.mlir +--- stablehlo/stablehlo/testdata/gather_from_take_indices_name__3_uint32__axis_2_enable_xla_True_mode_fill.mlir ++++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__3_uint32__axis_2_enable_xla_True_mode_fill.mlir +@@ -42,7 +42,7 @@ + %20 = stablehlo.compare LT, %8, %19, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %21 = stablehlo.constant dense<3> : tensor + %22 = stablehlo.broadcast_in_dim %21, dims = [] : (tensor) -> tensor<1xi32> +- %23 = stablehlo.add %8, %22 : tensor<1xi32> ++ %23 = stablehlo.add %22, %8 : tensor<1xi32> + %24 = stablehlo.select %20, %23, %8 : tensor<1xi1>, tensor<1xi32> + %25 = stablehlo.broadcast_in_dim %24, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %26 = "stablehlo.gather"(%16, %25) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -58,7 +58,7 @@ + %36 = stablehlo.compare LT, %9, %35, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %37 = stablehlo.constant dense<3> : tensor + %38 = stablehlo.broadcast_in_dim %37, dims = [] : (tensor) -> tensor<1xi32> +- %39 = stablehlo.add %9, %38 : tensor<1xi32> ++ %39 = stablehlo.add %38, %9 : tensor<1xi32> + %40 = stablehlo.select %36, %39, %9 : tensor<1xi1>, tensor<1xi32> + %41 = stablehlo.broadcast_in_dim %40, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %42 = "stablehlo.gather"(%33, %41) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__4__axis_0_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__4__axis_0_enable_xla_True_mode_fill.mlir +--- stablehlo/stablehlo/testdata/gather_from_take_indices_name__4__axis_0_enable_xla_True_mode_fill.mlir ++++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__4__axis_0_enable_xla_True_mode_fill.mlir +@@ -41,7 +41,7 @@ + %19 = stablehlo.compare LT, %8, %18, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %20 = stablehlo.constant dense<3> : tensor + %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<1xi32> +- %22 = stablehlo.add %8, %21 : tensor<1xi32> ++ %22 = stablehlo.add %21, %8 : tensor<1xi32> + %23 = stablehlo.select %19, %22, %8 : tensor<1xi1>, tensor<1xi32> + %24 = stablehlo.broadcast_in_dim %23, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %25 = "stablehlo.gather"(%16, %24) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -57,7 +57,7 @@ + %35 = stablehlo.compare LT, %9, %34, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %36 = stablehlo.constant dense<3> : tensor + %37 = stablehlo.broadcast_in_dim %36, dims = [] : (tensor) -> tensor<1xi32> +- %38 = stablehlo.add %9, %37 : tensor<1xi32> ++ %38 = stablehlo.add %37, %9 : tensor<1xi32> + %39 = stablehlo.select %35, %38, %9 : tensor<1xi1>, tensor<1xi32> + %40 = stablehlo.broadcast_in_dim %39, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %41 = "stablehlo.gather"(%32, %40) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__4__axis_1_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__4__axis_1_enable_xla_True_mode_fill.mlir +--- stablehlo/stablehlo/testdata/gather_from_take_indices_name__4__axis_1_enable_xla_True_mode_fill.mlir ++++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__4__axis_1_enable_xla_True_mode_fill.mlir +@@ -41,7 +41,7 @@ + %19 = stablehlo.compare LT, %8, %18, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %20 = stablehlo.constant dense<3> : tensor + %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<1xi32> +- %22 = stablehlo.add %8, %21 : tensor<1xi32> ++ %22 = stablehlo.add %21, %8 : tensor<1xi32> + %23 = stablehlo.select %19, %22, %8 : tensor<1xi1>, tensor<1xi32> + %24 = stablehlo.broadcast_in_dim %23, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %25 = "stablehlo.gather"(%16, %24) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -57,7 +57,7 @@ + %35 = stablehlo.compare LT, %9, %34, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %36 = stablehlo.constant dense<3> : tensor + %37 = stablehlo.broadcast_in_dim %36, dims = [] : (tensor) -> tensor<1xi32> +- %38 = stablehlo.add %9, %37 : tensor<1xi32> ++ %38 = stablehlo.add %37, %9 : tensor<1xi32> + %39 = stablehlo.select %35, %38, %9 : tensor<1xi1>, tensor<1xi32> + %40 = stablehlo.broadcast_in_dim %39, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %41 = "stablehlo.gather"(%32, %40) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__4__axis_2_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__4__axis_2_enable_xla_True_mode_fill.mlir +--- stablehlo/stablehlo/testdata/gather_from_take_indices_name__4__axis_2_enable_xla_True_mode_fill.mlir ++++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__4__axis_2_enable_xla_True_mode_fill.mlir +@@ -41,7 +41,7 @@ + %19 = stablehlo.compare LT, %8, %18, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %20 = stablehlo.constant dense<3> : tensor + %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<1xi32> +- %22 = stablehlo.add %8, %21 : tensor<1xi32> ++ %22 = stablehlo.add %21, %8 : tensor<1xi32> + %23 = stablehlo.select %19, %22, %8 : tensor<1xi1>, tensor<1xi32> + %24 = stablehlo.broadcast_in_dim %23, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %25 = "stablehlo.gather"(%16, %24) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -57,7 +57,7 @@ + %35 = stablehlo.compare LT, %9, %34, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %36 = stablehlo.constant dense<3> : tensor + %37 = stablehlo.broadcast_in_dim %36, dims = [] : (tensor) -> tensor<1xi32> +- %38 = stablehlo.add %9, %37 : tensor<1xi32> ++ %38 = stablehlo.add %37, %9 : tensor<1xi32> + %39 = stablehlo.select %35, %38, %9 : tensor<1xi1>, tensor<1xi32> + %40 = stablehlo.broadcast_in_dim %39, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %41 = "stablehlo.gather"(%32, %40) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__5_oob__axis_0_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__5_oob__axis_0_enable_xla_True_mode_fill.mlir +--- stablehlo/stablehlo/testdata/gather_from_take_indices_name__5_oob__axis_0_enable_xla_True_mode_fill.mlir ++++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__5_oob__axis_0_enable_xla_True_mode_fill.mlir +@@ -41,7 +41,7 @@ + %19 = stablehlo.compare LT, %8, %18, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %20 = stablehlo.constant dense<3> : tensor + %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<1xi32> +- %22 = stablehlo.add %8, %21 : tensor<1xi32> ++ %22 = stablehlo.add %21, %8 : tensor<1xi32> + %23 = stablehlo.select %19, %22, %8 : tensor<1xi1>, tensor<1xi32> + %24 = stablehlo.broadcast_in_dim %23, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %25 = "stablehlo.gather"(%16, %24) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -57,7 +57,7 @@ + %35 = stablehlo.compare LT, %9, %34, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %36 = stablehlo.constant dense<3> : tensor + %37 = stablehlo.broadcast_in_dim %36, dims = [] : (tensor) -> tensor<1xi32> +- %38 = stablehlo.add %9, %37 : tensor<1xi32> ++ %38 = stablehlo.add %37, %9 : tensor<1xi32> + %39 = stablehlo.select %35, %38, %9 : tensor<1xi1>, tensor<1xi32> + %40 = stablehlo.broadcast_in_dim %39, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %41 = "stablehlo.gather"(%32, %40) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__5_oob__axis_1_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__5_oob__axis_1_enable_xla_True_mode_fill.mlir +--- stablehlo/stablehlo/testdata/gather_from_take_indices_name__5_oob__axis_1_enable_xla_True_mode_fill.mlir ++++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__5_oob__axis_1_enable_xla_True_mode_fill.mlir +@@ -41,7 +41,7 @@ + %19 = stablehlo.compare LT, %8, %18, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %20 = stablehlo.constant dense<3> : tensor + %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<1xi32> +- %22 = stablehlo.add %8, %21 : tensor<1xi32> ++ %22 = stablehlo.add %21, %8 : tensor<1xi32> + %23 = stablehlo.select %19, %22, %8 : tensor<1xi1>, tensor<1xi32> + %24 = stablehlo.broadcast_in_dim %23, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %25 = "stablehlo.gather"(%16, %24) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -57,7 +57,7 @@ + %35 = stablehlo.compare LT, %9, %34, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %36 = stablehlo.constant dense<3> : tensor + %37 = stablehlo.broadcast_in_dim %36, dims = [] : (tensor) -> tensor<1xi32> +- %38 = stablehlo.add %9, %37 : tensor<1xi32> ++ %38 = stablehlo.add %37, %9 : tensor<1xi32> + %39 = stablehlo.select %35, %38, %9 : tensor<1xi1>, tensor<1xi32> + %40 = stablehlo.broadcast_in_dim %39, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %41 = "stablehlo.gather"(%32, %40) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__5_oob__axis_2_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__5_oob__axis_2_enable_xla_True_mode_fill.mlir +--- stablehlo/stablehlo/testdata/gather_from_take_indices_name__5_oob__axis_2_enable_xla_True_mode_fill.mlir ++++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__5_oob__axis_2_enable_xla_True_mode_fill.mlir +@@ -41,7 +41,7 @@ + %19 = stablehlo.compare LT, %8, %18, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %20 = stablehlo.constant dense<3> : tensor + %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<1xi32> +- %22 = stablehlo.add %8, %21 : tensor<1xi32> ++ %22 = stablehlo.add %21, %8 : tensor<1xi32> + %23 = stablehlo.select %19, %22, %8 : tensor<1xi1>, tensor<1xi32> + %24 = stablehlo.broadcast_in_dim %23, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %25 = "stablehlo.gather"(%16, %24) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -57,7 +57,7 @@ + %35 = stablehlo.compare LT, %9, %34, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %36 = stablehlo.constant dense<3> : tensor + %37 = stablehlo.broadcast_in_dim %36, dims = [] : (tensor) -> tensor<1xi32> +- %38 = stablehlo.add %9, %37 : tensor<1xi32> ++ %38 = stablehlo.add %37, %9 : tensor<1xi32> + %39 = stablehlo.select %35, %38, %9 : tensor<1xi1>, tensor<1xi32> + %40 = stablehlo.broadcast_in_dim %39, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %41 = "stablehlo.gather"(%32, %40) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__6_neg__axis_0_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__6_neg__axis_0_enable_xla_True_mode_fill.mlir +--- stablehlo/stablehlo/testdata/gather_from_take_indices_name__6_neg__axis_0_enable_xla_True_mode_fill.mlir ++++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__6_neg__axis_0_enable_xla_True_mode_fill.mlir +@@ -41,7 +41,7 @@ + %19 = stablehlo.compare LT, %8, %18, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %20 = stablehlo.constant dense<3> : tensor + %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<1xi32> +- %22 = stablehlo.add %8, %21 : tensor<1xi32> ++ %22 = stablehlo.add %21, %8 : tensor<1xi32> + %23 = stablehlo.select %19, %22, %8 : tensor<1xi1>, tensor<1xi32> + %24 = stablehlo.broadcast_in_dim %23, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %25 = "stablehlo.gather"(%16, %24) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -57,7 +57,7 @@ + %35 = stablehlo.compare LT, %9, %34, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %36 = stablehlo.constant dense<3> : tensor + %37 = stablehlo.broadcast_in_dim %36, dims = [] : (tensor) -> tensor<1xi32> +- %38 = stablehlo.add %9, %37 : tensor<1xi32> ++ %38 = stablehlo.add %37, %9 : tensor<1xi32> + %39 = stablehlo.select %35, %38, %9 : tensor<1xi1>, tensor<1xi32> + %40 = stablehlo.broadcast_in_dim %39, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %41 = "stablehlo.gather"(%32, %40) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__6_neg__axis_1_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__6_neg__axis_1_enable_xla_True_mode_fill.mlir +--- stablehlo/stablehlo/testdata/gather_from_take_indices_name__6_neg__axis_1_enable_xla_True_mode_fill.mlir ++++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__6_neg__axis_1_enable_xla_True_mode_fill.mlir +@@ -41,7 +41,7 @@ + %19 = stablehlo.compare LT, %8, %18, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %20 = stablehlo.constant dense<3> : tensor + %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<1xi32> +- %22 = stablehlo.add %8, %21 : tensor<1xi32> ++ %22 = stablehlo.add %21, %8 : tensor<1xi32> + %23 = stablehlo.select %19, %22, %8 : tensor<1xi1>, tensor<1xi32> + %24 = stablehlo.broadcast_in_dim %23, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %25 = "stablehlo.gather"(%16, %24) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -57,7 +57,7 @@ + %35 = stablehlo.compare LT, %9, %34, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %36 = stablehlo.constant dense<3> : tensor + %37 = stablehlo.broadcast_in_dim %36, dims = [] : (tensor) -> tensor<1xi32> +- %38 = stablehlo.add %9, %37 : tensor<1xi32> ++ %38 = stablehlo.add %37, %9 : tensor<1xi32> + %39 = stablehlo.select %35, %38, %9 : tensor<1xi1>, tensor<1xi32> + %40 = stablehlo.broadcast_in_dim %39, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %41 = "stablehlo.gather"(%32, %40) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__6_neg__axis_2_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__6_neg__axis_2_enable_xla_True_mode_fill.mlir +--- stablehlo/stablehlo/testdata/gather_from_take_indices_name__6_neg__axis_2_enable_xla_True_mode_fill.mlir ++++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__6_neg__axis_2_enable_xla_True_mode_fill.mlir +@@ -41,7 +41,7 @@ + %19 = stablehlo.compare LT, %8, %18, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %20 = stablehlo.constant dense<3> : tensor + %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<1xi32> +- %22 = stablehlo.add %8, %21 : tensor<1xi32> ++ %22 = stablehlo.add %21, %8 : tensor<1xi32> + %23 = stablehlo.select %19, %22, %8 : tensor<1xi1>, tensor<1xi32> + %24 = stablehlo.broadcast_in_dim %23, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %25 = "stablehlo.gather"(%16, %24) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -57,7 +57,7 @@ + %35 = stablehlo.compare LT, %9, %34, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %36 = stablehlo.constant dense<3> : tensor + %37 = stablehlo.broadcast_in_dim %36, dims = [] : (tensor) -> tensor<1xi32> +- %38 = stablehlo.add %9, %37 : tensor<1xi32> ++ %38 = stablehlo.add %37, %9 : tensor<1xi32> + %39 = stablehlo.select %35, %38, %9 : tensor<1xi1>, tensor<1xi32> + %40 = stablehlo.broadcast_in_dim %39, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %41 = "stablehlo.gather"(%32, %40) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__7_neg__axis_0_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__7_neg__axis_0_enable_xla_True_mode_fill.mlir +--- stablehlo/stablehlo/testdata/gather_from_take_indices_name__7_neg__axis_0_enable_xla_True_mode_fill.mlir ++++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__7_neg__axis_0_enable_xla_True_mode_fill.mlir +@@ -41,7 +41,7 @@ + %19 = stablehlo.compare LT, %8, %18, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %20 = stablehlo.constant dense<3> : tensor + %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<1xi32> +- %22 = stablehlo.add %8, %21 : tensor<1xi32> ++ %22 = stablehlo.add %21, %8 : tensor<1xi32> + %23 = stablehlo.select %19, %22, %8 : tensor<1xi1>, tensor<1xi32> + %24 = stablehlo.broadcast_in_dim %23, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %25 = "stablehlo.gather"(%16, %24) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -57,7 +57,7 @@ + %35 = stablehlo.compare LT, %9, %34, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %36 = stablehlo.constant dense<3> : tensor + %37 = stablehlo.broadcast_in_dim %36, dims = [] : (tensor) -> tensor<1xi32> +- %38 = stablehlo.add %9, %37 : tensor<1xi32> ++ %38 = stablehlo.add %37, %9 : tensor<1xi32> + %39 = stablehlo.select %35, %38, %9 : tensor<1xi1>, tensor<1xi32> + %40 = stablehlo.broadcast_in_dim %39, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %41 = "stablehlo.gather"(%32, %40) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__7_neg__axis_1_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__7_neg__axis_1_enable_xla_True_mode_fill.mlir +--- stablehlo/stablehlo/testdata/gather_from_take_indices_name__7_neg__axis_1_enable_xla_True_mode_fill.mlir ++++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__7_neg__axis_1_enable_xla_True_mode_fill.mlir +@@ -41,7 +41,7 @@ + %19 = stablehlo.compare LT, %8, %18, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %20 = stablehlo.constant dense<3> : tensor + %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<1xi32> +- %22 = stablehlo.add %8, %21 : tensor<1xi32> ++ %22 = stablehlo.add %21, %8 : tensor<1xi32> + %23 = stablehlo.select %19, %22, %8 : tensor<1xi1>, tensor<1xi32> + %24 = stablehlo.broadcast_in_dim %23, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %25 = "stablehlo.gather"(%16, %24) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -57,7 +57,7 @@ + %35 = stablehlo.compare LT, %9, %34, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %36 = stablehlo.constant dense<3> : tensor + %37 = stablehlo.broadcast_in_dim %36, dims = [] : (tensor) -> tensor<1xi32> +- %38 = stablehlo.add %9, %37 : tensor<1xi32> ++ %38 = stablehlo.add %37, %9 : tensor<1xi32> + %39 = stablehlo.select %35, %38, %9 : tensor<1xi1>, tensor<1xi32> + %40 = stablehlo.broadcast_in_dim %39, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %41 = "stablehlo.gather"(%32, %40) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__7_neg__axis_2_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__7_neg__axis_2_enable_xla_True_mode_fill.mlir +--- stablehlo/stablehlo/testdata/gather_from_take_indices_name__7_neg__axis_2_enable_xla_True_mode_fill.mlir ++++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__7_neg__axis_2_enable_xla_True_mode_fill.mlir +@@ -41,7 +41,7 @@ + %19 = stablehlo.compare LT, %8, %18, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %20 = stablehlo.constant dense<3> : tensor + %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<1xi32> +- %22 = stablehlo.add %8, %21 : tensor<1xi32> ++ %22 = stablehlo.add %21, %8 : tensor<1xi32> + %23 = stablehlo.select %19, %22, %8 : tensor<1xi1>, tensor<1xi32> + %24 = stablehlo.broadcast_in_dim %23, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %25 = "stablehlo.gather"(%16, %24) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -57,7 +57,7 @@ + %35 = stablehlo.compare LT, %9, %34, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %36 = stablehlo.constant dense<3> : tensor + %37 = stablehlo.broadcast_in_dim %36, dims = [] : (tensor) -> tensor<1xi32> +- %38 = stablehlo.add %9, %37 : tensor<1xi32> ++ %38 = stablehlo.add %37, %9 : tensor<1xi32> + %39 = stablehlo.select %35, %38, %9 : tensor<1xi1>, tensor<1xi32> + %40 = stablehlo.broadcast_in_dim %39, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %41 = "stablehlo.gather"(%32, %40) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__8_neg_oob__axis_0_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__8_neg_oob__axis_0_enable_xla_True_mode_fill.mlir +--- stablehlo/stablehlo/testdata/gather_from_take_indices_name__8_neg_oob__axis_0_enable_xla_True_mode_fill.mlir ++++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__8_neg_oob__axis_0_enable_xla_True_mode_fill.mlir +@@ -41,7 +41,7 @@ + %19 = stablehlo.compare LT, %8, %18, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %20 = stablehlo.constant dense<3> : tensor + %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<1xi32> +- %22 = stablehlo.add %8, %21 : tensor<1xi32> ++ %22 = stablehlo.add %21, %8 : tensor<1xi32> + %23 = stablehlo.select %19, %22, %8 : tensor<1xi1>, tensor<1xi32> + %24 = stablehlo.broadcast_in_dim %23, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %25 = "stablehlo.gather"(%16, %24) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -57,7 +57,7 @@ + %35 = stablehlo.compare LT, %9, %34, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %36 = stablehlo.constant dense<3> : tensor + %37 = stablehlo.broadcast_in_dim %36, dims = [] : (tensor) -> tensor<1xi32> +- %38 = stablehlo.add %9, %37 : tensor<1xi32> ++ %38 = stablehlo.add %37, %9 : tensor<1xi32> + %39 = stablehlo.select %35, %38, %9 : tensor<1xi1>, tensor<1xi32> + %40 = stablehlo.broadcast_in_dim %39, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %41 = "stablehlo.gather"(%32, %40) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__8_neg_oob__axis_1_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__8_neg_oob__axis_1_enable_xla_True_mode_fill.mlir +--- stablehlo/stablehlo/testdata/gather_from_take_indices_name__8_neg_oob__axis_1_enable_xla_True_mode_fill.mlir ++++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__8_neg_oob__axis_1_enable_xla_True_mode_fill.mlir +@@ -41,7 +41,7 @@ + %19 = stablehlo.compare LT, %8, %18, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %20 = stablehlo.constant dense<3> : tensor + %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<1xi32> +- %22 = stablehlo.add %8, %21 : tensor<1xi32> ++ %22 = stablehlo.add %21, %8 : tensor<1xi32> + %23 = stablehlo.select %19, %22, %8 : tensor<1xi1>, tensor<1xi32> + %24 = stablehlo.broadcast_in_dim %23, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %25 = "stablehlo.gather"(%16, %24) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -57,7 +57,7 @@ + %35 = stablehlo.compare LT, %9, %34, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %36 = stablehlo.constant dense<3> : tensor + %37 = stablehlo.broadcast_in_dim %36, dims = [] : (tensor) -> tensor<1xi32> +- %38 = stablehlo.add %9, %37 : tensor<1xi32> ++ %38 = stablehlo.add %37, %9 : tensor<1xi32> + %39 = stablehlo.select %35, %38, %9 : tensor<1xi1>, tensor<1xi32> + %40 = stablehlo.broadcast_in_dim %39, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %41 = "stablehlo.gather"(%32, %40) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__8_neg_oob__axis_2_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__8_neg_oob__axis_2_enable_xla_True_mode_fill.mlir +--- stablehlo/stablehlo/testdata/gather_from_take_indices_name__8_neg_oob__axis_2_enable_xla_True_mode_fill.mlir ++++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__8_neg_oob__axis_2_enable_xla_True_mode_fill.mlir +@@ -41,7 +41,7 @@ + %19 = stablehlo.compare LT, %8, %18, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %20 = stablehlo.constant dense<3> : tensor + %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<1xi32> +- %22 = stablehlo.add %8, %21 : tensor<1xi32> ++ %22 = stablehlo.add %21, %8 : tensor<1xi32> + %23 = stablehlo.select %19, %22, %8 : tensor<1xi1>, tensor<1xi32> + %24 = stablehlo.broadcast_in_dim %23, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %25 = "stablehlo.gather"(%16, %24) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -57,7 +57,7 @@ + %35 = stablehlo.compare LT, %9, %34, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %36 = stablehlo.constant dense<3> : tensor + %37 = stablehlo.broadcast_in_dim %36, dims = [] : (tensor) -> tensor<1xi32> +- %38 = stablehlo.add %9, %37 : tensor<1xi32> ++ %38 = stablehlo.add %37, %9 : tensor<1xi32> + %39 = stablehlo.select %35, %38, %9 : tensor<1xi1>, tensor<1xi32> + %40 = stablehlo.broadcast_in_dim %39, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %41 = "stablehlo.gather"(%32, %40) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/igamma_broadcasting_lhs_float32_1_20__rhs_float32_20_20.mlir b/stablehlo/stablehlo/testdata/igamma_broadcasting_lhs_float32_1_20__rhs_float32_20_20.mlir +--- stablehlo/stablehlo/testdata/igamma_broadcasting_lhs_float32_1_20__rhs_float32_20_20.mlir ++++ stablehlo/stablehlo/testdata/igamma_broadcasting_lhs_float32_1_20__rhs_float32_20_20.mlir +@@ -202,7 +202,7 @@ + %21 = stablehlo.multiply %19, %20 : tensor<20x20xf32> + %22 = stablehlo.constant dense<-1.000000e+00> : tensor + %23 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xf32> +- %24 = stablehlo.multiply %23, %1 : tensor<20x20xf32> ++ %24 = stablehlo.multiply %1, %23 : tensor<20x20xf32> + %25 = stablehlo.multiply %24, %2 : tensor<20x20xf32> + %26 = stablehlo.multiply %6, %6 : tensor<20x20xf32> + %27 = stablehlo.divide %25, %26 : tensor<20x20xf32> +@@ -272,7 +272,7 @@ + %34 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %35 = stablehlo.subtract %34, %29 : tensor<20x20xf32> + %36 = stablehlo.select %32, %35, %29 : tensor<20x20xi1>, tensor<20x20xf32> +- %37 = stablehlo.multiply %26, %36 : tensor<20x20xf32> ++ %37 = stablehlo.multiply %36, %26 : tensor<20x20xf32> + %38 = stablehlo.sine %37 : tensor<20x20xf32> + %39 = stablehlo.log %38 : tensor<20x20xf32> + %40 = stablehlo.is_finite %39 : (tensor<20x20xf32>) -> tensor<20x20xi1> +@@ -290,17 +290,17 @@ + %52 = stablehlo.add %50, %51 : tensor<20x20xf32> + %53 = stablehlo.constant dense<7.500000e+00> : tensor + %54 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> +- %55 = stablehlo.add %54, %50 : tensor<20x20xf32> ++ %55 = stablehlo.add %50, %54 : tensor<20x20xf32> + %56 = stablehlo.constant dense<2.01490307> : tensor + %57 = stablehlo.constant dense<2.01490307> : tensor<20x20xf32> + %58 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> + %59 = stablehlo.divide %50, %58 : tensor<20x20xf32> + %60 = stablehlo.log_plus_one %59 : tensor<20x20xf32> +- %61 = stablehlo.add %57, %60 : tensor<20x20xf32> ++ %61 = stablehlo.add %60, %57 : tensor<20x20xf32> + %62 = stablehlo.divide %55, %61 : tensor<20x20xf32> + %63 = stablehlo.subtract %52, %62 : tensor<20x20xf32> + %64 = stablehlo.multiply %63, %61 : tensor<20x20xf32> +- %65 = stablehlo.add %45, %64 : tensor<20x20xf32> ++ %65 = stablehlo.add %64, %45 : tensor<20x20xf32> + %66 = stablehlo.constant dense<1.000000e+00> : tensor + %67 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %68 = stablehlo.constant dense<676.520386> : tensor +@@ -311,7 +311,7 @@ + %73 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %74 = stablehlo.add %72, %73 : tensor<20x20xf32> + %75 = stablehlo.divide %69, %74 : tensor<20x20xf32> +- %76 = stablehlo.add %67, %75 : tensor<20x20xf32> ++ %76 = stablehlo.add %75, %67 : tensor<20x20xf32> + %77 = stablehlo.constant dense<-1259.13916> : tensor + %78 = stablehlo.constant dense<-1259.13916> : tensor<20x20xf32> + %79 = stablehlo.constant dense<1.000000e+00> : tensor +@@ -585,7 +585,7 @@ + %240 = stablehlo.multiply %iterArg_4, %239 : tensor<20x20xf32> + %241 = stablehlo.constant dense<-1.000000e+00> : tensor + %242 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xf32> +- %243 = stablehlo.multiply %242, %iterArg_1 : tensor<20x20xf32> ++ %243 = stablehlo.multiply %iterArg_1, %242 : tensor<20x20xf32> + %244 = stablehlo.multiply %243, %iterArg_3 : tensor<20x20xf32> + %245 = stablehlo.multiply %227, %227 : tensor<20x20xf32> + %246 = stablehlo.divide %244, %245 : tensor<20x20xf32> +diff --ruN a/stablehlo/stablehlo/testdata/igamma_broadcasting_lhs_float32_20_20__rhs_float32_1_20.mlir b/stablehlo/stablehlo/testdata/igamma_broadcasting_lhs_float32_20_20__rhs_float32_1_20.mlir +--- stablehlo/stablehlo/testdata/igamma_broadcasting_lhs_float32_20_20__rhs_float32_1_20.mlir ++++ stablehlo/stablehlo/testdata/igamma_broadcasting_lhs_float32_20_20__rhs_float32_1_20.mlir +@@ -202,7 +202,7 @@ + %21 = stablehlo.multiply %19, %20 : tensor<20x20xf32> + %22 = stablehlo.constant dense<-1.000000e+00> : tensor + %23 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xf32> +- %24 = stablehlo.multiply %23, %1 : tensor<20x20xf32> ++ %24 = stablehlo.multiply %1, %23 : tensor<20x20xf32> + %25 = stablehlo.multiply %24, %2 : tensor<20x20xf32> + %26 = stablehlo.multiply %6, %6 : tensor<20x20xf32> + %27 = stablehlo.divide %25, %26 : tensor<20x20xf32> +@@ -272,7 +272,7 @@ + %34 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %35 = stablehlo.subtract %34, %29 : tensor<20x20xf32> + %36 = stablehlo.select %32, %35, %29 : tensor<20x20xi1>, tensor<20x20xf32> +- %37 = stablehlo.multiply %26, %36 : tensor<20x20xf32> ++ %37 = stablehlo.multiply %36, %26 : tensor<20x20xf32> + %38 = stablehlo.sine %37 : tensor<20x20xf32> + %39 = stablehlo.log %38 : tensor<20x20xf32> + %40 = stablehlo.is_finite %39 : (tensor<20x20xf32>) -> tensor<20x20xi1> +@@ -290,17 +290,17 @@ + %52 = stablehlo.add %50, %51 : tensor<20x20xf32> + %53 = stablehlo.constant dense<7.500000e+00> : tensor + %54 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> +- %55 = stablehlo.add %54, %50 : tensor<20x20xf32> ++ %55 = stablehlo.add %50, %54 : tensor<20x20xf32> + %56 = stablehlo.constant dense<2.01490307> : tensor + %57 = stablehlo.constant dense<2.01490307> : tensor<20x20xf32> + %58 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> + %59 = stablehlo.divide %50, %58 : tensor<20x20xf32> + %60 = stablehlo.log_plus_one %59 : tensor<20x20xf32> +- %61 = stablehlo.add %57, %60 : tensor<20x20xf32> ++ %61 = stablehlo.add %60, %57 : tensor<20x20xf32> + %62 = stablehlo.divide %55, %61 : tensor<20x20xf32> + %63 = stablehlo.subtract %52, %62 : tensor<20x20xf32> + %64 = stablehlo.multiply %63, %61 : tensor<20x20xf32> +- %65 = stablehlo.add %45, %64 : tensor<20x20xf32> ++ %65 = stablehlo.add %64, %45 : tensor<20x20xf32> + %66 = stablehlo.constant dense<1.000000e+00> : tensor + %67 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %68 = stablehlo.constant dense<676.520386> : tensor +@@ -311,7 +311,7 @@ + %73 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %74 = stablehlo.add %72, %73 : tensor<20x20xf32> + %75 = stablehlo.divide %69, %74 : tensor<20x20xf32> +- %76 = stablehlo.add %67, %75 : tensor<20x20xf32> ++ %76 = stablehlo.add %75, %67 : tensor<20x20xf32> + %77 = stablehlo.constant dense<-1259.13916> : tensor + %78 = stablehlo.constant dense<-1259.13916> : tensor<20x20xf32> + %79 = stablehlo.constant dense<1.000000e+00> : tensor +@@ -585,7 +585,7 @@ + %240 = stablehlo.multiply %iterArg_4, %239 : tensor<20x20xf32> + %241 = stablehlo.constant dense<-1.000000e+00> : tensor + %242 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xf32> +- %243 = stablehlo.multiply %242, %iterArg_1 : tensor<20x20xf32> ++ %243 = stablehlo.multiply %iterArg_1, %242 : tensor<20x20xf32> + %244 = stablehlo.multiply %243, %iterArg_3 : tensor<20x20xf32> + %245 = stablehlo.multiply %227, %227 : tensor<20x20xf32> + %246 = stablehlo.divide %244, %245 : tensor<20x20xf32> +diff --ruN a/stablehlo/stablehlo/testdata/igamma_dtypes_lhs_bfloat16_20_20__rhs_bfloat16_20_20.mlir b/stablehlo/stablehlo/testdata/igamma_dtypes_lhs_bfloat16_20_20__rhs_bfloat16_20_20.mlir +--- stablehlo/stablehlo/testdata/igamma_dtypes_lhs_bfloat16_20_20__rhs_bfloat16_20_20.mlir ++++ stablehlo/stablehlo/testdata/igamma_dtypes_lhs_bfloat16_20_20__rhs_bfloat16_20_20.mlir +@@ -202,7 +202,7 @@ + %21 = stablehlo.multiply %19, %20 : tensor<20x20xf32> + %22 = stablehlo.constant dense<-1.000000e+00> : tensor + %23 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xf32> +- %24 = stablehlo.multiply %23, %1 : tensor<20x20xf32> ++ %24 = stablehlo.multiply %1, %23 : tensor<20x20xf32> + %25 = stablehlo.multiply %24, %2 : tensor<20x20xf32> + %26 = stablehlo.multiply %6, %6 : tensor<20x20xf32> + %27 = stablehlo.divide %25, %26 : tensor<20x20xf32> +@@ -272,7 +272,7 @@ + %34 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %35 = stablehlo.subtract %34, %29 : tensor<20x20xf32> + %36 = stablehlo.select %32, %35, %29 : tensor<20x20xi1>, tensor<20x20xf32> +- %37 = stablehlo.multiply %26, %36 : tensor<20x20xf32> ++ %37 = stablehlo.multiply %36, %26 : tensor<20x20xf32> + %38 = stablehlo.sine %37 : tensor<20x20xf32> + %39 = stablehlo.log %38 : tensor<20x20xf32> + %40 = stablehlo.is_finite %39 : (tensor<20x20xf32>) -> tensor<20x20xi1> +@@ -290,17 +290,17 @@ + %52 = stablehlo.add %50, %51 : tensor<20x20xf32> + %53 = stablehlo.constant dense<7.500000e+00> : tensor + %54 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> +- %55 = stablehlo.add %54, %50 : tensor<20x20xf32> ++ %55 = stablehlo.add %50, %54 : tensor<20x20xf32> + %56 = stablehlo.constant dense<2.01490307> : tensor + %57 = stablehlo.constant dense<2.01490307> : tensor<20x20xf32> + %58 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> + %59 = stablehlo.divide %50, %58 : tensor<20x20xf32> + %60 = stablehlo.log_plus_one %59 : tensor<20x20xf32> +- %61 = stablehlo.add %57, %60 : tensor<20x20xf32> ++ %61 = stablehlo.add %60, %57 : tensor<20x20xf32> + %62 = stablehlo.divide %55, %61 : tensor<20x20xf32> + %63 = stablehlo.subtract %52, %62 : tensor<20x20xf32> + %64 = stablehlo.multiply %63, %61 : tensor<20x20xf32> +- %65 = stablehlo.add %45, %64 : tensor<20x20xf32> ++ %65 = stablehlo.add %64, %45 : tensor<20x20xf32> + %66 = stablehlo.constant dense<1.000000e+00> : tensor + %67 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %68 = stablehlo.constant dense<676.520386> : tensor +@@ -311,7 +311,7 @@ + %73 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %74 = stablehlo.add %72, %73 : tensor<20x20xf32> + %75 = stablehlo.divide %69, %74 : tensor<20x20xf32> +- %76 = stablehlo.add %67, %75 : tensor<20x20xf32> ++ %76 = stablehlo.add %75, %67 : tensor<20x20xf32> + %77 = stablehlo.constant dense<-1259.13916> : tensor + %78 = stablehlo.constant dense<-1259.13916> : tensor<20x20xf32> + %79 = stablehlo.constant dense<1.000000e+00> : tensor +@@ -585,7 +585,7 @@ + %241 = stablehlo.multiply %iterArg_4, %240 : tensor<20x20xf32> + %242 = stablehlo.constant dense<-1.000000e+00> : tensor + %243 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xf32> +- %244 = stablehlo.multiply %243, %iterArg_1 : tensor<20x20xf32> ++ %244 = stablehlo.multiply %iterArg_1, %243 : tensor<20x20xf32> + %245 = stablehlo.multiply %244, %iterArg_3 : tensor<20x20xf32> + %246 = stablehlo.multiply %228, %228 : tensor<20x20xf32> + %247 = stablehlo.divide %245, %246 : tensor<20x20xf32> +diff --ruN a/stablehlo/stablehlo/testdata/igamma_dtypes_lhs_float16_20_20__rhs_float16_20_20.mlir b/stablehlo/stablehlo/testdata/igamma_dtypes_lhs_float16_20_20__rhs_float16_20_20.mlir +--- stablehlo/stablehlo/testdata/igamma_dtypes_lhs_float16_20_20__rhs_float16_20_20.mlir ++++ stablehlo/stablehlo/testdata/igamma_dtypes_lhs_float16_20_20__rhs_float16_20_20.mlir +@@ -202,7 +202,7 @@ + %21 = stablehlo.multiply %19, %20 : tensor<20x20xf32> + %22 = stablehlo.constant dense<-1.000000e+00> : tensor + %23 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xf32> +- %24 = stablehlo.multiply %23, %1 : tensor<20x20xf32> ++ %24 = stablehlo.multiply %1, %23 : tensor<20x20xf32> + %25 = stablehlo.multiply %24, %2 : tensor<20x20xf32> + %26 = stablehlo.multiply %6, %6 : tensor<20x20xf32> + %27 = stablehlo.divide %25, %26 : tensor<20x20xf32> +@@ -272,7 +272,7 @@ + %34 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %35 = stablehlo.subtract %34, %29 : tensor<20x20xf32> + %36 = stablehlo.select %32, %35, %29 : tensor<20x20xi1>, tensor<20x20xf32> +- %37 = stablehlo.multiply %26, %36 : tensor<20x20xf32> ++ %37 = stablehlo.multiply %36, %26 : tensor<20x20xf32> + %38 = stablehlo.sine %37 : tensor<20x20xf32> + %39 = stablehlo.log %38 : tensor<20x20xf32> + %40 = stablehlo.is_finite %39 : (tensor<20x20xf32>) -> tensor<20x20xi1> +@@ -290,17 +290,17 @@ + %52 = stablehlo.add %50, %51 : tensor<20x20xf32> + %53 = stablehlo.constant dense<7.500000e+00> : tensor + %54 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> +- %55 = stablehlo.add %54, %50 : tensor<20x20xf32> ++ %55 = stablehlo.add %50, %54 : tensor<20x20xf32> + %56 = stablehlo.constant dense<2.01490307> : tensor + %57 = stablehlo.constant dense<2.01490307> : tensor<20x20xf32> + %58 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> + %59 = stablehlo.divide %50, %58 : tensor<20x20xf32> + %60 = stablehlo.log_plus_one %59 : tensor<20x20xf32> +- %61 = stablehlo.add %57, %60 : tensor<20x20xf32> ++ %61 = stablehlo.add %60, %57 : tensor<20x20xf32> + %62 = stablehlo.divide %55, %61 : tensor<20x20xf32> + %63 = stablehlo.subtract %52, %62 : tensor<20x20xf32> + %64 = stablehlo.multiply %63, %61 : tensor<20x20xf32> +- %65 = stablehlo.add %45, %64 : tensor<20x20xf32> ++ %65 = stablehlo.add %64, %45 : tensor<20x20xf32> + %66 = stablehlo.constant dense<1.000000e+00> : tensor + %67 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %68 = stablehlo.constant dense<676.520386> : tensor +@@ -311,7 +311,7 @@ + %73 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %74 = stablehlo.add %72, %73 : tensor<20x20xf32> + %75 = stablehlo.divide %69, %74 : tensor<20x20xf32> +- %76 = stablehlo.add %67, %75 : tensor<20x20xf32> ++ %76 = stablehlo.add %75, %67 : tensor<20x20xf32> + %77 = stablehlo.constant dense<-1259.13916> : tensor + %78 = stablehlo.constant dense<-1259.13916> : tensor<20x20xf32> + %79 = stablehlo.constant dense<1.000000e+00> : tensor +@@ -585,7 +585,7 @@ + %241 = stablehlo.multiply %iterArg_4, %240 : tensor<20x20xf32> + %242 = stablehlo.constant dense<-1.000000e+00> : tensor + %243 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xf32> +- %244 = stablehlo.multiply %243, %iterArg_1 : tensor<20x20xf32> ++ %244 = stablehlo.multiply %iterArg_1, %243 : tensor<20x20xf32> + %245 = stablehlo.multiply %244, %iterArg_3 : tensor<20x20xf32> + %246 = stablehlo.multiply %228, %228 : tensor<20x20xf32> + %247 = stablehlo.divide %245, %246 : tensor<20x20xf32> +diff --ruN a/stablehlo/stablehlo/testdata/igamma_dtypes_lhs_float32_20_20__rhs_float32_20_20.mlir b/stablehlo/stablehlo/testdata/igamma_dtypes_lhs_float32_20_20__rhs_float32_20_20.mlir +--- stablehlo/stablehlo/testdata/igamma_dtypes_lhs_float32_20_20__rhs_float32_20_20.mlir ++++ stablehlo/stablehlo/testdata/igamma_dtypes_lhs_float32_20_20__rhs_float32_20_20.mlir +@@ -202,7 +202,7 @@ + %21 = stablehlo.multiply %19, %20 : tensor<20x20xf32> + %22 = stablehlo.constant dense<-1.000000e+00> : tensor + %23 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xf32> +- %24 = stablehlo.multiply %23, %1 : tensor<20x20xf32> ++ %24 = stablehlo.multiply %1, %23 : tensor<20x20xf32> + %25 = stablehlo.multiply %24, %2 : tensor<20x20xf32> + %26 = stablehlo.multiply %6, %6 : tensor<20x20xf32> + %27 = stablehlo.divide %25, %26 : tensor<20x20xf32> +@@ -270,7 +270,7 @@ + %32 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %33 = stablehlo.subtract %32, %27 : tensor<20x20xf32> + %34 = stablehlo.select %30, %33, %27 : tensor<20x20xi1>, tensor<20x20xf32> +- %35 = stablehlo.multiply %24, %34 : tensor<20x20xf32> ++ %35 = stablehlo.multiply %34, %24 : tensor<20x20xf32> + %36 = stablehlo.sine %35 : tensor<20x20xf32> + %37 = stablehlo.log %36 : tensor<20x20xf32> + %38 = stablehlo.is_finite %37 : (tensor<20x20xf32>) -> tensor<20x20xi1> +@@ -288,17 +288,17 @@ + %50 = stablehlo.add %48, %49 : tensor<20x20xf32> + %51 = stablehlo.constant dense<7.500000e+00> : tensor + %52 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> +- %53 = stablehlo.add %52, %48 : tensor<20x20xf32> ++ %53 = stablehlo.add %48, %52 : tensor<20x20xf32> + %54 = stablehlo.constant dense<2.01490307> : tensor + %55 = stablehlo.constant dense<2.01490307> : tensor<20x20xf32> + %56 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> + %57 = stablehlo.divide %48, %56 : tensor<20x20xf32> + %58 = stablehlo.log_plus_one %57 : tensor<20x20xf32> +- %59 = stablehlo.add %55, %58 : tensor<20x20xf32> ++ %59 = stablehlo.add %58, %55 : tensor<20x20xf32> + %60 = stablehlo.divide %53, %59 : tensor<20x20xf32> + %61 = stablehlo.subtract %50, %60 : tensor<20x20xf32> + %62 = stablehlo.multiply %61, %59 : tensor<20x20xf32> +- %63 = stablehlo.add %43, %62 : tensor<20x20xf32> ++ %63 = stablehlo.add %62, %43 : tensor<20x20xf32> + %64 = stablehlo.constant dense<1.000000e+00> : tensor + %65 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %66 = stablehlo.constant dense<676.520386> : tensor +@@ -309,7 +309,7 @@ + %71 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %72 = stablehlo.add %70, %71 : tensor<20x20xf32> + %73 = stablehlo.divide %67, %72 : tensor<20x20xf32> +- %74 = stablehlo.add %65, %73 : tensor<20x20xf32> ++ %74 = stablehlo.add %73, %65 : tensor<20x20xf32> + %75 = stablehlo.constant dense<-1259.13916> : tensor + %76 = stablehlo.constant dense<-1259.13916> : tensor<20x20xf32> + %77 = stablehlo.constant dense<1.000000e+00> : tensor +@@ -583,7 +583,7 @@ + %238 = stablehlo.multiply %iterArg_4, %237 : tensor<20x20xf32> + %239 = stablehlo.constant dense<-1.000000e+00> : tensor + %240 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xf32> +- %241 = stablehlo.multiply %240, %iterArg_1 : tensor<20x20xf32> ++ %241 = stablehlo.multiply %iterArg_1, %240 : tensor<20x20xf32> + %242 = stablehlo.multiply %241, %iterArg_3 : tensor<20x20xf32> + %243 = stablehlo.multiply %225, %225 : tensor<20x20xf32> + %244 = stablehlo.divide %242, %243 : tensor<20x20xf32> +diff --ruN a/stablehlo/stablehlo/testdata/igammac_broadcasting_lhs_float32_1_20__rhs_float32_20_20.mlir b/stablehlo/stablehlo/testdata/igammac_broadcasting_lhs_float32_1_20__rhs_float32_20_20.mlir +--- stablehlo/stablehlo/testdata/igammac_broadcasting_lhs_float32_1_20__rhs_float32_20_20.mlir ++++ stablehlo/stablehlo/testdata/igammac_broadcasting_lhs_float32_1_20__rhs_float32_20_20.mlir +@@ -47,7 +47,7 @@ + %21 = stablehlo.multiply %19, %20 : tensor<20x20xf32> + %22 = stablehlo.constant dense<-1.000000e+00> : tensor + %23 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xf32> +- %24 = stablehlo.multiply %23, %1 : tensor<20x20xf32> ++ %24 = stablehlo.multiply %1, %23 : tensor<20x20xf32> + %25 = stablehlo.multiply %24, %2 : tensor<20x20xf32> + %26 = stablehlo.multiply %6, %6 : tensor<20x20xf32> + %27 = stablehlo.divide %25, %26 : tensor<20x20xf32> +@@ -268,7 +268,7 @@ + %30 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %31 = stablehlo.subtract %30, %25 : tensor<20x20xf32> + %32 = stablehlo.select %28, %31, %25 : tensor<20x20xi1>, tensor<20x20xf32> +- %33 = stablehlo.multiply %22, %32 : tensor<20x20xf32> ++ %33 = stablehlo.multiply %32, %22 : tensor<20x20xf32> + %34 = stablehlo.sine %33 : tensor<20x20xf32> + %35 = stablehlo.log %34 : tensor<20x20xf32> + %36 = stablehlo.is_finite %35 : (tensor<20x20xf32>) -> tensor<20x20xi1> +@@ -286,17 +286,17 @@ + %48 = stablehlo.add %46, %47 : tensor<20x20xf32> + %49 = stablehlo.constant dense<7.500000e+00> : tensor + %50 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> +- %51 = stablehlo.add %50, %46 : tensor<20x20xf32> ++ %51 = stablehlo.add %46, %50 : tensor<20x20xf32> + %52 = stablehlo.constant dense<2.01490307> : tensor + %53 = stablehlo.constant dense<2.01490307> : tensor<20x20xf32> + %54 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> + %55 = stablehlo.divide %46, %54 : tensor<20x20xf32> + %56 = stablehlo.log_plus_one %55 : tensor<20x20xf32> +- %57 = stablehlo.add %53, %56 : tensor<20x20xf32> ++ %57 = stablehlo.add %56, %53 : tensor<20x20xf32> + %58 = stablehlo.divide %51, %57 : tensor<20x20xf32> + %59 = stablehlo.subtract %48, %58 : tensor<20x20xf32> + %60 = stablehlo.multiply %59, %57 : tensor<20x20xf32> +- %61 = stablehlo.add %41, %60 : tensor<20x20xf32> ++ %61 = stablehlo.add %60, %41 : tensor<20x20xf32> + %62 = stablehlo.constant dense<1.000000e+00> : tensor + %63 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %64 = stablehlo.constant dense<676.520386> : tensor +@@ -307,7 +307,7 @@ + %69 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %70 = stablehlo.add %68, %69 : tensor<20x20xf32> + %71 = stablehlo.divide %65, %70 : tensor<20x20xf32> +- %72 = stablehlo.add %63, %71 : tensor<20x20xf32> ++ %72 = stablehlo.add %71, %63 : tensor<20x20xf32> + %73 = stablehlo.constant dense<-1259.13916> : tensor + %74 = stablehlo.constant dense<-1259.13916> : tensor<20x20xf32> + %75 = stablehlo.constant dense<1.000000e+00> : tensor +@@ -428,7 +428,7 @@ + %228 = stablehlo.multiply %iterArg_4, %227 : tensor<20x20xf32> + %229 = stablehlo.constant dense<-1.000000e+00> : tensor + %230 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xf32> +- %231 = stablehlo.multiply %230, %iterArg_1 : tensor<20x20xf32> ++ %231 = stablehlo.multiply %iterArg_1, %230 : tensor<20x20xf32> + %232 = stablehlo.multiply %231, %iterArg_3 : tensor<20x20xf32> + %233 = stablehlo.multiply %215, %215 : tensor<20x20xf32> + %234 = stablehlo.divide %232, %233 : tensor<20x20xf32> +diff --ruN a/stablehlo/stablehlo/testdata/igammac_broadcasting_lhs_float32_20_20__rhs_float32_1_20.mlir b/stablehlo/stablehlo/testdata/igammac_broadcasting_lhs_float32_20_20__rhs_float32_1_20.mlir +--- stablehlo/stablehlo/testdata/igammac_broadcasting_lhs_float32_20_20__rhs_float32_1_20.mlir ++++ stablehlo/stablehlo/testdata/igammac_broadcasting_lhs_float32_20_20__rhs_float32_1_20.mlir +@@ -47,7 +47,7 @@ + %21 = stablehlo.multiply %19, %20 : tensor<20x20xf32> + %22 = stablehlo.constant dense<-1.000000e+00> : tensor + %23 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xf32> +- %24 = stablehlo.multiply %23, %1 : tensor<20x20xf32> ++ %24 = stablehlo.multiply %1, %23 : tensor<20x20xf32> + %25 = stablehlo.multiply %24, %2 : tensor<20x20xf32> + %26 = stablehlo.multiply %6, %6 : tensor<20x20xf32> + %27 = stablehlo.divide %25, %26 : tensor<20x20xf32> +@@ -268,7 +268,7 @@ + %30 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %31 = stablehlo.subtract %30, %25 : tensor<20x20xf32> + %32 = stablehlo.select %28, %31, %25 : tensor<20x20xi1>, tensor<20x20xf32> +- %33 = stablehlo.multiply %22, %32 : tensor<20x20xf32> ++ %33 = stablehlo.multiply %32, %22 : tensor<20x20xf32> + %34 = stablehlo.sine %33 : tensor<20x20xf32> + %35 = stablehlo.log %34 : tensor<20x20xf32> + %36 = stablehlo.is_finite %35 : (tensor<20x20xf32>) -> tensor<20x20xi1> +@@ -286,17 +286,17 @@ + %48 = stablehlo.add %46, %47 : tensor<20x20xf32> + %49 = stablehlo.constant dense<7.500000e+00> : tensor + %50 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> +- %51 = stablehlo.add %50, %46 : tensor<20x20xf32> ++ %51 = stablehlo.add %46, %50 : tensor<20x20xf32> + %52 = stablehlo.constant dense<2.01490307> : tensor + %53 = stablehlo.constant dense<2.01490307> : tensor<20x20xf32> + %54 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> + %55 = stablehlo.divide %46, %54 : tensor<20x20xf32> + %56 = stablehlo.log_plus_one %55 : tensor<20x20xf32> +- %57 = stablehlo.add %53, %56 : tensor<20x20xf32> ++ %57 = stablehlo.add %56, %53 : tensor<20x20xf32> + %58 = stablehlo.divide %51, %57 : tensor<20x20xf32> + %59 = stablehlo.subtract %48, %58 : tensor<20x20xf32> + %60 = stablehlo.multiply %59, %57 : tensor<20x20xf32> +- %61 = stablehlo.add %41, %60 : tensor<20x20xf32> ++ %61 = stablehlo.add %60, %41 : tensor<20x20xf32> + %62 = stablehlo.constant dense<1.000000e+00> : tensor + %63 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %64 = stablehlo.constant dense<676.520386> : tensor +@@ -307,7 +307,7 @@ + %69 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %70 = stablehlo.add %68, %69 : tensor<20x20xf32> + %71 = stablehlo.divide %65, %70 : tensor<20x20xf32> +- %72 = stablehlo.add %63, %71 : tensor<20x20xf32> ++ %72 = stablehlo.add %71, %63 : tensor<20x20xf32> + %73 = stablehlo.constant dense<-1259.13916> : tensor + %74 = stablehlo.constant dense<-1259.13916> : tensor<20x20xf32> + %75 = stablehlo.constant dense<1.000000e+00> : tensor +@@ -428,7 +428,7 @@ + %228 = stablehlo.multiply %iterArg_4, %227 : tensor<20x20xf32> + %229 = stablehlo.constant dense<-1.000000e+00> : tensor + %230 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xf32> +- %231 = stablehlo.multiply %230, %iterArg_1 : tensor<20x20xf32> ++ %231 = stablehlo.multiply %iterArg_1, %230 : tensor<20x20xf32> + %232 = stablehlo.multiply %231, %iterArg_3 : tensor<20x20xf32> + %233 = stablehlo.multiply %215, %215 : tensor<20x20xf32> + %234 = stablehlo.divide %232, %233 : tensor<20x20xf32> +diff --ruN a/stablehlo/stablehlo/testdata/igammac_dtypes_lhs_bfloat16_20_20__rhs_bfloat16_20_20.mlir b/stablehlo/stablehlo/testdata/igammac_dtypes_lhs_bfloat16_20_20__rhs_bfloat16_20_20.mlir +--- stablehlo/stablehlo/testdata/igammac_dtypes_lhs_bfloat16_20_20__rhs_bfloat16_20_20.mlir ++++ stablehlo/stablehlo/testdata/igammac_dtypes_lhs_bfloat16_20_20__rhs_bfloat16_20_20.mlir +@@ -47,7 +47,7 @@ + %21 = stablehlo.multiply %19, %20 : tensor<20x20xf32> + %22 = stablehlo.constant dense<-1.000000e+00> : tensor + %23 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xf32> +- %24 = stablehlo.multiply %23, %1 : tensor<20x20xf32> ++ %24 = stablehlo.multiply %1, %23 : tensor<20x20xf32> + %25 = stablehlo.multiply %24, %2 : tensor<20x20xf32> + %26 = stablehlo.multiply %6, %6 : tensor<20x20xf32> + %27 = stablehlo.divide %25, %26 : tensor<20x20xf32> +@@ -268,7 +268,7 @@ + %30 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %31 = stablehlo.subtract %30, %25 : tensor<20x20xf32> + %32 = stablehlo.select %28, %31, %25 : tensor<20x20xi1>, tensor<20x20xf32> +- %33 = stablehlo.multiply %22, %32 : tensor<20x20xf32> ++ %33 = stablehlo.multiply %32, %22 : tensor<20x20xf32> + %34 = stablehlo.sine %33 : tensor<20x20xf32> + %35 = stablehlo.log %34 : tensor<20x20xf32> + %36 = stablehlo.is_finite %35 : (tensor<20x20xf32>) -> tensor<20x20xi1> +@@ -286,17 +286,17 @@ + %48 = stablehlo.add %46, %47 : tensor<20x20xf32> + %49 = stablehlo.constant dense<7.500000e+00> : tensor + %50 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> +- %51 = stablehlo.add %50, %46 : tensor<20x20xf32> ++ %51 = stablehlo.add %46, %50 : tensor<20x20xf32> + %52 = stablehlo.constant dense<2.01490307> : tensor + %53 = stablehlo.constant dense<2.01490307> : tensor<20x20xf32> + %54 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> + %55 = stablehlo.divide %46, %54 : tensor<20x20xf32> + %56 = stablehlo.log_plus_one %55 : tensor<20x20xf32> +- %57 = stablehlo.add %53, %56 : tensor<20x20xf32> ++ %57 = stablehlo.add %56, %53 : tensor<20x20xf32> + %58 = stablehlo.divide %51, %57 : tensor<20x20xf32> + %59 = stablehlo.subtract %48, %58 : tensor<20x20xf32> + %60 = stablehlo.multiply %59, %57 : tensor<20x20xf32> +- %61 = stablehlo.add %41, %60 : tensor<20x20xf32> ++ %61 = stablehlo.add %60, %41 : tensor<20x20xf32> + %62 = stablehlo.constant dense<1.000000e+00> : tensor + %63 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %64 = stablehlo.constant dense<676.520386> : tensor +@@ -307,7 +307,7 @@ + %69 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %70 = stablehlo.add %68, %69 : tensor<20x20xf32> + %71 = stablehlo.divide %65, %70 : tensor<20x20xf32> +- %72 = stablehlo.add %63, %71 : tensor<20x20xf32> ++ %72 = stablehlo.add %71, %63 : tensor<20x20xf32> + %73 = stablehlo.constant dense<-1259.13916> : tensor + %74 = stablehlo.constant dense<-1259.13916> : tensor<20x20xf32> + %75 = stablehlo.constant dense<1.000000e+00> : tensor +@@ -428,7 +428,7 @@ + %229 = stablehlo.multiply %iterArg_4, %228 : tensor<20x20xf32> + %230 = stablehlo.constant dense<-1.000000e+00> : tensor + %231 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xf32> +- %232 = stablehlo.multiply %231, %iterArg_1 : tensor<20x20xf32> ++ %232 = stablehlo.multiply %iterArg_1, %231 : tensor<20x20xf32> + %233 = stablehlo.multiply %232, %iterArg_3 : tensor<20x20xf32> + %234 = stablehlo.multiply %216, %216 : tensor<20x20xf32> + %235 = stablehlo.divide %233, %234 : tensor<20x20xf32> +diff --ruN a/stablehlo/stablehlo/testdata/igammac_dtypes_lhs_float16_20_20__rhs_float16_20_20.mlir b/stablehlo/stablehlo/testdata/igammac_dtypes_lhs_float16_20_20__rhs_float16_20_20.mlir +--- stablehlo/stablehlo/testdata/igammac_dtypes_lhs_float16_20_20__rhs_float16_20_20.mlir ++++ stablehlo/stablehlo/testdata/igammac_dtypes_lhs_float16_20_20__rhs_float16_20_20.mlir +@@ -47,7 +47,7 @@ + %21 = stablehlo.multiply %19, %20 : tensor<20x20xf32> + %22 = stablehlo.constant dense<-1.000000e+00> : tensor + %23 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xf32> +- %24 = stablehlo.multiply %23, %1 : tensor<20x20xf32> ++ %24 = stablehlo.multiply %1, %23 : tensor<20x20xf32> + %25 = stablehlo.multiply %24, %2 : tensor<20x20xf32> + %26 = stablehlo.multiply %6, %6 : tensor<20x20xf32> + %27 = stablehlo.divide %25, %26 : tensor<20x20xf32> +@@ -268,7 +268,7 @@ + %30 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %31 = stablehlo.subtract %30, %25 : tensor<20x20xf32> + %32 = stablehlo.select %28, %31, %25 : tensor<20x20xi1>, tensor<20x20xf32> +- %33 = stablehlo.multiply %22, %32 : tensor<20x20xf32> ++ %33 = stablehlo.multiply %32, %22 : tensor<20x20xf32> + %34 = stablehlo.sine %33 : tensor<20x20xf32> + %35 = stablehlo.log %34 : tensor<20x20xf32> + %36 = stablehlo.is_finite %35 : (tensor<20x20xf32>) -> tensor<20x20xi1> +@@ -286,17 +286,17 @@ + %48 = stablehlo.add %46, %47 : tensor<20x20xf32> + %49 = stablehlo.constant dense<7.500000e+00> : tensor + %50 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> +- %51 = stablehlo.add %50, %46 : tensor<20x20xf32> ++ %51 = stablehlo.add %46, %50 : tensor<20x20xf32> + %52 = stablehlo.constant dense<2.01490307> : tensor + %53 = stablehlo.constant dense<2.01490307> : tensor<20x20xf32> + %54 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> + %55 = stablehlo.divide %46, %54 : tensor<20x20xf32> + %56 = stablehlo.log_plus_one %55 : tensor<20x20xf32> +- %57 = stablehlo.add %53, %56 : tensor<20x20xf32> ++ %57 = stablehlo.add %56, %53 : tensor<20x20xf32> + %58 = stablehlo.divide %51, %57 : tensor<20x20xf32> + %59 = stablehlo.subtract %48, %58 : tensor<20x20xf32> + %60 = stablehlo.multiply %59, %57 : tensor<20x20xf32> +- %61 = stablehlo.add %41, %60 : tensor<20x20xf32> ++ %61 = stablehlo.add %60, %41 : tensor<20x20xf32> + %62 = stablehlo.constant dense<1.000000e+00> : tensor + %63 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %64 = stablehlo.constant dense<676.520386> : tensor +@@ -307,7 +307,7 @@ + %69 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %70 = stablehlo.add %68, %69 : tensor<20x20xf32> + %71 = stablehlo.divide %65, %70 : tensor<20x20xf32> +- %72 = stablehlo.add %63, %71 : tensor<20x20xf32> ++ %72 = stablehlo.add %71, %63 : tensor<20x20xf32> + %73 = stablehlo.constant dense<-1259.13916> : tensor + %74 = stablehlo.constant dense<-1259.13916> : tensor<20x20xf32> + %75 = stablehlo.constant dense<1.000000e+00> : tensor +@@ -428,7 +428,7 @@ + %229 = stablehlo.multiply %iterArg_4, %228 : tensor<20x20xf32> + %230 = stablehlo.constant dense<-1.000000e+00> : tensor + %231 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xf32> +- %232 = stablehlo.multiply %231, %iterArg_1 : tensor<20x20xf32> ++ %232 = stablehlo.multiply %iterArg_1, %231 : tensor<20x20xf32> + %233 = stablehlo.multiply %232, %iterArg_3 : tensor<20x20xf32> + %234 = stablehlo.multiply %216, %216 : tensor<20x20xf32> + %235 = stablehlo.divide %233, %234 : tensor<20x20xf32> +diff --ruN a/stablehlo/stablehlo/testdata/igammac_dtypes_lhs_float32_20_20__rhs_float32_20_20.mlir b/stablehlo/stablehlo/testdata/igammac_dtypes_lhs_float32_20_20__rhs_float32_20_20.mlir +--- stablehlo/stablehlo/testdata/igammac_dtypes_lhs_float32_20_20__rhs_float32_20_20.mlir ++++ stablehlo/stablehlo/testdata/igammac_dtypes_lhs_float32_20_20__rhs_float32_20_20.mlir +@@ -47,7 +47,7 @@ + %21 = stablehlo.multiply %19, %20 : tensor<20x20xf32> + %22 = stablehlo.constant dense<-1.000000e+00> : tensor + %23 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xf32> +- %24 = stablehlo.multiply %23, %1 : tensor<20x20xf32> ++ %24 = stablehlo.multiply %1, %23 : tensor<20x20xf32> + %25 = stablehlo.multiply %24, %2 : tensor<20x20xf32> + %26 = stablehlo.multiply %6, %6 : tensor<20x20xf32> + %27 = stablehlo.divide %25, %26 : tensor<20x20xf32> +@@ -266,7 +266,7 @@ + %28 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %29 = stablehlo.subtract %28, %23 : tensor<20x20xf32> + %30 = stablehlo.select %26, %29, %23 : tensor<20x20xi1>, tensor<20x20xf32> +- %31 = stablehlo.multiply %20, %30 : tensor<20x20xf32> ++ %31 = stablehlo.multiply %30, %20 : tensor<20x20xf32> + %32 = stablehlo.sine %31 : tensor<20x20xf32> + %33 = stablehlo.log %32 : tensor<20x20xf32> + %34 = stablehlo.is_finite %33 : (tensor<20x20xf32>) -> tensor<20x20xi1> +@@ -284,17 +284,17 @@ + %46 = stablehlo.add %44, %45 : tensor<20x20xf32> + %47 = stablehlo.constant dense<7.500000e+00> : tensor + %48 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> +- %49 = stablehlo.add %48, %44 : tensor<20x20xf32> ++ %49 = stablehlo.add %44, %48 : tensor<20x20xf32> + %50 = stablehlo.constant dense<2.01490307> : tensor + %51 = stablehlo.constant dense<2.01490307> : tensor<20x20xf32> + %52 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> + %53 = stablehlo.divide %44, %52 : tensor<20x20xf32> + %54 = stablehlo.log_plus_one %53 : tensor<20x20xf32> +- %55 = stablehlo.add %51, %54 : tensor<20x20xf32> ++ %55 = stablehlo.add %54, %51 : tensor<20x20xf32> + %56 = stablehlo.divide %49, %55 : tensor<20x20xf32> + %57 = stablehlo.subtract %46, %56 : tensor<20x20xf32> + %58 = stablehlo.multiply %57, %55 : tensor<20x20xf32> +- %59 = stablehlo.add %39, %58 : tensor<20x20xf32> ++ %59 = stablehlo.add %58, %39 : tensor<20x20xf32> + %60 = stablehlo.constant dense<1.000000e+00> : tensor + %61 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %62 = stablehlo.constant dense<676.520386> : tensor +@@ -305,7 +305,7 @@ + %67 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %68 = stablehlo.add %66, %67 : tensor<20x20xf32> + %69 = stablehlo.divide %63, %68 : tensor<20x20xf32> +- %70 = stablehlo.add %61, %69 : tensor<20x20xf32> ++ %70 = stablehlo.add %69, %61 : tensor<20x20xf32> + %71 = stablehlo.constant dense<-1259.13916> : tensor + %72 = stablehlo.constant dense<-1259.13916> : tensor<20x20xf32> + %73 = stablehlo.constant dense<1.000000e+00> : tensor +@@ -426,7 +426,7 @@ + %226 = stablehlo.multiply %iterArg_4, %225 : tensor<20x20xf32> + %227 = stablehlo.constant dense<-1.000000e+00> : tensor + %228 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xf32> +- %229 = stablehlo.multiply %228, %iterArg_1 : tensor<20x20xf32> ++ %229 = stablehlo.multiply %iterArg_1, %228 : tensor<20x20xf32> + %230 = stablehlo.multiply %229, %iterArg_3 : tensor<20x20xf32> + %231 = stablehlo.multiply %213, %213 : tensor<20x20xf32> + %232 = stablehlo.divide %230, %231 : tensor<20x20xf32> +diff --ruN a/stablehlo/stablehlo/testdata/index_in_dim_0_dynamic.mlir b/stablehlo/stablehlo/testdata/index_in_dim_0_dynamic.mlir +--- stablehlo/stablehlo/testdata/index_in_dim_0_dynamic.mlir ++++ stablehlo/stablehlo/testdata/index_in_dim_0_dynamic.mlir +@@ -3,7 +3,7 @@ + module @jit_fun_flat_jax { + func.func public @main(%arg0: tensor, %arg1: tensor {mhlo.sharding = ""}) -> tensor<4xf32> { + %0 = stablehlo.constant dense<-1> : tensor +- %1 = stablehlo.add %0, %arg0 : tensor ++ %1 = stablehlo.add %arg0, %0 : tensor + %2 = stablehlo.convert %1 : (tensor) -> tensor + %3 = stablehlo.reshape %2 : (tensor) -> tensor<1xi32> + %4 = stablehlo.constant dense<0> : tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/index_in_dim_idx_neg_dynamic.mlir b/stablehlo/stablehlo/testdata/index_in_dim_idx_neg_dynamic.mlir +--- stablehlo/stablehlo/testdata/index_in_dim_idx_neg_dynamic.mlir ++++ stablehlo/stablehlo/testdata/index_in_dim_idx_neg_dynamic.mlir +@@ -3,7 +3,7 @@ + module @jit_fun_flat_jax { + func.func public @main(%arg0: tensor, %arg1: tensor {mhlo.sharding = ""}) -> tensor<4xf32> { + %0 = stablehlo.constant dense<-1> : tensor +- %1 = stablehlo.add %0, %arg0 : tensor ++ %1 = stablehlo.add %arg0, %0 : tensor + %2 = stablehlo.convert %1 : (tensor) -> tensor + %3 = stablehlo.reshape %2 : (tensor) -> tensor<1xi32> + %4 = stablehlo.constant dense<0> : tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/lgamma_shape_bfloat16_20_20.mlir b/stablehlo/stablehlo/testdata/lgamma_shape_bfloat16_20_20.mlir +--- stablehlo/stablehlo/testdata/lgamma_shape_bfloat16_20_20.mlir ++++ stablehlo/stablehlo/testdata/lgamma_shape_bfloat16_20_20.mlir +@@ -17,7 +17,7 @@ + %11 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %12 = stablehlo.add %8, %11 : tensor<20x20xf32> + %13 = stablehlo.divide %10, %12 : tensor<20x20xf32> +- %14 = stablehlo.add %9, %13 : tensor<20x20xf32> ++ %14 = stablehlo.add %13, %9 : tensor<20x20xf32> + %15 = stablehlo.constant dense<-1259.13916> : tensor<20x20xf32> + %16 = stablehlo.constant dense<2.000000e+00> : tensor<20x20xf32> + %17 = stablehlo.add %8, %16 : tensor<20x20xf32> +@@ -54,18 +54,18 @@ + %48 = stablehlo.divide %45, %47 : tensor<20x20xf32> + %49 = stablehlo.add %44, %48 : tensor<20x20xf32> + %50 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> +- %51 = stablehlo.add %50, %8 : tensor<20x20xf32> ++ %51 = stablehlo.add %8, %50 : tensor<20x20xf32> + %52 = stablehlo.constant dense<2.01490307> : tensor<20x20xf32> + %53 = stablehlo.divide %8, %50 : tensor<20x20xf32> + %54 = stablehlo.log_plus_one %53 : tensor<20x20xf32> +- %55 = stablehlo.add %52, %54 : tensor<20x20xf32> ++ %55 = stablehlo.add %54, %52 : tensor<20x20xf32> + %56 = stablehlo.divide %51, %55 : tensor<20x20xf32> + %57 = stablehlo.add %8, %3 : tensor<20x20xf32> + %58 = stablehlo.subtract %57, %56 : tensor<20x20xf32> + %59 = stablehlo.multiply %58, %55 : tensor<20x20xf32> + %60 = stablehlo.log %49 : tensor<20x20xf32> + %61 = stablehlo.constant dense<0.918938517> : tensor<20x20xf32> +- %62 = stablehlo.add %61, %59 : tensor<20x20xf32> ++ %62 = stablehlo.add %59, %61 : tensor<20x20xf32> + %63 = stablehlo.add %62, %60 : tensor<20x20xf32> + %64 = stablehlo.abs %2 : tensor<20x20xf32> + %65 = stablehlo.floor %64 : tensor<20x20xf32> +@@ -74,7 +74,7 @@ + %68 = stablehlo.subtract %6, %66 : tensor<20x20xf32> + %69 = stablehlo.select %67, %68, %66 : tensor<20x20xi1>, tensor<20x20xf32> + %70 = stablehlo.constant dense<3.14159274> : tensor<20x20xf32> +- %71 = stablehlo.multiply %70, %69 : tensor<20x20xf32> ++ %71 = stablehlo.multiply %69, %70 : tensor<20x20xf32> + %72 = stablehlo.sine %71 : tensor<20x20xf32> + %73 = stablehlo.log %72 : tensor<20x20xf32> + %74 = stablehlo.constant dense<1.14472985> : tensor<20x20xf32> +diff --ruN a/stablehlo/stablehlo/testdata/lgamma_shape_float16_20_20.mlir b/stablehlo/stablehlo/testdata/lgamma_shape_float16_20_20.mlir +--- stablehlo/stablehlo/testdata/lgamma_shape_float16_20_20.mlir ++++ stablehlo/stablehlo/testdata/lgamma_shape_float16_20_20.mlir +@@ -17,7 +17,7 @@ + %11 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %12 = stablehlo.add %8, %11 : tensor<20x20xf32> + %13 = stablehlo.divide %10, %12 : tensor<20x20xf32> +- %14 = stablehlo.add %9, %13 : tensor<20x20xf32> ++ %14 = stablehlo.add %13, %9 : tensor<20x20xf32> + %15 = stablehlo.constant dense<-1259.13916> : tensor<20x20xf32> + %16 = stablehlo.constant dense<2.000000e+00> : tensor<20x20xf32> + %17 = stablehlo.add %8, %16 : tensor<20x20xf32> +@@ -54,18 +54,18 @@ + %48 = stablehlo.divide %45, %47 : tensor<20x20xf32> + %49 = stablehlo.add %44, %48 : tensor<20x20xf32> + %50 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> +- %51 = stablehlo.add %50, %8 : tensor<20x20xf32> ++ %51 = stablehlo.add %8, %50 : tensor<20x20xf32> + %52 = stablehlo.constant dense<2.01490307> : tensor<20x20xf32> + %53 = stablehlo.divide %8, %50 : tensor<20x20xf32> + %54 = stablehlo.log_plus_one %53 : tensor<20x20xf32> +- %55 = stablehlo.add %52, %54 : tensor<20x20xf32> ++ %55 = stablehlo.add %54, %52 : tensor<20x20xf32> + %56 = stablehlo.divide %51, %55 : tensor<20x20xf32> + %57 = stablehlo.add %8, %3 : tensor<20x20xf32> + %58 = stablehlo.subtract %57, %56 : tensor<20x20xf32> + %59 = stablehlo.multiply %58, %55 : tensor<20x20xf32> + %60 = stablehlo.log %49 : tensor<20x20xf32> + %61 = stablehlo.constant dense<0.918938517> : tensor<20x20xf32> +- %62 = stablehlo.add %61, %59 : tensor<20x20xf32> ++ %62 = stablehlo.add %59, %61 : tensor<20x20xf32> + %63 = stablehlo.add %62, %60 : tensor<20x20xf32> + %64 = stablehlo.abs %2 : tensor<20x20xf32> + %65 = stablehlo.floor %64 : tensor<20x20xf32> +@@ -74,7 +74,7 @@ + %68 = stablehlo.subtract %6, %66 : tensor<20x20xf32> + %69 = stablehlo.select %67, %68, %66 : tensor<20x20xi1>, tensor<20x20xf32> + %70 = stablehlo.constant dense<3.14159274> : tensor<20x20xf32> +- %71 = stablehlo.multiply %70, %69 : tensor<20x20xf32> ++ %71 = stablehlo.multiply %69, %70 : tensor<20x20xf32> + %72 = stablehlo.sine %71 : tensor<20x20xf32> + %73 = stablehlo.log %72 : tensor<20x20xf32> + %74 = stablehlo.constant dense<1.14472985> : tensor<20x20xf32> +diff --ruN a/stablehlo/stablehlo/testdata/lgamma_shape_float32_20_20.mlir b/stablehlo/stablehlo/testdata/lgamma_shape_float32_20_20.mlir +--- stablehlo/stablehlo/testdata/lgamma_shape_float32_20_20.mlir ++++ stablehlo/stablehlo/testdata/lgamma_shape_float32_20_20.mlir +@@ -16,7 +16,7 @@ + %10 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %11 = stablehlo.add %7, %10 : tensor<20x20xf32> + %12 = stablehlo.divide %9, %11 : tensor<20x20xf32> +- %13 = stablehlo.add %8, %12 : tensor<20x20xf32> ++ %13 = stablehlo.add %12, %8 : tensor<20x20xf32> + %14 = stablehlo.constant dense<-1259.13916> : tensor<20x20xf32> + %15 = stablehlo.constant dense<2.000000e+00> : tensor<20x20xf32> + %16 = stablehlo.add %7, %15 : tensor<20x20xf32> +@@ -53,18 +53,18 @@ + %47 = stablehlo.divide %44, %46 : tensor<20x20xf32> + %48 = stablehlo.add %43, %47 : tensor<20x20xf32> + %49 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> +- %50 = stablehlo.add %49, %7 : tensor<20x20xf32> ++ %50 = stablehlo.add %7, %49 : tensor<20x20xf32> + %51 = stablehlo.constant dense<2.01490307> : tensor<20x20xf32> + %52 = stablehlo.divide %7, %49 : tensor<20x20xf32> + %53 = stablehlo.log_plus_one %52 : tensor<20x20xf32> +- %54 = stablehlo.add %51, %53 : tensor<20x20xf32> ++ %54 = stablehlo.add %53, %51 : tensor<20x20xf32> + %55 = stablehlo.divide %50, %54 : tensor<20x20xf32> + %56 = stablehlo.add %7, %2 : tensor<20x20xf32> + %57 = stablehlo.subtract %56, %55 : tensor<20x20xf32> + %58 = stablehlo.multiply %57, %54 : tensor<20x20xf32> + %59 = stablehlo.log %48 : tensor<20x20xf32> + %60 = stablehlo.constant dense<0.918938517> : tensor<20x20xf32> +- %61 = stablehlo.add %60, %58 : tensor<20x20xf32> ++ %61 = stablehlo.add %58, %60 : tensor<20x20xf32> + %62 = stablehlo.add %61, %59 : tensor<20x20xf32> + %63 = stablehlo.abs %0 : tensor<20x20xf32> + %64 = stablehlo.floor %63 : tensor<20x20xf32> +@@ -73,7 +73,7 @@ + %67 = stablehlo.subtract %5, %65 : tensor<20x20xf32> + %68 = stablehlo.select %66, %67, %65 : tensor<20x20xi1>, tensor<20x20xf32> + %69 = stablehlo.constant dense<3.14159274> : tensor<20x20xf32> +- %70 = stablehlo.multiply %69, %68 : tensor<20x20xf32> ++ %70 = stablehlo.multiply %68, %69 : tensor<20x20xf32> + %71 = stablehlo.sine %70 : tensor<20x20xf32> + %72 = stablehlo.log %71 : tensor<20x20xf32> + %73 = stablehlo.constant dense<1.14472985> : tensor<20x20xf32> +diff --ruN a/stablehlo/stablehlo/testdata/nanquantile_axis_None_dynamic.mlir b/stablehlo/stablehlo/testdata/nanquantile_axis_None_dynamic.mlir +--- stablehlo/stablehlo/testdata/nanquantile_axis_None_dynamic.mlir ++++ stablehlo/stablehlo/testdata/nanquantile_axis_None_dynamic.mlir +@@ -72,12 +72,12 @@ + %24 = stablehlo.subtract %14, %23 : tensor + %25 = stablehlo.minimum %18, %24 : tensor + %26 = stablehlo.constant dense<0.000000e+00> : tensor +- %27 = stablehlo.maximum %26, %25 : tensor ++ %27 = stablehlo.maximum %25, %26 : tensor + %28 = stablehlo.constant dense<1.000000e+00> : tensor + %29 = stablehlo.subtract %14, %28 : tensor + %30 = stablehlo.minimum %19, %29 : tensor + %31 = stablehlo.constant dense<0.000000e+00> : tensor +- %32 = stablehlo.maximum %31, %30 : tensor ++ %32 = stablehlo.maximum %30, %31 : tensor + %33 = stablehlo.convert %27 : (tensor) -> tensor + %34 = stablehlo.convert %32 : (tensor) -> tensor + %35 = stablehlo.constant dense<5> : tensor +diff --ruN a/stablehlo/stablehlo/testdata/random_gamma_shape_float32.mlir b/stablehlo/stablehlo/testdata/random_gamma_shape_float32.mlir +--- stablehlo/stablehlo/testdata/random_gamma_shape_float32.mlir ++++ stablehlo/stablehlo/testdata/random_gamma_shape_float32.mlir +@@ -338,7 +338,7 @@ + %122 = stablehlo.add %120, %121 : tensor + %123 = stablehlo.reshape %122 : (tensor) -> tensor + %124 = stablehlo.constant dense<0.000000e+00> : tensor +- %125 = stablehlo.maximum %124, %123 : tensor ++ %125 = stablehlo.maximum %123, %124 : tensor + %126 = stablehlo.constant dense<0.000000e+00> : tensor + %127 = stablehlo.constant dense<1.000000e+00> : tensor + %128 = stablehlo.constant dense<2.000000e+00> : tensor +@@ -346,7 +346,7 @@ + cond { + %151 = stablehlo.multiply %iterArg_6, %iterArg_6 : tensor + %152 = stablehlo.constant dense<3.310000e-02> : tensor +- %153 = stablehlo.multiply %152, %151 : tensor ++ %153 = stablehlo.multiply %151, %152 : tensor + %154 = stablehlo.constant dense<1.000000e+00> : tensor + %155 = stablehlo.subtract %154, %153 : tensor + %156 = stablehlo.compare GE, %iterArg_8, %155, FLOAT : (tensor, tensor) -> tensor +@@ -656,13 +656,13 @@ + %289 = stablehlo.add %287, %288 : tensor + %290 = stablehlo.reshape %289 : (tensor) -> tensor + %291 = stablehlo.constant dense<-0.99999994> : tensor +- %292 = stablehlo.maximum %291, %290 : tensor ++ %292 = stablehlo.maximum %290, %291 : tensor + %293 = func.call @erf_inv(%292) : (tensor) -> tensor + %294 = stablehlo.constant dense<1.41421354> : tensor +- %295 = stablehlo.multiply %294, %293 : tensor ++ %295 = stablehlo.multiply %293, %294 : tensor + %296 = stablehlo.multiply %295, %iterArg_9 : tensor + %297 = stablehlo.constant dense<1.000000e+00> : tensor +- %298 = stablehlo.add %297, %296 : tensor ++ %298 = stablehlo.add %296, %297 : tensor + stablehlo.return %iterArg_9, %248, %295, %298 : tensor, tensor<2xui32>, tensor, tensor + } + %181 = stablehlo.multiply %180#2, %180#2 : tensor +@@ -773,7 +773,7 @@ + %222 = stablehlo.add %220, %221 : tensor + %223 = stablehlo.reshape %222 : (tensor) -> tensor + %224 = stablehlo.constant dense<0.000000e+00> : tensor +- %225 = stablehlo.maximum %224, %223 : tensor ++ %225 = stablehlo.maximum %223, %224 : tensor + stablehlo.return %iterArg_3, %iterArg_4, %173, %181, %183, %225 : tensor, tensor, tensor<2xui32>, tensor, tensor, tensor + } + %130 = stablehlo.constant dense<1.000000e+00> : tensor +diff --ruN a/stablehlo/stablehlo/testdata/random_gamma_shape_float32_3.mlir b/stablehlo/stablehlo/testdata/random_gamma_shape_float32_3.mlir +--- stablehlo/stablehlo/testdata/random_gamma_shape_float32_3.mlir ++++ stablehlo/stablehlo/testdata/random_gamma_shape_float32_3.mlir +@@ -341,7 +341,7 @@ + %122 = stablehlo.add %120, %121 : tensor + %123 = stablehlo.reshape %122 : (tensor) -> tensor + %124 = stablehlo.constant dense<0.000000e+00> : tensor +- %125 = stablehlo.maximum %124, %123 : tensor ++ %125 = stablehlo.maximum %123, %124 : tensor + %126 = stablehlo.constant dense<0.000000e+00> : tensor + %127 = stablehlo.constant dense<1.000000e+00> : tensor + %128 = stablehlo.constant dense<2.000000e+00> : tensor +@@ -349,7 +349,7 @@ + cond { + %151 = stablehlo.multiply %iterArg_6, %iterArg_6 : tensor + %152 = stablehlo.constant dense<3.310000e-02> : tensor +- %153 = stablehlo.multiply %152, %151 : tensor ++ %153 = stablehlo.multiply %151, %152 : tensor + %154 = stablehlo.constant dense<1.000000e+00> : tensor + %155 = stablehlo.subtract %154, %153 : tensor + %156 = stablehlo.compare GE, %iterArg_8, %155, FLOAT : (tensor, tensor) -> tensor +@@ -659,13 +659,13 @@ + %289 = stablehlo.add %287, %288 : tensor + %290 = stablehlo.reshape %289 : (tensor) -> tensor + %291 = stablehlo.constant dense<-0.99999994> : tensor +- %292 = stablehlo.maximum %291, %290 : tensor ++ %292 = stablehlo.maximum %290, %291 : tensor + %293 = func.call @erf_inv(%292) : (tensor) -> tensor + %294 = stablehlo.constant dense<1.41421354> : tensor +- %295 = stablehlo.multiply %294, %293 : tensor ++ %295 = stablehlo.multiply %293, %294 : tensor + %296 = stablehlo.multiply %295, %iterArg_9 : tensor + %297 = stablehlo.constant dense<1.000000e+00> : tensor +- %298 = stablehlo.add %297, %296 : tensor ++ %298 = stablehlo.add %296, %297 : tensor + stablehlo.return %iterArg_9, %248, %295, %298 : tensor, tensor<2xui32>, tensor, tensor + } + %181 = stablehlo.multiply %180#2, %180#2 : tensor +@@ -776,7 +776,7 @@ + %222 = stablehlo.add %220, %221 : tensor + %223 = stablehlo.reshape %222 : (tensor) -> tensor + %224 = stablehlo.constant dense<0.000000e+00> : tensor +- %225 = stablehlo.maximum %224, %223 : tensor ++ %225 = stablehlo.maximum %223, %224 : tensor + stablehlo.return %iterArg_3, %iterArg_4, %173, %181, %183, %225 : tensor, tensor, tensor<2xui32>, tensor, tensor, tensor + } + %130 = stablehlo.constant dense<1.000000e+00> : tensor +diff --ruN a/stablehlo/stablehlo/testdata/random_gamma_shape_float64.mlir b/stablehlo/stablehlo/testdata/random_gamma_shape_float64.mlir +--- stablehlo/stablehlo/testdata/random_gamma_shape_float64.mlir ++++ stablehlo/stablehlo/testdata/random_gamma_shape_float64.mlir +@@ -338,7 +338,7 @@ + %122 = stablehlo.add %120, %121 : tensor + %123 = stablehlo.reshape %122 : (tensor) -> tensor + %124 = stablehlo.constant dense<0.000000e+00> : tensor +- %125 = stablehlo.maximum %124, %123 : tensor ++ %125 = stablehlo.maximum %123, %124 : tensor + %126 = stablehlo.constant dense<0.000000e+00> : tensor + %127 = stablehlo.constant dense<1.000000e+00> : tensor + %128 = stablehlo.constant dense<2.000000e+00> : tensor +@@ -346,7 +346,7 @@ + cond { + %151 = stablehlo.multiply %iterArg_6, %iterArg_6 : tensor + %152 = stablehlo.constant dense<3.310000e-02> : tensor +- %153 = stablehlo.multiply %152, %151 : tensor ++ %153 = stablehlo.multiply %151, %152 : tensor + %154 = stablehlo.constant dense<1.000000e+00> : tensor + %155 = stablehlo.subtract %154, %153 : tensor + %156 = stablehlo.compare GE, %iterArg_8, %155, FLOAT : (tensor, tensor) -> tensor +@@ -656,13 +656,13 @@ + %289 = stablehlo.add %287, %288 : tensor + %290 = stablehlo.reshape %289 : (tensor) -> tensor + %291 = stablehlo.constant dense<-0.99999994> : tensor +- %292 = stablehlo.maximum %291, %290 : tensor ++ %292 = stablehlo.maximum %290, %291 : tensor + %293 = func.call @erf_inv(%292) : (tensor) -> tensor + %294 = stablehlo.constant dense<1.41421354> : tensor +- %295 = stablehlo.multiply %294, %293 : tensor ++ %295 = stablehlo.multiply %293, %294 : tensor + %296 = stablehlo.multiply %295, %iterArg_9 : tensor + %297 = stablehlo.constant dense<1.000000e+00> : tensor +- %298 = stablehlo.add %297, %296 : tensor ++ %298 = stablehlo.add %296, %297 : tensor + stablehlo.return %iterArg_9, %248, %295, %298 : tensor, tensor<2xui32>, tensor, tensor + } + %181 = stablehlo.multiply %180#2, %180#2 : tensor +@@ -773,7 +773,7 @@ + %222 = stablehlo.add %220, %221 : tensor + %223 = stablehlo.reshape %222 : (tensor) -> tensor + %224 = stablehlo.constant dense<0.000000e+00> : tensor +- %225 = stablehlo.maximum %224, %223 : tensor ++ %225 = stablehlo.maximum %223, %224 : tensor + stablehlo.return %iterArg_3, %iterArg_4, %173, %181, %183, %225 : tensor, tensor, tensor<2xui32>, tensor, tensor, tensor + } + %130 = stablehlo.constant dense<1.000000e+00> : tensor +diff --ruN a/stablehlo/stablehlo/testdata/random_gamma_shape_float64_3.mlir b/stablehlo/stablehlo/testdata/random_gamma_shape_float64_3.mlir +--- stablehlo/stablehlo/testdata/random_gamma_shape_float64_3.mlir ++++ stablehlo/stablehlo/testdata/random_gamma_shape_float64_3.mlir +@@ -341,7 +341,7 @@ + %122 = stablehlo.add %120, %121 : tensor + %123 = stablehlo.reshape %122 : (tensor) -> tensor + %124 = stablehlo.constant dense<0.000000e+00> : tensor +- %125 = stablehlo.maximum %124, %123 : tensor ++ %125 = stablehlo.maximum %123, %124 : tensor + %126 = stablehlo.constant dense<0.000000e+00> : tensor + %127 = stablehlo.constant dense<1.000000e+00> : tensor + %128 = stablehlo.constant dense<2.000000e+00> : tensor +@@ -349,7 +349,7 @@ + cond { + %151 = stablehlo.multiply %iterArg_6, %iterArg_6 : tensor + %152 = stablehlo.constant dense<3.310000e-02> : tensor +- %153 = stablehlo.multiply %152, %151 : tensor ++ %153 = stablehlo.multiply %151, %152 : tensor + %154 = stablehlo.constant dense<1.000000e+00> : tensor + %155 = stablehlo.subtract %154, %153 : tensor + %156 = stablehlo.compare GE, %iterArg_8, %155, FLOAT : (tensor, tensor) -> tensor +@@ -659,13 +659,13 @@ + %289 = stablehlo.add %287, %288 : tensor + %290 = stablehlo.reshape %289 : (tensor) -> tensor + %291 = stablehlo.constant dense<-0.99999994> : tensor +- %292 = stablehlo.maximum %291, %290 : tensor ++ %292 = stablehlo.maximum %290, %291 : tensor + %293 = func.call @erf_inv(%292) : (tensor) -> tensor + %294 = stablehlo.constant dense<1.41421354> : tensor +- %295 = stablehlo.multiply %294, %293 : tensor ++ %295 = stablehlo.multiply %293, %294 : tensor + %296 = stablehlo.multiply %295, %iterArg_9 : tensor + %297 = stablehlo.constant dense<1.000000e+00> : tensor +- %298 = stablehlo.add %297, %296 : tensor ++ %298 = stablehlo.add %296, %297 : tensor + stablehlo.return %iterArg_9, %248, %295, %298 : tensor, tensor<2xui32>, tensor, tensor + } + %181 = stablehlo.multiply %180#2, %180#2 : tensor +@@ -776,7 +776,7 @@ + %222 = stablehlo.add %220, %221 : tensor + %223 = stablehlo.reshape %222 : (tensor) -> tensor + %224 = stablehlo.constant dense<0.000000e+00> : tensor +- %225 = stablehlo.maximum %224, %223 : tensor ++ %225 = stablehlo.maximum %223, %224 : tensor + stablehlo.return %iterArg_3, %iterArg_4, %173, %181, %183, %225 : tensor, tensor, tensor<2xui32>, tensor, tensor, tensor + } + %130 = stablehlo.constant dense<1.000000e+00> : tensor +diff --ruN a/stablehlo/stablehlo/testdata/random_uniform_shape_bfloat16.mlir b/stablehlo/stablehlo/testdata/random_uniform_shape_bfloat16.mlir +--- stablehlo/stablehlo/testdata/random_uniform_shape_bfloat16.mlir ++++ stablehlo/stablehlo/testdata/random_uniform_shape_bfloat16.mlir +@@ -125,7 +125,7 @@ + %55 = stablehlo.add %53, %54 : tensor + %56 = stablehlo.reshape %55 : (tensor) -> tensor + %57 = stablehlo.constant dense<0.000000e+00> : tensor +- %58 = stablehlo.maximum %57, %56 : tensor ++ %58 = stablehlo.maximum %56, %57 : tensor + %59 = stablehlo.custom_call @check.eq(%58, %1) : (tensor, tensor) -> tensor + return %59 : tensor + } +diff --ruN a/stablehlo/stablehlo/testdata/random_uniform_shape_float16.mlir b/stablehlo/stablehlo/testdata/random_uniform_shape_float16.mlir +--- stablehlo/stablehlo/testdata/random_uniform_shape_float16.mlir ++++ stablehlo/stablehlo/testdata/random_uniform_shape_float16.mlir +@@ -125,7 +125,7 @@ + %55 = stablehlo.add %53, %54 : tensor + %56 = stablehlo.reshape %55 : (tensor) -> tensor + %57 = stablehlo.constant dense<0.000000e+00> : tensor +- %58 = stablehlo.maximum %57, %56 : tensor ++ %58 = stablehlo.maximum %56, %57 : tensor + %59 = stablehlo.custom_call @check.eq(%58, %1) : (tensor, tensor) -> tensor + return %59 : tensor + } +diff --ruN a/stablehlo/stablehlo/testdata/random_uniform_shape_float32.mlir b/stablehlo/stablehlo/testdata/random_uniform_shape_float32.mlir +--- stablehlo/stablehlo/testdata/random_uniform_shape_float32.mlir ++++ stablehlo/stablehlo/testdata/random_uniform_shape_float32.mlir +@@ -110,7 +110,7 @@ + %40 = stablehlo.add %38, %39 : tensor + %41 = stablehlo.reshape %40 : (tensor) -> tensor + %42 = stablehlo.constant dense<0.000000e+00> : tensor +- %43 = stablehlo.maximum %42, %41 : tensor ++ %43 = stablehlo.maximum %41, %42 : tensor + %44 = stablehlo.custom_call @check.eq(%43, %1) : (tensor, tensor) -> tensor + return %44 : tensor + } +diff --ruN a/stablehlo/stablehlo/testdata/regularized_incomplete_beta__bfloat16.mlir b/stablehlo/stablehlo/testdata/regularized_incomplete_beta__bfloat16.mlir +--- stablehlo/stablehlo/testdata/regularized_incomplete_beta__bfloat16.mlir ++++ stablehlo/stablehlo/testdata/regularized_incomplete_beta__bfloat16.mlir +@@ -71,9 +71,9 @@ + %40 = stablehlo.multiply %38, %39 : tensor<9xf32> + %41 = stablehlo.constant dense<2.000000e+00> : tensor + %42 = stablehlo.constant dense<2.000000e+00> : tensor<9xf32> +- %43 = stablehlo.multiply %42, %32 : tensor<9xf32> ++ %43 = stablehlo.multiply %32, %42 : tensor<9xf32> + %44 = stablehlo.add %25, %43 : tensor<9xf32> +- %45 = stablehlo.multiply %42, %32 : tensor<9xf32> ++ %45 = stablehlo.multiply %32, %42 : tensor<9xf32> + %46 = stablehlo.add %25, %45 : tensor<9xf32> + %47 = stablehlo.constant dense<1.000000e+00> : tensor + %48 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> +@@ -83,10 +83,10 @@ + %52 = stablehlo.subtract %35, %32 : tensor<9xf32> + %53 = stablehlo.multiply %32, %52 : tensor<9xf32> + %54 = stablehlo.multiply %53, %39 : tensor<9xf32> +- %55 = stablehlo.multiply %42, %32 : tensor<9xf32> ++ %55 = stablehlo.multiply %32, %42 : tensor<9xf32> + %56 = stablehlo.add %25, %55 : tensor<9xf32> + %57 = stablehlo.subtract %56, %48 : tensor<9xf32> +- %58 = stablehlo.multiply %42, %32 : tensor<9xf32> ++ %58 = stablehlo.multiply %32, %42 : tensor<9xf32> + %59 = stablehlo.add %25, %58 : tensor<9xf32> + %60 = stablehlo.multiply %57, %59 : tensor<9xf32> + %61 = stablehlo.divide %54, %60 : tensor<9xf32> +@@ -225,9 +225,9 @@ + %502 = stablehlo.multiply %501, %iterArg_6 : tensor<9xf32> + %503 = stablehlo.constant dense<2.000000e+00> : tensor + %504 = stablehlo.constant dense<2.000000e+00> : tensor<9xf32> +- %505 = stablehlo.multiply %504, %496 : tensor<9xf32> ++ %505 = stablehlo.multiply %496, %504 : tensor<9xf32> + %506 = stablehlo.add %iterArg_4, %505 : tensor<9xf32> +- %507 = stablehlo.multiply %504, %496 : tensor<9xf32> ++ %507 = stablehlo.multiply %496, %504 : tensor<9xf32> + %508 = stablehlo.add %iterArg_4, %507 : tensor<9xf32> + %509 = stablehlo.constant dense<1.000000e+00> : tensor + %510 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> +@@ -237,10 +237,10 @@ + %514 = stablehlo.subtract %iterArg_5, %496 : tensor<9xf32> + %515 = stablehlo.multiply %496, %514 : tensor<9xf32> + %516 = stablehlo.multiply %515, %iterArg_6 : tensor<9xf32> +- %517 = stablehlo.multiply %504, %496 : tensor<9xf32> ++ %517 = stablehlo.multiply %496, %504 : tensor<9xf32> + %518 = stablehlo.add %iterArg_4, %517 : tensor<9xf32> + %519 = stablehlo.subtract %518, %510 : tensor<9xf32> +- %520 = stablehlo.multiply %504, %496 : tensor<9xf32> ++ %520 = stablehlo.multiply %496, %504 : tensor<9xf32> + %521 = stablehlo.add %iterArg_4, %520 : tensor<9xf32> + %522 = stablehlo.multiply %519, %521 : tensor<9xf32> + %523 = stablehlo.divide %516, %522 : tensor<9xf32> +@@ -322,7 +322,7 @@ + %79 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> + %80 = stablehlo.subtract %79, %74 : tensor<9xf32> + %81 = stablehlo.select %77, %80, %74 : tensor<9xi1>, tensor<9xf32> +- %82 = stablehlo.multiply %71, %81 : tensor<9xf32> ++ %82 = stablehlo.multiply %81, %71 : tensor<9xf32> + %83 = stablehlo.sine %82 : tensor<9xf32> + %84 = stablehlo.log %83 : tensor<9xf32> + %85 = stablehlo.is_finite %84 : (tensor<9xf32>) -> tensor<9xi1> +@@ -340,17 +340,17 @@ + %97 = stablehlo.add %95, %96 : tensor<9xf32> + %98 = stablehlo.constant dense<7.500000e+00> : tensor + %99 = stablehlo.constant dense<7.500000e+00> : tensor<9xf32> +- %100 = stablehlo.add %99, %95 : tensor<9xf32> ++ %100 = stablehlo.add %95, %99 : tensor<9xf32> + %101 = stablehlo.constant dense<2.01490307> : tensor + %102 = stablehlo.constant dense<2.01490307> : tensor<9xf32> + %103 = stablehlo.constant dense<7.500000e+00> : tensor<9xf32> + %104 = stablehlo.divide %95, %103 : tensor<9xf32> + %105 = stablehlo.log_plus_one %104 : tensor<9xf32> +- %106 = stablehlo.add %102, %105 : tensor<9xf32> ++ %106 = stablehlo.add %105, %102 : tensor<9xf32> + %107 = stablehlo.divide %100, %106 : tensor<9xf32> + %108 = stablehlo.subtract %97, %107 : tensor<9xf32> + %109 = stablehlo.multiply %108, %106 : tensor<9xf32> +- %110 = stablehlo.add %90, %109 : tensor<9xf32> ++ %110 = stablehlo.add %109, %90 : tensor<9xf32> + %111 = stablehlo.constant dense<1.000000e+00> : tensor + %112 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> + %113 = stablehlo.constant dense<676.520386> : tensor +@@ -361,7 +361,7 @@ + %118 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> + %119 = stablehlo.add %117, %118 : tensor<9xf32> + %120 = stablehlo.divide %114, %119 : tensor<9xf32> +- %121 = stablehlo.add %112, %120 : tensor<9xf32> ++ %121 = stablehlo.add %120, %112 : tensor<9xf32> + %122 = stablehlo.constant dense<-1259.13916> : tensor + %123 = stablehlo.constant dense<-1259.13916> : tensor<9xf32> + %124 = stablehlo.constant dense<1.000000e+00> : tensor +@@ -453,7 +453,7 @@ + %210 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> + %211 = stablehlo.subtract %210, %205 : tensor<9xf32> + %212 = stablehlo.select %208, %211, %205 : tensor<9xi1>, tensor<9xf32> +- %213 = stablehlo.multiply %202, %212 : tensor<9xf32> ++ %213 = stablehlo.multiply %212, %202 : tensor<9xf32> + %214 = stablehlo.sine %213 : tensor<9xf32> + %215 = stablehlo.log %214 : tensor<9xf32> + %216 = stablehlo.is_finite %215 : (tensor<9xf32>) -> tensor<9xi1> +@@ -471,17 +471,17 @@ + %228 = stablehlo.add %226, %227 : tensor<9xf32> + %229 = stablehlo.constant dense<7.500000e+00> : tensor + %230 = stablehlo.constant dense<7.500000e+00> : tensor<9xf32> +- %231 = stablehlo.add %230, %226 : tensor<9xf32> ++ %231 = stablehlo.add %226, %230 : tensor<9xf32> + %232 = stablehlo.constant dense<2.01490307> : tensor + %233 = stablehlo.constant dense<2.01490307> : tensor<9xf32> + %234 = stablehlo.constant dense<7.500000e+00> : tensor<9xf32> + %235 = stablehlo.divide %226, %234 : tensor<9xf32> + %236 = stablehlo.log_plus_one %235 : tensor<9xf32> +- %237 = stablehlo.add %233, %236 : tensor<9xf32> ++ %237 = stablehlo.add %236, %233 : tensor<9xf32> + %238 = stablehlo.divide %231, %237 : tensor<9xf32> + %239 = stablehlo.subtract %228, %238 : tensor<9xf32> + %240 = stablehlo.multiply %239, %237 : tensor<9xf32> +- %241 = stablehlo.add %221, %240 : tensor<9xf32> ++ %241 = stablehlo.add %240, %221 : tensor<9xf32> + %242 = stablehlo.constant dense<1.000000e+00> : tensor + %243 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> + %244 = stablehlo.constant dense<676.520386> : tensor +@@ -492,7 +492,7 @@ + %249 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> + %250 = stablehlo.add %248, %249 : tensor<9xf32> + %251 = stablehlo.divide %245, %250 : tensor<9xf32> +- %252 = stablehlo.add %243, %251 : tensor<9xf32> ++ %252 = stablehlo.add %251, %243 : tensor<9xf32> + %253 = stablehlo.constant dense<-1259.13916> : tensor + %254 = stablehlo.constant dense<-1259.13916> : tensor<9xf32> + %255 = stablehlo.constant dense<1.000000e+00> : tensor +@@ -586,7 +586,7 @@ + %343 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> + %344 = stablehlo.subtract %343, %338 : tensor<9xf32> + %345 = stablehlo.select %341, %344, %338 : tensor<9xi1>, tensor<9xf32> +- %346 = stablehlo.multiply %335, %345 : tensor<9xf32> ++ %346 = stablehlo.multiply %345, %335 : tensor<9xf32> + %347 = stablehlo.sine %346 : tensor<9xf32> + %348 = stablehlo.log %347 : tensor<9xf32> + %349 = stablehlo.is_finite %348 : (tensor<9xf32>) -> tensor<9xi1> +@@ -604,17 +604,17 @@ + %361 = stablehlo.add %359, %360 : tensor<9xf32> + %362 = stablehlo.constant dense<7.500000e+00> : tensor + %363 = stablehlo.constant dense<7.500000e+00> : tensor<9xf32> +- %364 = stablehlo.add %363, %359 : tensor<9xf32> ++ %364 = stablehlo.add %359, %363 : tensor<9xf32> + %365 = stablehlo.constant dense<2.01490307> : tensor + %366 = stablehlo.constant dense<2.01490307> : tensor<9xf32> + %367 = stablehlo.constant dense<7.500000e+00> : tensor<9xf32> + %368 = stablehlo.divide %359, %367 : tensor<9xf32> + %369 = stablehlo.log_plus_one %368 : tensor<9xf32> +- %370 = stablehlo.add %366, %369 : tensor<9xf32> ++ %370 = stablehlo.add %369, %366 : tensor<9xf32> + %371 = stablehlo.divide %364, %370 : tensor<9xf32> + %372 = stablehlo.subtract %361, %371 : tensor<9xf32> + %373 = stablehlo.multiply %372, %370 : tensor<9xf32> +- %374 = stablehlo.add %354, %373 : tensor<9xf32> ++ %374 = stablehlo.add %373, %354 : tensor<9xf32> + %375 = stablehlo.constant dense<1.000000e+00> : tensor + %376 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> + %377 = stablehlo.constant dense<676.520386> : tensor +@@ -625,7 +625,7 @@ + %382 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> + %383 = stablehlo.add %381, %382 : tensor<9xf32> + %384 = stablehlo.divide %378, %383 : tensor<9xf32> +- %385 = stablehlo.add %376, %384 : tensor<9xf32> ++ %385 = stablehlo.add %384, %376 : tensor<9xf32> + %386 = stablehlo.constant dense<-1259.13916> : tensor + %387 = stablehlo.constant dense<-1259.13916> : tensor<9xf32> + %388 = stablehlo.constant dense<1.000000e+00> : tensor +diff --ruN a/stablehlo/stablehlo/testdata/regularized_incomplete_beta__float16.mlir b/stablehlo/stablehlo/testdata/regularized_incomplete_beta__float16.mlir +--- stablehlo/stablehlo/testdata/regularized_incomplete_beta__float16.mlir ++++ stablehlo/stablehlo/testdata/regularized_incomplete_beta__float16.mlir +@@ -71,9 +71,9 @@ + %40 = stablehlo.multiply %38, %39 : tensor<9xf32> + %41 = stablehlo.constant dense<2.000000e+00> : tensor + %42 = stablehlo.constant dense<2.000000e+00> : tensor<9xf32> +- %43 = stablehlo.multiply %42, %32 : tensor<9xf32> ++ %43 = stablehlo.multiply %32, %42 : tensor<9xf32> + %44 = stablehlo.add %25, %43 : tensor<9xf32> +- %45 = stablehlo.multiply %42, %32 : tensor<9xf32> ++ %45 = stablehlo.multiply %32, %42 : tensor<9xf32> + %46 = stablehlo.add %25, %45 : tensor<9xf32> + %47 = stablehlo.constant dense<1.000000e+00> : tensor + %48 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> +@@ -83,10 +83,10 @@ + %52 = stablehlo.subtract %35, %32 : tensor<9xf32> + %53 = stablehlo.multiply %32, %52 : tensor<9xf32> + %54 = stablehlo.multiply %53, %39 : tensor<9xf32> +- %55 = stablehlo.multiply %42, %32 : tensor<9xf32> ++ %55 = stablehlo.multiply %32, %42 : tensor<9xf32> + %56 = stablehlo.add %25, %55 : tensor<9xf32> + %57 = stablehlo.subtract %56, %48 : tensor<9xf32> +- %58 = stablehlo.multiply %42, %32 : tensor<9xf32> ++ %58 = stablehlo.multiply %32, %42 : tensor<9xf32> + %59 = stablehlo.add %25, %58 : tensor<9xf32> + %60 = stablehlo.multiply %57, %59 : tensor<9xf32> + %61 = stablehlo.divide %54, %60 : tensor<9xf32> +@@ -225,9 +225,9 @@ + %502 = stablehlo.multiply %501, %iterArg_6 : tensor<9xf32> + %503 = stablehlo.constant dense<2.000000e+00> : tensor + %504 = stablehlo.constant dense<2.000000e+00> : tensor<9xf32> +- %505 = stablehlo.multiply %504, %496 : tensor<9xf32> ++ %505 = stablehlo.multiply %496, %504 : tensor<9xf32> + %506 = stablehlo.add %iterArg_4, %505 : tensor<9xf32> +- %507 = stablehlo.multiply %504, %496 : tensor<9xf32> ++ %507 = stablehlo.multiply %496, %504 : tensor<9xf32> + %508 = stablehlo.add %iterArg_4, %507 : tensor<9xf32> + %509 = stablehlo.constant dense<1.000000e+00> : tensor + %510 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> +@@ -237,10 +237,10 @@ + %514 = stablehlo.subtract %iterArg_5, %496 : tensor<9xf32> + %515 = stablehlo.multiply %496, %514 : tensor<9xf32> + %516 = stablehlo.multiply %515, %iterArg_6 : tensor<9xf32> +- %517 = stablehlo.multiply %504, %496 : tensor<9xf32> ++ %517 = stablehlo.multiply %496, %504 : tensor<9xf32> + %518 = stablehlo.add %iterArg_4, %517 : tensor<9xf32> + %519 = stablehlo.subtract %518, %510 : tensor<9xf32> +- %520 = stablehlo.multiply %504, %496 : tensor<9xf32> ++ %520 = stablehlo.multiply %496, %504 : tensor<9xf32> + %521 = stablehlo.add %iterArg_4, %520 : tensor<9xf32> + %522 = stablehlo.multiply %519, %521 : tensor<9xf32> + %523 = stablehlo.divide %516, %522 : tensor<9xf32> +@@ -322,7 +322,7 @@ + %79 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> + %80 = stablehlo.subtract %79, %74 : tensor<9xf32> + %81 = stablehlo.select %77, %80, %74 : tensor<9xi1>, tensor<9xf32> +- %82 = stablehlo.multiply %71, %81 : tensor<9xf32> ++ %82 = stablehlo.multiply %81, %71 : tensor<9xf32> + %83 = stablehlo.sine %82 : tensor<9xf32> + %84 = stablehlo.log %83 : tensor<9xf32> + %85 = stablehlo.is_finite %84 : (tensor<9xf32>) -> tensor<9xi1> +@@ -340,17 +340,17 @@ + %97 = stablehlo.add %95, %96 : tensor<9xf32> + %98 = stablehlo.constant dense<7.500000e+00> : tensor + %99 = stablehlo.constant dense<7.500000e+00> : tensor<9xf32> +- %100 = stablehlo.add %99, %95 : tensor<9xf32> ++ %100 = stablehlo.add %95, %99 : tensor<9xf32> + %101 = stablehlo.constant dense<2.01490307> : tensor + %102 = stablehlo.constant dense<2.01490307> : tensor<9xf32> + %103 = stablehlo.constant dense<7.500000e+00> : tensor<9xf32> + %104 = stablehlo.divide %95, %103 : tensor<9xf32> + %105 = stablehlo.log_plus_one %104 : tensor<9xf32> +- %106 = stablehlo.add %102, %105 : tensor<9xf32> ++ %106 = stablehlo.add %105, %102 : tensor<9xf32> + %107 = stablehlo.divide %100, %106 : tensor<9xf32> + %108 = stablehlo.subtract %97, %107 : tensor<9xf32> + %109 = stablehlo.multiply %108, %106 : tensor<9xf32> +- %110 = stablehlo.add %90, %109 : tensor<9xf32> ++ %110 = stablehlo.add %109, %90 : tensor<9xf32> + %111 = stablehlo.constant dense<1.000000e+00> : tensor + %112 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> + %113 = stablehlo.constant dense<676.520386> : tensor +@@ -361,7 +361,7 @@ + %118 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> + %119 = stablehlo.add %117, %118 : tensor<9xf32> + %120 = stablehlo.divide %114, %119 : tensor<9xf32> +- %121 = stablehlo.add %112, %120 : tensor<9xf32> ++ %121 = stablehlo.add %120, %112 : tensor<9xf32> + %122 = stablehlo.constant dense<-1259.13916> : tensor + %123 = stablehlo.constant dense<-1259.13916> : tensor<9xf32> + %124 = stablehlo.constant dense<1.000000e+00> : tensor +@@ -453,7 +453,7 @@ + %210 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> + %211 = stablehlo.subtract %210, %205 : tensor<9xf32> + %212 = stablehlo.select %208, %211, %205 : tensor<9xi1>, tensor<9xf32> +- %213 = stablehlo.multiply %202, %212 : tensor<9xf32> ++ %213 = stablehlo.multiply %212, %202 : tensor<9xf32> + %214 = stablehlo.sine %213 : tensor<9xf32> + %215 = stablehlo.log %214 : tensor<9xf32> + %216 = stablehlo.is_finite %215 : (tensor<9xf32>) -> tensor<9xi1> +@@ -471,17 +471,17 @@ + %228 = stablehlo.add %226, %227 : tensor<9xf32> + %229 = stablehlo.constant dense<7.500000e+00> : tensor + %230 = stablehlo.constant dense<7.500000e+00> : tensor<9xf32> +- %231 = stablehlo.add %230, %226 : tensor<9xf32> ++ %231 = stablehlo.add %226, %230 : tensor<9xf32> + %232 = stablehlo.constant dense<2.01490307> : tensor + %233 = stablehlo.constant dense<2.01490307> : tensor<9xf32> + %234 = stablehlo.constant dense<7.500000e+00> : tensor<9xf32> + %235 = stablehlo.divide %226, %234 : tensor<9xf32> + %236 = stablehlo.log_plus_one %235 : tensor<9xf32> +- %237 = stablehlo.add %233, %236 : tensor<9xf32> ++ %237 = stablehlo.add %236, %233 : tensor<9xf32> + %238 = stablehlo.divide %231, %237 : tensor<9xf32> + %239 = stablehlo.subtract %228, %238 : tensor<9xf32> + %240 = stablehlo.multiply %239, %237 : tensor<9xf32> +- %241 = stablehlo.add %221, %240 : tensor<9xf32> ++ %241 = stablehlo.add %240, %221 : tensor<9xf32> + %242 = stablehlo.constant dense<1.000000e+00> : tensor + %243 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> + %244 = stablehlo.constant dense<676.520386> : tensor +@@ -492,7 +492,7 @@ + %249 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> + %250 = stablehlo.add %248, %249 : tensor<9xf32> + %251 = stablehlo.divide %245, %250 : tensor<9xf32> +- %252 = stablehlo.add %243, %251 : tensor<9xf32> ++ %252 = stablehlo.add %251, %243 : tensor<9xf32> + %253 = stablehlo.constant dense<-1259.13916> : tensor + %254 = stablehlo.constant dense<-1259.13916> : tensor<9xf32> + %255 = stablehlo.constant dense<1.000000e+00> : tensor +@@ -586,7 +586,7 @@ + %343 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> + %344 = stablehlo.subtract %343, %338 : tensor<9xf32> + %345 = stablehlo.select %341, %344, %338 : tensor<9xi1>, tensor<9xf32> +- %346 = stablehlo.multiply %335, %345 : tensor<9xf32> ++ %346 = stablehlo.multiply %345, %335 : tensor<9xf32> + %347 = stablehlo.sine %346 : tensor<9xf32> + %348 = stablehlo.log %347 : tensor<9xf32> + %349 = stablehlo.is_finite %348 : (tensor<9xf32>) -> tensor<9xi1> +@@ -604,17 +604,17 @@ + %361 = stablehlo.add %359, %360 : tensor<9xf32> + %362 = stablehlo.constant dense<7.500000e+00> : tensor + %363 = stablehlo.constant dense<7.500000e+00> : tensor<9xf32> +- %364 = stablehlo.add %363, %359 : tensor<9xf32> ++ %364 = stablehlo.add %359, %363 : tensor<9xf32> + %365 = stablehlo.constant dense<2.01490307> : tensor + %366 = stablehlo.constant dense<2.01490307> : tensor<9xf32> + %367 = stablehlo.constant dense<7.500000e+00> : tensor<9xf32> + %368 = stablehlo.divide %359, %367 : tensor<9xf32> + %369 = stablehlo.log_plus_one %368 : tensor<9xf32> +- %370 = stablehlo.add %366, %369 : tensor<9xf32> ++ %370 = stablehlo.add %369, %366 : tensor<9xf32> + %371 = stablehlo.divide %364, %370 : tensor<9xf32> + %372 = stablehlo.subtract %361, %371 : tensor<9xf32> + %373 = stablehlo.multiply %372, %370 : tensor<9xf32> +- %374 = stablehlo.add %354, %373 : tensor<9xf32> ++ %374 = stablehlo.add %373, %354 : tensor<9xf32> + %375 = stablehlo.constant dense<1.000000e+00> : tensor + %376 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> + %377 = stablehlo.constant dense<676.520386> : tensor +@@ -625,7 +625,7 @@ + %382 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> + %383 = stablehlo.add %381, %382 : tensor<9xf32> + %384 = stablehlo.divide %378, %383 : tensor<9xf32> +- %385 = stablehlo.add %376, %384 : tensor<9xf32> ++ %385 = stablehlo.add %384, %376 : tensor<9xf32> + %386 = stablehlo.constant dense<-1259.13916> : tensor + %387 = stablehlo.constant dense<-1259.13916> : tensor<9xf32> + %388 = stablehlo.constant dense<1.000000e+00> : tensor +diff --ruN a/stablehlo/stablehlo/testdata/regularized_incomplete_beta__float32.mlir b/stablehlo/stablehlo/testdata/regularized_incomplete_beta__float32.mlir +--- stablehlo/stablehlo/testdata/regularized_incomplete_beta__float32.mlir ++++ stablehlo/stablehlo/testdata/regularized_incomplete_beta__float32.mlir +@@ -71,9 +71,9 @@ + %40 = stablehlo.multiply %38, %39 : tensor<9xf32> + %41 = stablehlo.constant dense<2.000000e+00> : tensor + %42 = stablehlo.constant dense<2.000000e+00> : tensor<9xf32> +- %43 = stablehlo.multiply %42, %32 : tensor<9xf32> ++ %43 = stablehlo.multiply %32, %42 : tensor<9xf32> + %44 = stablehlo.add %25, %43 : tensor<9xf32> +- %45 = stablehlo.multiply %42, %32 : tensor<9xf32> ++ %45 = stablehlo.multiply %32, %42 : tensor<9xf32> + %46 = stablehlo.add %25, %45 : tensor<9xf32> + %47 = stablehlo.constant dense<1.000000e+00> : tensor + %48 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> +@@ -83,10 +83,10 @@ + %52 = stablehlo.subtract %35, %32 : tensor<9xf32> + %53 = stablehlo.multiply %32, %52 : tensor<9xf32> + %54 = stablehlo.multiply %53, %39 : tensor<9xf32> +- %55 = stablehlo.multiply %42, %32 : tensor<9xf32> ++ %55 = stablehlo.multiply %32, %42 : tensor<9xf32> + %56 = stablehlo.add %25, %55 : tensor<9xf32> + %57 = stablehlo.subtract %56, %48 : tensor<9xf32> +- %58 = stablehlo.multiply %42, %32 : tensor<9xf32> ++ %58 = stablehlo.multiply %32, %42 : tensor<9xf32> + %59 = stablehlo.add %25, %58 : tensor<9xf32> + %60 = stablehlo.multiply %57, %59 : tensor<9xf32> + %61 = stablehlo.divide %54, %60 : tensor<9xf32> +@@ -222,9 +222,9 @@ + %498 = stablehlo.multiply %497, %iterArg_6 : tensor<9xf32> + %499 = stablehlo.constant dense<2.000000e+00> : tensor + %500 = stablehlo.constant dense<2.000000e+00> : tensor<9xf32> +- %501 = stablehlo.multiply %500, %492 : tensor<9xf32> ++ %501 = stablehlo.multiply %492, %500 : tensor<9xf32> + %502 = stablehlo.add %iterArg_4, %501 : tensor<9xf32> +- %503 = stablehlo.multiply %500, %492 : tensor<9xf32> ++ %503 = stablehlo.multiply %492, %500 : tensor<9xf32> + %504 = stablehlo.add %iterArg_4, %503 : tensor<9xf32> + %505 = stablehlo.constant dense<1.000000e+00> : tensor + %506 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> +@@ -234,10 +234,10 @@ + %510 = stablehlo.subtract %iterArg_5, %492 : tensor<9xf32> + %511 = stablehlo.multiply %492, %510 : tensor<9xf32> + %512 = stablehlo.multiply %511, %iterArg_6 : tensor<9xf32> +- %513 = stablehlo.multiply %500, %492 : tensor<9xf32> ++ %513 = stablehlo.multiply %492, %500 : tensor<9xf32> + %514 = stablehlo.add %iterArg_4, %513 : tensor<9xf32> + %515 = stablehlo.subtract %514, %506 : tensor<9xf32> +- %516 = stablehlo.multiply %500, %492 : tensor<9xf32> ++ %516 = stablehlo.multiply %492, %500 : tensor<9xf32> + %517 = stablehlo.add %iterArg_4, %516 : tensor<9xf32> + %518 = stablehlo.multiply %515, %517 : tensor<9xf32> + %519 = stablehlo.divide %512, %518 : tensor<9xf32> +@@ -319,7 +319,7 @@ + %76 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> + %77 = stablehlo.subtract %76, %71 : tensor<9xf32> + %78 = stablehlo.select %74, %77, %71 : tensor<9xi1>, tensor<9xf32> +- %79 = stablehlo.multiply %68, %78 : tensor<9xf32> ++ %79 = stablehlo.multiply %78, %68 : tensor<9xf32> + %80 = stablehlo.sine %79 : tensor<9xf32> + %81 = stablehlo.log %80 : tensor<9xf32> + %82 = stablehlo.is_finite %81 : (tensor<9xf32>) -> tensor<9xi1> +@@ -337,17 +337,17 @@ + %94 = stablehlo.add %92, %93 : tensor<9xf32> + %95 = stablehlo.constant dense<7.500000e+00> : tensor + %96 = stablehlo.constant dense<7.500000e+00> : tensor<9xf32> +- %97 = stablehlo.add %96, %92 : tensor<9xf32> ++ %97 = stablehlo.add %92, %96 : tensor<9xf32> + %98 = stablehlo.constant dense<2.01490307> : tensor + %99 = stablehlo.constant dense<2.01490307> : tensor<9xf32> + %100 = stablehlo.constant dense<7.500000e+00> : tensor<9xf32> + %101 = stablehlo.divide %92, %100 : tensor<9xf32> + %102 = stablehlo.log_plus_one %101 : tensor<9xf32> +- %103 = stablehlo.add %99, %102 : tensor<9xf32> ++ %103 = stablehlo.add %102, %99 : tensor<9xf32> + %104 = stablehlo.divide %97, %103 : tensor<9xf32> + %105 = stablehlo.subtract %94, %104 : tensor<9xf32> + %106 = stablehlo.multiply %105, %103 : tensor<9xf32> +- %107 = stablehlo.add %87, %106 : tensor<9xf32> ++ %107 = stablehlo.add %106, %87 : tensor<9xf32> + %108 = stablehlo.constant dense<1.000000e+00> : tensor + %109 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> + %110 = stablehlo.constant dense<676.520386> : tensor +@@ -358,7 +358,7 @@ + %115 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> + %116 = stablehlo.add %114, %115 : tensor<9xf32> + %117 = stablehlo.divide %111, %116 : tensor<9xf32> +- %118 = stablehlo.add %109, %117 : tensor<9xf32> ++ %118 = stablehlo.add %117, %109 : tensor<9xf32> + %119 = stablehlo.constant dense<-1259.13916> : tensor + %120 = stablehlo.constant dense<-1259.13916> : tensor<9xf32> + %121 = stablehlo.constant dense<1.000000e+00> : tensor +@@ -450,7 +450,7 @@ + %207 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> + %208 = stablehlo.subtract %207, %202 : tensor<9xf32> + %209 = stablehlo.select %205, %208, %202 : tensor<9xi1>, tensor<9xf32> +- %210 = stablehlo.multiply %199, %209 : tensor<9xf32> ++ %210 = stablehlo.multiply %209, %199 : tensor<9xf32> + %211 = stablehlo.sine %210 : tensor<9xf32> + %212 = stablehlo.log %211 : tensor<9xf32> + %213 = stablehlo.is_finite %212 : (tensor<9xf32>) -> tensor<9xi1> +@@ -468,17 +468,17 @@ + %225 = stablehlo.add %223, %224 : tensor<9xf32> + %226 = stablehlo.constant dense<7.500000e+00> : tensor + %227 = stablehlo.constant dense<7.500000e+00> : tensor<9xf32> +- %228 = stablehlo.add %227, %223 : tensor<9xf32> ++ %228 = stablehlo.add %223, %227 : tensor<9xf32> + %229 = stablehlo.constant dense<2.01490307> : tensor + %230 = stablehlo.constant dense<2.01490307> : tensor<9xf32> + %231 = stablehlo.constant dense<7.500000e+00> : tensor<9xf32> + %232 = stablehlo.divide %223, %231 : tensor<9xf32> + %233 = stablehlo.log_plus_one %232 : tensor<9xf32> +- %234 = stablehlo.add %230, %233 : tensor<9xf32> ++ %234 = stablehlo.add %233, %230 : tensor<9xf32> + %235 = stablehlo.divide %228, %234 : tensor<9xf32> + %236 = stablehlo.subtract %225, %235 : tensor<9xf32> + %237 = stablehlo.multiply %236, %234 : tensor<9xf32> +- %238 = stablehlo.add %218, %237 : tensor<9xf32> ++ %238 = stablehlo.add %237, %218 : tensor<9xf32> + %239 = stablehlo.constant dense<1.000000e+00> : tensor + %240 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> + %241 = stablehlo.constant dense<676.520386> : tensor +@@ -489,7 +489,7 @@ + %246 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> + %247 = stablehlo.add %245, %246 : tensor<9xf32> + %248 = stablehlo.divide %242, %247 : tensor<9xf32> +- %249 = stablehlo.add %240, %248 : tensor<9xf32> ++ %249 = stablehlo.add %248, %240 : tensor<9xf32> + %250 = stablehlo.constant dense<-1259.13916> : tensor + %251 = stablehlo.constant dense<-1259.13916> : tensor<9xf32> + %252 = stablehlo.constant dense<1.000000e+00> : tensor +@@ -583,7 +583,7 @@ + %340 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> + %341 = stablehlo.subtract %340, %335 : tensor<9xf32> + %342 = stablehlo.select %338, %341, %335 : tensor<9xi1>, tensor<9xf32> +- %343 = stablehlo.multiply %332, %342 : tensor<9xf32> ++ %343 = stablehlo.multiply %342, %332 : tensor<9xf32> + %344 = stablehlo.sine %343 : tensor<9xf32> + %345 = stablehlo.log %344 : tensor<9xf32> + %346 = stablehlo.is_finite %345 : (tensor<9xf32>) -> tensor<9xi1> +@@ -601,17 +601,17 @@ + %358 = stablehlo.add %356, %357 : tensor<9xf32> + %359 = stablehlo.constant dense<7.500000e+00> : tensor + %360 = stablehlo.constant dense<7.500000e+00> : tensor<9xf32> +- %361 = stablehlo.add %360, %356 : tensor<9xf32> ++ %361 = stablehlo.add %356, %360 : tensor<9xf32> + %362 = stablehlo.constant dense<2.01490307> : tensor + %363 = stablehlo.constant dense<2.01490307> : tensor<9xf32> + %364 = stablehlo.constant dense<7.500000e+00> : tensor<9xf32> + %365 = stablehlo.divide %356, %364 : tensor<9xf32> + %366 = stablehlo.log_plus_one %365 : tensor<9xf32> +- %367 = stablehlo.add %363, %366 : tensor<9xf32> ++ %367 = stablehlo.add %366, %363 : tensor<9xf32> + %368 = stablehlo.divide %361, %367 : tensor<9xf32> + %369 = stablehlo.subtract %358, %368 : tensor<9xf32> + %370 = stablehlo.multiply %369, %367 : tensor<9xf32> +- %371 = stablehlo.add %351, %370 : tensor<9xf32> ++ %371 = stablehlo.add %370, %351 : tensor<9xf32> + %372 = stablehlo.constant dense<1.000000e+00> : tensor + %373 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> + %374 = stablehlo.constant dense<676.520386> : tensor +@@ -622,7 +622,7 @@ + %379 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> + %380 = stablehlo.add %378, %379 : tensor<9xf32> + %381 = stablehlo.divide %375, %380 : tensor<9xf32> +- %382 = stablehlo.add %373, %381 : tensor<9xf32> ++ %382 = stablehlo.add %381, %373 : tensor<9xf32> + %383 = stablehlo.constant dense<-1259.13916> : tensor + %384 = stablehlo.constant dense<-1259.13916> : tensor<9xf32> + %385 = stablehlo.constant dense<1.000000e+00> : tensor +diff --ruN a/stablehlo/stablehlo/testdata/sinh_shape_bfloat16_20_20.mlir b/stablehlo/stablehlo/testdata/sinh_shape_bfloat16_20_20.mlir +--- stablehlo/stablehlo/testdata/sinh_shape_bfloat16_20_20.mlir ++++ stablehlo/stablehlo/testdata/sinh_shape_bfloat16_20_20.mlir +@@ -19,7 +19,7 @@ + %13 = stablehlo.add %10, %11 : tensor<20x20xf32> + %14 = stablehlo.divide %10, %13 : tensor<20x20xf32> + %15 = stablehlo.add %10, %14 : tensor<20x20xf32> +- %16 = stablehlo.multiply %12, %15 : tensor<20x20xf32> ++ %16 = stablehlo.multiply %15, %12 : tensor<20x20xf32> + %17 = stablehlo.abs %2 : tensor<20x20xf32> + %18 = stablehlo.compare LT, %17, %11 : (tensor<20x20xf32>, tensor<20x20xf32>) -> tensor<20x20xi1> + %19 = stablehlo.select %18, %16, %9 : tensor<20x20xi1>, tensor<20x20xf32> +diff --ruN a/stablehlo/stablehlo/testdata/sinh_shape_float16_20_20.mlir b/stablehlo/stablehlo/testdata/sinh_shape_float16_20_20.mlir +--- stablehlo/stablehlo/testdata/sinh_shape_float16_20_20.mlir ++++ stablehlo/stablehlo/testdata/sinh_shape_float16_20_20.mlir +@@ -19,7 +19,7 @@ + %13 = stablehlo.add %10, %11 : tensor<20x20xf32> + %14 = stablehlo.divide %10, %13 : tensor<20x20xf32> + %15 = stablehlo.add %10, %14 : tensor<20x20xf32> +- %16 = stablehlo.multiply %12, %15 : tensor<20x20xf32> ++ %16 = stablehlo.multiply %15, %12 : tensor<20x20xf32> + %17 = stablehlo.abs %2 : tensor<20x20xf32> + %18 = stablehlo.compare LT, %17, %11 : (tensor<20x20xf32>, tensor<20x20xf32>) -> tensor<20x20xi1> + %19 = stablehlo.select %18, %16, %9 : tensor<20x20xi1>, tensor<20x20xf32> +diff --ruN a/stablehlo/stablehlo/testdata/sinh_shape_float32_20_20.mlir b/stablehlo/stablehlo/testdata/sinh_shape_float32_20_20.mlir +--- stablehlo/stablehlo/testdata/sinh_shape_float32_20_20.mlir ++++ stablehlo/stablehlo/testdata/sinh_shape_float32_20_20.mlir +@@ -18,7 +18,7 @@ + %12 = stablehlo.add %9, %10 : tensor<20x20xf32> + %13 = stablehlo.divide %9, %12 : tensor<20x20xf32> + %14 = stablehlo.add %9, %13 : tensor<20x20xf32> +- %15 = stablehlo.multiply %11, %14 : tensor<20x20xf32> ++ %15 = stablehlo.multiply %14, %11 : tensor<20x20xf32> + %16 = stablehlo.abs %0 : tensor<20x20xf32> + %17 = stablehlo.compare LT, %16, %10 : (tensor<20x20xf32>, tensor<20x20xf32>) -> tensor<20x20xi1> + %18 = stablehlo.select %17, %15, %8 : tensor<20x20xi1>, tensor<20x20xf32> +diff --ruN a/stablehlo/stablehlo/testdata/slice_in_dim_limit_neg_dynamic.mlir b/stablehlo/stablehlo/testdata/slice_in_dim_limit_neg_dynamic.mlir +--- stablehlo/stablehlo/testdata/slice_in_dim_limit_neg_dynamic.mlir ++++ stablehlo/stablehlo/testdata/slice_in_dim_limit_neg_dynamic.mlir +@@ -3,7 +3,7 @@ + module @jit_fun_flat_jax { + func.func public @main(%arg0: tensor, %arg1: tensor {mhlo.sharding = ""}) -> tensor { + %0 = stablehlo.constant dense<-1> : tensor +- %1 = stablehlo.add %0, %arg0 : tensor ++ %1 = stablehlo.add %arg0, %0 : tensor + %2 = stablehlo.constant dense<0> : tensor<1xi32> + %3 = stablehlo.constant dense<0> : tensor<1xi32> + %4 = stablehlo.concatenate %2, %3, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> +diff --ruN a/stablehlo/stablehlo/testdata/slice_in_dim_start_neg_dynamic.mlir b/stablehlo/stablehlo/testdata/slice_in_dim_start_neg_dynamic.mlir +--- stablehlo/stablehlo/testdata/slice_in_dim_start_neg_dynamic.mlir ++++ stablehlo/stablehlo/testdata/slice_in_dim_start_neg_dynamic.mlir +@@ -3,7 +3,7 @@ + module @jit_fun_flat_jax { + func.func public @main(%arg0: tensor, %arg1: tensor {mhlo.sharding = ""}) -> tensor<1x4xf32> { + %0 = stablehlo.constant dense<-1> : tensor +- %1 = stablehlo.add %0, %arg0 : tensor ++ %1 = stablehlo.add %arg0, %0 : tensor + %2 = stablehlo.convert %1 : (tensor) -> tensor + %3 = stablehlo.reshape %2 : (tensor) -> tensor<1xi32> + %4 = stablehlo.constant dense<0> : tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/take__enable_xla_True_dynamic.mlir b/stablehlo/stablehlo/testdata/take__enable_xla_True_dynamic.mlir +--- stablehlo/stablehlo/testdata/take__enable_xla_True_dynamic.mlir ++++ stablehlo/stablehlo/testdata/take__enable_xla_True_dynamic.mlir +@@ -29,7 +29,7 @@ + %20 = stablehlo.compare LT, %8, %19, SIGNED : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> + %21 = stablehlo.constant dense<3> : tensor + %22 = stablehlo.broadcast_in_dim %21, dims = [] : (tensor) -> tensor<1xi64> +- %23 = stablehlo.add %8, %22 : tensor<1xi64> ++ %23 = stablehlo.add %22, %8 : tensor<1xi64> + %24 = stablehlo.select %20, %23, %8 : tensor<1xi1>, tensor<1xi64> + %25 = stablehlo.convert %24 : (tensor<1xi64>) -> tensor<1xi32> + %26 = stablehlo.broadcast_in_dim %25, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> +@@ -46,7 +46,7 @@ + %37 = stablehlo.compare LT, %9, %36, SIGNED : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> + %38 = stablehlo.constant dense<3> : tensor + %39 = stablehlo.broadcast_in_dim %38, dims = [] : (tensor) -> tensor<1xi64> +- %40 = stablehlo.add %9, %39 : tensor<1xi64> ++ %40 = stablehlo.add %39, %9 : tensor<1xi64> + %41 = stablehlo.select %37, %40, %9 : tensor<1xi1>, tensor<1xi64> + %42 = stablehlo.convert %41 : (tensor<1xi64>) -> tensor<1xi32> + %43 = stablehlo.broadcast_in_dim %42, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/take_along_axis_0_dynamic.mlir b/stablehlo/stablehlo/testdata/take_along_axis_0_dynamic.mlir +--- stablehlo/stablehlo/testdata/take_along_axis_0_dynamic.mlir ++++ stablehlo/stablehlo/testdata/take_along_axis_0_dynamic.mlir +@@ -34,7 +34,7 @@ + %25 = stablehlo.compare LT, %16, %24, SIGNED : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> + %26 = stablehlo.constant dense<2> : tensor + %27 = stablehlo.broadcast_in_dim %26, dims = [] : (tensor) -> tensor<1xi64> +- %28 = stablehlo.add %16, %27 : tensor<1xi64> ++ %28 = stablehlo.add %27, %16 : tensor<1xi64> + %29 = stablehlo.select %25, %28, %16 : tensor<1xi1>, tensor<1xi64> + %30 = stablehlo.convert %29 : (tensor<1xi64>) -> tensor<1xi32> + %31 = stablehlo.broadcast_in_dim %30, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> +@@ -49,7 +49,7 @@ + %40 = stablehlo.compare LT, %17, %39, SIGNED : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> + %41 = stablehlo.constant dense<2> : tensor + %42 = stablehlo.broadcast_in_dim %41, dims = [] : (tensor) -> tensor<1xi64> +- %43 = stablehlo.add %17, %42 : tensor<1xi64> ++ %43 = stablehlo.add %42, %17 : tensor<1xi64> + %44 = stablehlo.select %40, %43, %17 : tensor<1xi1>, tensor<1xi64> + %45 = stablehlo.convert %44 : (tensor<1xi64>) -> tensor<1xi32> + %46 = stablehlo.broadcast_in_dim %45, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/take_along_axis_1_dynamic.mlir b/stablehlo/stablehlo/testdata/take_along_axis_1_dynamic.mlir +--- stablehlo/stablehlo/testdata/take_along_axis_1_dynamic.mlir ++++ stablehlo/stablehlo/testdata/take_along_axis_1_dynamic.mlir +@@ -47,7 +47,7 @@ + %38 = stablehlo.compare LT, %29, %37, SIGNED : (tensor<2xi64>, tensor<2xi64>) -> tensor<2xi1> + %39 = stablehlo.constant dense<2> : tensor + %40 = stablehlo.broadcast_in_dim %39, dims = [] : (tensor) -> tensor<2xi64> +- %41 = stablehlo.add %29, %40 : tensor<2xi64> ++ %41 = stablehlo.add %40, %29 : tensor<2xi64> + %42 = stablehlo.select %38, %41, %29 : tensor<2xi1>, tensor<2xi64> + %43 = stablehlo.convert %42 : (tensor<2xi64>) -> tensor<2xi32> + %44 = stablehlo.broadcast_in_dim %43, dims = [0] : (tensor<2xi32>) -> tensor<2x1xi32> +@@ -62,7 +62,7 @@ + %53 = stablehlo.compare LT, %30, %52, SIGNED : (tensor<2xi64>, tensor<2xi64>) -> tensor<2xi1> + %54 = stablehlo.constant dense<2> : tensor + %55 = stablehlo.broadcast_in_dim %54, dims = [] : (tensor) -> tensor<2xi64> +- %56 = stablehlo.add %30, %55 : tensor<2xi64> ++ %56 = stablehlo.add %55, %30 : tensor<2xi64> + %57 = stablehlo.select %53, %56, %30 : tensor<2xi1>, tensor<2xi64> + %58 = stablehlo.convert %57 : (tensor<2xi64>) -> tensor<2xi32> + %59 = stablehlo.broadcast_in_dim %58, dims = [0] : (tensor<2xi32>) -> tensor<2x1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/vmap_gather_dtypes_shape_float32_10__axis_0_enable_xla_True_dynamic.mlir b/stablehlo/stablehlo/testdata/vmap_gather_dtypes_shape_float32_10__axis_0_enable_xla_True_dynamic.mlir +--- stablehlo/stablehlo/testdata/vmap_gather_dtypes_shape_float32_10__axis_0_enable_xla_True_dynamic.mlir ++++ stablehlo/stablehlo/testdata/vmap_gather_dtypes_shape_float32_10__axis_0_enable_xla_True_dynamic.mlir +@@ -41,7 +41,7 @@ + %32 = stablehlo.compare LT, %22, %31, SIGNED : (tensor<2xi64>, tensor<2xi64>) -> tensor<2xi1> + %33 = stablehlo.constant dense<2> : tensor + %34 = stablehlo.broadcast_in_dim %33, dims = [] : (tensor) -> tensor<2xi64> +- %35 = stablehlo.add %22, %34 : tensor<2xi64> ++ %35 = stablehlo.add %34, %22 : tensor<2xi64> + %36 = stablehlo.select %32, %35, %22 : tensor<2xi1>, tensor<2xi64> + %37 = stablehlo.convert %36 : (tensor<2xi64>) -> tensor<2xi32> + %38 = stablehlo.broadcast_in_dim %37, dims = [0] : (tensor<2xi32>) -> tensor<2x1xi32> +@@ -56,7 +56,7 @@ + %47 = stablehlo.compare LT, %23, %46, SIGNED : (tensor<2xi64>, tensor<2xi64>) -> tensor<2xi1> + %48 = stablehlo.constant dense<2> : tensor + %49 = stablehlo.broadcast_in_dim %48, dims = [] : (tensor) -> tensor<2xi64> +- %50 = stablehlo.add %23, %49 : tensor<2xi64> ++ %50 = stablehlo.add %49, %23 : tensor<2xi64> + %51 = stablehlo.select %47, %50, %23 : tensor<2xi1>, tensor<2xi64> + %52 = stablehlo.convert %51 : (tensor<2xi64>) -> tensor<2xi32> + %53 = stablehlo.broadcast_in_dim %52, dims = [0] : (tensor<2xi32>) -> tensor<2x1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/vmap_gather_from_take_indices_name__1__axis_0_enable_xla_True_mode_fill_dynamic.mlir b/stablehlo/stablehlo/testdata/vmap_gather_from_take_indices_name__1__axis_0_enable_xla_True_mode_fill_dynamic.mlir +--- stablehlo/stablehlo/testdata/vmap_gather_from_take_indices_name__1__axis_0_enable_xla_True_mode_fill_dynamic.mlir ++++ stablehlo/stablehlo/testdata/vmap_gather_from_take_indices_name__1__axis_0_enable_xla_True_mode_fill_dynamic.mlir +@@ -45,7 +45,7 @@ + %36 = stablehlo.compare LT, %22, %35, SIGNED : (tensor<2xi64>, tensor<2xi64>) -> tensor<2xi1> + %37 = stablehlo.constant dense<4> : tensor + %38 = stablehlo.broadcast_in_dim %37, dims = [] : (tensor) -> tensor<2xi64> +- %39 = stablehlo.add %22, %38 : tensor<2xi64> ++ %39 = stablehlo.add %38, %22 : tensor<2xi64> + %40 = stablehlo.select %36, %39, %22 : tensor<2xi1>, tensor<2xi64> + %41 = stablehlo.convert %40 : (tensor<2xi64>) -> tensor<2xi32> + %42 = stablehlo.broadcast_in_dim %41, dims = [0] : (tensor<2xi32>) -> tensor<2x1xi32> +@@ -64,7 +64,7 @@ + %55 = stablehlo.compare LT, %23, %54, SIGNED : (tensor<2xi64>, tensor<2xi64>) -> tensor<2xi1> + %56 = stablehlo.constant dense<4> : tensor + %57 = stablehlo.broadcast_in_dim %56, dims = [] : (tensor) -> tensor<2xi64> +- %58 = stablehlo.add %23, %57 : tensor<2xi64> ++ %58 = stablehlo.add %57, %23 : tensor<2xi64> + %59 = stablehlo.select %55, %58, %23 : tensor<2xi1>, tensor<2xi64> + %60 = stablehlo.convert %59 : (tensor<2xi64>) -> tensor<2xi32> + %61 = stablehlo.broadcast_in_dim %60, dims = [0] : (tensor<2xi32>) -> tensor<2x1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/vmap_gather_from_take_indices_name__1__axis_2_enable_xla_True_mode_fill_dynamic.mlir b/stablehlo/stablehlo/testdata/vmap_gather_from_take_indices_name__1__axis_2_enable_xla_True_mode_fill_dynamic.mlir +--- stablehlo/stablehlo/testdata/vmap_gather_from_take_indices_name__1__axis_2_enable_xla_True_mode_fill_dynamic.mlir ++++ stablehlo/stablehlo/testdata/vmap_gather_from_take_indices_name__1__axis_2_enable_xla_True_mode_fill_dynamic.mlir +@@ -45,7 +45,7 @@ + %36 = stablehlo.compare LT, %22, %35, SIGNED : (tensor<2xi64>, tensor<2xi64>) -> tensor<2xi1> + %37 = stablehlo.constant dense<4> : tensor + %38 = stablehlo.broadcast_in_dim %37, dims = [] : (tensor) -> tensor<2xi64> +- %39 = stablehlo.add %22, %38 : tensor<2xi64> ++ %39 = stablehlo.add %38, %22 : tensor<2xi64> + %40 = stablehlo.select %36, %39, %22 : tensor<2xi1>, tensor<2xi64> + %41 = stablehlo.convert %40 : (tensor<2xi64>) -> tensor<2xi32> + %42 = stablehlo.broadcast_in_dim %41, dims = [0] : (tensor<2xi32>) -> tensor<2x1xi32> +@@ -64,7 +64,7 @@ + %55 = stablehlo.compare LT, %23, %54, SIGNED : (tensor<2xi64>, tensor<2xi64>) -> tensor<2xi1> + %56 = stablehlo.constant dense<4> : tensor + %57 = stablehlo.broadcast_in_dim %56, dims = [] : (tensor) -> tensor<2xi64> +- %58 = stablehlo.add %23, %57 : tensor<2xi64> ++ %58 = stablehlo.add %57, %23 : tensor<2xi64> + %59 = stablehlo.select %55, %58, %23 : tensor<2xi1>, tensor<2xi64> + %60 = stablehlo.convert %59 : (tensor<2xi64>) -> tensor<2xi32> + %61 = stablehlo.broadcast_in_dim %60, dims = [0] : (tensor<2xi32>) -> tensor<2x1xi32> diff --ruN a/stablehlo/stablehlo/tests/stablehlo_canonicalize_dynamism.mlir b/stablehlo/stablehlo/tests/stablehlo_canonicalize_dynamism.mlir --- stablehlo/stablehlo/tests/stablehlo_canonicalize_dynamism.mlir +++ stablehlo/stablehlo/tests/stablehlo_canonicalize_dynamism.mlir diff --git a/third_party/triton/cl550499635.patch b/third_party/triton/cl550499635.patch new file mode 100644 index 00000000000000..ddceb0765a5f49 --- /dev/null +++ b/third_party/triton/cl550499635.patch @@ -0,0 +1,17 @@ +==== triton/test/Analysis/test-alias.mlir#5 - /google/src/cloud/shyshkov/mlir_8dbddb17180fff0ed881d75689651769b9a9b483_1690193084/triton/test/Analysis/test-alias.mlir ==== +# action=edit type=text +--- triton/test/Analysis/test-alias.mlir 2023-05-26 16:07:06.000000000 -0700 ++++ triton/test/Analysis/test-alias.mlir 2023-07-24 03:52:39.000000000 -0700 +@@ -192,10 +192,10 @@ + // CHECK-NEXT: %arg9 -> %cst_1 + // CHECK-NEXT: %0#0 -> %cst + // CHECK-NEXT: %0#1 -> %cst_0 +- // CHECK-NEXT: %0#2 -> %cst_2,%cst_2 ++ // CHECK-NEXT: %0#2 -> %cst_1,%cst_2,%cst_2 + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) { + // CHECK-NEXT: %arg11 -> %cst_1,%cst_2,%cst_2 +- // CHECK-NEXT: %1 -> %cst_2,%cst_2 ++ // CHECK-NEXT: %1 -> %cst_1,%cst_2,%cst_2 + %c_shared_next = scf.for %jv = %lb to %ub step %step iter_args(%c_shared_next = %c_shared) -> (tensor<128x32xf16, #A_SHARED>) { + // CHECK-NEXT: %2 -> %cst_2,%cst_2 + %c_shared_next_next = scf.if %i1 -> tensor<128x32xf16, #A_SHARED> { diff --git a/third_party/triton/workspace.bzl b/third_party/triton/workspace.bzl index 219c10b85e2e7f..c3b66c14ffe155 100644 --- a/third_party/triton/workspace.bzl +++ b/third_party/triton/workspace.bzl @@ -16,5 +16,6 @@ def repo(): # For temporary changes which haven't landed upstream yet. patch_file = [ "//third_party/triton:cl536931041.patch", + "//third_party/triton:cl550499635.patch", ], ) From f178576d12e9dcc31daabdd2743beba7dd77af7c Mon Sep 17 00:00:00 2001 From: Fergus Henderson Date: Wed, 26 Jul 2023 09:49:28 -0700 Subject: [PATCH 187/410] Update sample_stable_delegate for promotion of experimental/acceleration/configuration out of experimental. PiperOrigin-RevId: 551235514 --- .../sample_stable_delegate/sample_app_using_stable_delegate.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/lite/delegates/utils/experimental/sample_stable_delegate/sample_app_using_stable_delegate.cc b/tensorflow/lite/delegates/utils/experimental/sample_stable_delegate/sample_app_using_stable_delegate.cc index aaaebb4dd14cda..2ef7d180bbe82d 100644 --- a/tensorflow/lite/delegates/utils/experimental/sample_stable_delegate/sample_app_using_stable_delegate.cc +++ b/tensorflow/lite/delegates/utils/experimental/sample_stable_delegate/sample_app_using_stable_delegate.cc @@ -19,9 +19,9 @@ limitations under the License. #include +#include "tensorflow/lite/acceleration/configuration/configuration_generated.h" #include "tensorflow/lite/c/c_api.h" // For TfLiteTensorByteSize. #include "tensorflow/lite/delegates/utils/experimental/stable_delegate/delegate_loader.h" -#include "tensorflow/lite/experimental/acceleration/configuration/configuration_generated.h" #include "tensorflow/lite/interpreter_builder.h" #include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/model_builder.h" From 0ef963cc44ddddf3b5a6f123c6365a57eba2e26d Mon Sep 17 00:00:00 2001 From: Scott Zhu Date: Wed, 26 Jul 2023 09:59:09 -0700 Subject: [PATCH 188/410] Increase the memory limit for the dtensor GPU test. PiperOrigin-RevId: 551238120 --- tensorflow/dtensor/python/tests/test_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/dtensor/python/tests/test_util.py b/tensorflow/dtensor/python/tests/test_util.py index fed11025d42a97..de21130c247682 100644 --- a/tensorflow/dtensor/python/tests/test_util.py +++ b/tensorflow/dtensor/python/tests/test_util.py @@ -51,7 +51,7 @@ DEFAULT_TOL = 1e-5 -_DEFAULT_GPU_MEMORY_LIMIT = 200 # MB +_DEFAULT_GPU_MEMORY_LIMIT = 1024 # 1G def get_use_xla_spmd(device_type): From 0321ee174ab470ca8480d890a5ed5204d050b044 Mon Sep 17 00:00:00 2001 From: Mihai Maruseac Date: Wed, 26 Jul 2023 10:24:33 -0700 Subject: [PATCH 189/410] Remove trigraph --- tensorflow/lite/core/interpreter_builder.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/lite/core/interpreter_builder.cc b/tensorflow/lite/core/interpreter_builder.cc index d1cf3eec2db7c1..18f1ce424b25d5 100644 --- a/tensorflow/lite/core/interpreter_builder.cc +++ b/tensorflow/lite/core/interpreter_builder.cc @@ -49,7 +49,7 @@ limitations under the License. #include "tensorflow/lite/version.h" // aligned_alloc is available (via cstdlib/stdlib.h) with C++17/C11. -//(introduced in stdc11 but realized in C++17) +// (introduced in stdc11 but realized in C++17) #if __cplusplus >= 201703L && __STDC_VERSION__ >= 201112L #if !defined(__ANDROID__) || __ANDROID_API__ >= 28 // Neither Apple nor Windows provide aligned_alloc. From cc5aa34f67a2c7a0ffe5dd186d35f9b4aedcb5cf Mon Sep 17 00:00:00 2001 From: Mihai Maruseac Date: Wed, 26 Jul 2023 10:24:39 -0700 Subject: [PATCH 190/410] Remove trigraph --- tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc index 67da19fc95a773..d4ef695a72a7eb 100644 --- a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc +++ b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc @@ -36,7 +36,7 @@ limitations under the License. #ifdef USE_NEON // aligned_alloc is available (via cstdlib/stdlib.h) with C++17/C11. -//(introduced in stdc11 but realized in C++17) +// (introduced in stdc11 but realized in C++17) #if __cplusplus >= 201703L && __STDC_VERSION__ >= 201112L #if !defined(__ANDROID__) || __ANDROID_API__ >= 28 // Neither Apple nor Windows provide aligned_alloc. From 7ba36c17a36cdce72ff294f9424b78980451eb6d Mon Sep 17 00:00:00 2001 From: Fergus Henderson Date: Wed, 26 Jul 2023 10:30:02 -0700 Subject: [PATCH 191/410] Remove unnecessary 'const' from pass-by-value function parameters. Also fix typo in SetAllowBufferHandleOutput comment: false->true. Also fix #include order to match style guide. PiperOrigin-RevId: 551247708 --- tensorflow/lite/core/interpreter.h | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tensorflow/lite/core/interpreter.h b/tensorflow/lite/core/interpreter.h index 92b6f2404cf515..2daaa186f11cca 100644 --- a/tensorflow/lite/core/interpreter.h +++ b/tensorflow/lite/core/interpreter.h @@ -41,8 +41,8 @@ limitations under the License. #include "tensorflow/lite/allocation.h" #include "tensorflow/lite/core/api/error_reporter.h" -#include "tensorflow/lite/core/async/async_signature_runner.h" #include "tensorflow/lite/core/api/profiler.h" +#include "tensorflow/lite/core/async/async_signature_runner.h" #include "tensorflow/lite/core/c/common.h" // IWYU pragma: export #include "tensorflow/lite/core/subgraph.h" #include "tensorflow/lite/experimental/remat/metadata_util.h" @@ -190,7 +190,7 @@ class Interpreter { } TfLiteStatus SetTensorParametersReadOnly( - int tensor_index, TfLiteType type, const char* name, const size_t rank, + int tensor_index, TfLiteType type, const char* name, size_t rank, const int* dims, TfLiteQuantizationParams quantization, const char* buffer, size_t bytes, const Allocation* allocation = nullptr); @@ -221,9 +221,9 @@ class Interpreter { is_variable, rank_dims_signature, dims_signature_pointer); } TfLiteStatus SetTensorParametersReadWrite( - int tensor_index, TfLiteType type, const char* name, const size_t rank, + int tensor_index, TfLiteType type, const char* name, size_t rank, const int* dims, TfLiteQuantizationParams quantization, - bool is_variable = false, const size_t rank_dims_signature = 0, + bool is_variable = false, size_t rank_dims_signature = 0, const int* dims_signature = nullptr); /// Enables application to cancel in flight invocation with `Cancel`. @@ -702,7 +702,7 @@ class Interpreter { /// When using hardware delegation, Interpreter will make the data of output /// tensors available in `tensor->data` by default. If the application can /// consume the buffer handle directly (e.g. reading output from OpenGL - /// texture), it can set this flag to false, so Interpreter won't copy the + /// texture), it can set this flag to true, so Interpreter won't copy the /// data from buffer handle to CPU memory. void SetAllowBufferHandleOutput(bool allow_buffer_handle_output) { allow_buffer_handle_output_ = allow_buffer_handle_output; From d9ec8c5ecec9150c7eef2437eed19f35a503d355 Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Wed, 26 Jul 2023 11:13:30 -0700 Subject: [PATCH 192/410] Internal visibility change only. PiperOrigin-RevId: 551261650 --- tensorflow/python/framework/BUILD | 13 ++----------- tensorflow/tensorflow.bzl | 3 +++ tensorflow/tensorflow.default.bzl | 2 ++ 3 files changed, 7 insertions(+), 11 deletions(-) diff --git a/tensorflow/python/framework/BUILD b/tensorflow/python/framework/BUILD index d579a313cc886a..b281baf9be9400 100644 --- a/tensorflow/python/framework/BUILD +++ b/tensorflow/python/framework/BUILD @@ -12,22 +12,13 @@ load( "tf_gen_op_wrapper_py", "tf_kernel_library", ) -load("//tensorflow:tensorflow.default.bzl", "cuda_py_strict_test", "tf_py_strict_test", "tf_python_pybind_extension") +load("//tensorflow:tensorflow.default.bzl", "cuda_py_strict_test", "tf_py_strict_test", "tf_python_framework_friends", "tf_python_pybind_extension") load("//tensorflow:pytype.default.bzl", "pytype_strict_library") load("//tensorflow/python/tpu:tpu.bzl", "tpu_py_strict_test") load("//tensorflow/core/platform:build_config.bzl", "pyx_library", "tf_additional_all_protos", "tf_additional_lib_deps", "tf_proto_library", "tf_protos_grappler") # @unused load("//tensorflow/core/platform:build_config_root.bzl", "if_static", "tf_additional_xla_deps_py") -visibility = [ - "//tensorflow:__subpackages__", - "//tensorflow/dtensor:dtensor-internal", - "//learning/brain/tfrt:__subpackages__", - "//third_party/py/tensorflow_numerics:__subpackages__", - "//tensorflow_models/google:__subpackages__", - "//learning/brain/google/data:__subpackages__", - "//learning/brain/experimental/tfq:__subpackages__", - "//learning/brain/google/python/ops:__subpackages__", -] +visibility = tf_python_framework_friends() package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index bb1d1394785c1d..9fa8c9d7628b2a 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -3531,3 +3531,6 @@ def replace_with_portable_tf_lib_when_required(non_portable_tf_deps, use_lib_wit "//tensorflow:ios": [portable_tf_lib], "//conditions:default": non_portable_tf_deps, }) + +def tf_python_framework_friends(): + return ["//tensorflow:__subpackages__"] diff --git a/tensorflow/tensorflow.default.bzl b/tensorflow/tensorflow.default.bzl index 9c6515f9798e5e..b2aba538e9a9ec 100644 --- a/tensorflow/tensorflow.default.bzl +++ b/tensorflow/tensorflow.default.bzl @@ -38,6 +38,7 @@ load( _tf_py_build_info_genrule = "tf_py_build_info_genrule", _tf_py_test = "tf_py_test", _tf_pybind_cc_library_wrapper = "tf_pybind_cc_library_wrapper", + _tf_python_framework_friends = "tf_python_framework_friends", _tf_python_pybind_extension = "tf_python_pybind_extension", _tf_selective_registration_deps = "tf_selective_registration_deps", _tf_version_info_genrule = "tf_version_info_genrule", @@ -91,3 +92,4 @@ genrule = _genrule internal_tfrt_deps = _internal_tfrt_deps tf_disable_ptxas_warning_flags = _tf_disable_ptxas_warning_flags replace_with_portable_tf_lib_when_required = _replace_with_portable_tf_lib_when_required +tf_python_framework_friends = _tf_python_framework_friends From cf4afb6b6ff75ed69bd02cc9f7cdbe47bfa2bfa8 Mon Sep 17 00:00:00 2001 From: Jie Sun Date: Wed, 26 Jul 2023 12:01:06 -0700 Subject: [PATCH 193/410] special allocations' aggregated metrics need to consider memory color. PiperOrigin-RevId: 551275563 --- .../convert/hlo_proto_to_memory_visualization_utils.cc | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tensorflow/core/profiler/convert/hlo_proto_to_memory_visualization_utils.cc b/tensorflow/core/profiler/convert/hlo_proto_to_memory_visualization_utils.cc index c562642ecad327..44aea4dd0da201 100644 --- a/tensorflow/core/profiler/convert/hlo_proto_to_memory_visualization_utils.cc +++ b/tensorflow/core/profiler/convert/hlo_proto_to_memory_visualization_utils.cc @@ -466,13 +466,14 @@ void Convert(const BufferAllocationProto& proto, } void NoteSpecialAllocations(const HloProtoBufferWrapper& wrapper, - int64_t small_buffer_size, + int64_t memory_color, int64_t small_buffer_size, PreprocessResult* result) { int64_t entry_parameters_bytes = 0; int64_t non_reusable_bytes = 0; int64_t maybe_live_out_bytes = 0; - for (const BufferAllocationProto& buffer_allocation : - wrapper.GetHloProto().buffer_assignment().buffer_allocations()) { + for (const auto* buffer_allocation_struct : + wrapper.GetBufferAllocations(memory_color)) { + const auto& buffer_allocation = buffer_allocation_struct->proto(); if (buffer_allocation.is_entry_computation_parameter()) { entry_parameters_bytes += buffer_allocation.size(); } @@ -1020,7 +1021,8 @@ void GeneratePreprocessResult(const HloProtoBufferWrapper& wrapper, logical_buffer->span->second); } - NoteSpecialAllocations(wrapper, peak_snapshot.small_buffer_size, result); + NoteSpecialAllocations(wrapper, memory_color, peak_snapshot.small_buffer_size, + result); ConvertAllocationTimeline(wrapper, simulator_stats, memory_color, result); } From 7f8be6eb6c4220e63a2e73787c78ae0df4fbbe50 Mon Sep 17 00:00:00 2001 From: Jie Sun Date: Wed, 26 Jul 2023 12:03:45 -0700 Subject: [PATCH 194/410] deprecate instruction name, it is changed over 1 years ago. PiperOrigin-RevId: 551276374 --- .../convert/hlo_proto_to_memory_visualization_utils.cc | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/tensorflow/core/profiler/convert/hlo_proto_to_memory_visualization_utils.cc b/tensorflow/core/profiler/convert/hlo_proto_to_memory_visualization_utils.cc index 44aea4dd0da201..16e57ea9b2d103 100644 --- a/tensorflow/core/profiler/convert/hlo_proto_to_memory_visualization_utils.cc +++ b/tensorflow/core/profiler/convert/hlo_proto_to_memory_visualization_utils.cc @@ -295,13 +295,8 @@ class HloProtoBufferWrapper { for (const auto& assigned : buffer_allocation.assigned()) { const auto id = assigned.logical_buffer_id(); const auto* logical_buffer = id_to_logical_buffer_proto.at(id); - const auto& instruction_name = - logical_buffer->defined_at().instruction_name(); const auto* instruction = - instruction_name.empty() - ? unique_id_to_hlo.at( - logical_buffer->defined_at().instruction_id()) - : name_to_hlo.at(instruction_name); + unique_id_to_hlo.at(logical_buffer->defined_at().instruction_id()); id_to_logical_buffer_[id] = std::make_unique( *logical_buffer, *buffer_allocation_s, *instruction, assigned.offset()); From 0cc2c308bf94cb2bc22b6be607d68bacf4046c53 Mon Sep 17 00:00:00 2001 From: David Silverstone Date: Wed, 26 Jul 2023 12:11:35 -0700 Subject: [PATCH 195/410] Add macros for working with TF_Status in C++ code `TF_STATUS_ASSIGN_OR_RETURN` and `TF_STATUS_RETURN_IF_ERROR` PiperOrigin-RevId: 551278625 --- tensorflow/c/tf_status_helper.h | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/tensorflow/c/tf_status_helper.h b/tensorflow/c/tf_status_helper.h index 829eea54512d75..7e54a5c4e76b98 100644 --- a/tensorflow/c/tf_status_helper.h +++ b/tensorflow/c/tf_status_helper.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_C_TF_STATUS_HELPER_H_ #define TENSORFLOW_C_TF_STATUS_HELPER_H_ +#include + #include "tensorflow/c/tf_status.h" #include "tensorflow/tsl/platform/status.h" @@ -41,4 +43,30 @@ using TF_StatusPtr = std::unique_ptr; } // namespace tensorflow +#define TF_STATUS_ASSIGN_OR_RETURN(lhs, rexpr, c_status) \ + _TF_STATUS_ASSIGN_OR_RETURN_IMPL( \ + _TF_STATUS_CONCAT(_status_or_value, __COUNTER__), lhs, rexpr, c_status); + +#define _TF_STATUS_ASSIGN_OR_RETURN_IMPL(statusor, lhs, rexpr, c_status) \ + auto statusor = (rexpr); \ + if (!statusor.ok()) { \ + tensorflow::Set_TF_Status_from_Status(c_status, statusor.status()); \ + return; \ + } \ + lhs = std::move(*statusor) + +#define TF_STATUS_RETURN_IF_ERROR(rexpr, c_status) \ + _TF_STATUS_RETURN_IF_ERROR_IMPL(_TF_STATUS_CONCAT(_status, __COUNTER__), \ + rexpr, c_status); + +#define _TF_STATUS_RETURN_IF_ERROR_IMPL(status, rexpr, c_status) \ + auto status = (rexpr); \ + if (!status.ok()) { \ + tensorflow::Set_TF_Status_from_Status(c_status, status); \ + return; \ + } + +#define _TF_STATUS_CONCAT(x, y) _TF_STATUS_CONCAT_IMPL(x, y) +#define _TF_STATUS_CONCAT_IMPL(x, y) x##y + #endif // TENSORFLOW_C_TF_STATUS_HELPER_H_ From 3874ea238780c117ec917d191099c479949803cc Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 26 Jul 2023 12:56:15 -0700 Subject: [PATCH 196/410] Correct the device assignment for tf._XlaCompile PiperOrigin-RevId: 551290442 --- .../compiler/mlir/tensorflow/tests/xla_rewrite_v2.mlir | 2 +- .../mlir/tensorflow/transforms/xla_rewrite_v2.cc | 9 +++------ 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/xla_rewrite_v2.mlir b/tensorflow/compiler/mlir/tensorflow/tests/xla_rewrite_v2.mlir index 010546c2946535..f943443e3a4da9 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/xla_rewrite_v2.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/xla_rewrite_v2.mlir @@ -78,7 +78,7 @@ module attributes {tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0"] func.func @outside_compilation_in_generic_pipeline(%arg0: tensor<2xi32>) -> tensor<2xi32> { // CHECK: tf_device.launch // CHECK: "tf._XlaCompile"() {function = @func, must_compile = true, operand_segment_sizes = array} - // CHECK: {device = "/job:localhost/replica:0/task:0/device:CPU:0"} + // CHECK: {device = "/job:localhost/replica:0/task:0/device:GPU:0"} // CHECK: tf_device.parallel_execute // CHECK: tf_device.launch // CHECK: tf.B diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/xla_rewrite_v2.cc b/tensorflow/compiler/mlir/tensorflow/transforms/xla_rewrite_v2.cc index 639e924f3200b4..830dd1cb124705 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/xla_rewrite_v2.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/xla_rewrite_v2.cc @@ -319,15 +319,12 @@ mlir::LogicalResult Rewrite(tf_device::ClusterFuncOp cluster_func, // Fetch compilation device std::string compilation_device; + if (failed(GetClusterFuncDevice(cluster_func, compilation_device))) + return failure(); + if (!old_parallel_execute) { - if (failed(GetClusterFuncDevice(cluster_func, compilation_device))) - return failure(); old_parallel_execute = mlir::TF::BuildParallelExecuteOp(cluster_func, &builder); - } else { - if (failed(GetCompilationDeviceFromParallelExecuteOp(old_parallel_execute, - compilation_device))) - return failure(); } // Build compile op _XlaCompile From 69541bba3ad52137f575ffde39793bc9140c78ad Mon Sep 17 00:00:00 2001 From: Matt Callanan Date: Wed, 26 Jul 2023 13:14:10 -0700 Subject: [PATCH 197/410] #tf-data-service Graduate `"data_transfer"` experiment. PiperOrigin-RevId: 551295403 --- tensorflow/core/data/dataset_utils.cc | 2 +- tensorflow/core/data/service/client/data_service_client.cc | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/tensorflow/core/data/dataset_utils.cc b/tensorflow/core/data/dataset_utils.cc index 733ccb42554444..c31e7ec1fc5965 100644 --- a/tensorflow/core/data/dataset_utils.cc +++ b/tensorflow/core/data/dataset_utils.cc @@ -974,7 +974,7 @@ REGISTER_DATASET_EXPERIMENT("stage_based_autotune", RandomJobSamplePercentage<0>, IndependentHostTasks); REGISTER_DATASET_EXPERIMENT("stage_based_autotune_v2", RandomJobSamplePercentage<0>, IndependentHostTasks); -REGISTER_DATASET_EXPERIMENT("data_transfer", RandomJobSamplePercentage<50>, +REGISTER_DATASET_EXPERIMENT("data_transfer", RandomJobSamplePercentage<0>, AllTasks); REGISTER_DATASET_EXPERIMENT("file_locality", RandomJobSamplePercentage<10>, IndependentHostTasks); diff --git a/tensorflow/core/data/service/client/data_service_client.cc b/tensorflow/core/data/service/client/data_service_client.cc index 5cbb955f8bb075..664153a5dc62ca 100644 --- a/tensorflow/core/data/service/client/data_service_client.cc +++ b/tensorflow/core/data/service/client/data_service_client.cc @@ -378,8 +378,6 @@ DataServiceClient::CreateWorkerClient(const TaskInfo& task_info) { } if (std::string default_protocol = DefaultDataTransferProtocol(); default_protocol != kGrpcTransferProtocol) { - LOG(INFO) - << "This task is participating in the \"data_transfer\" experiment."; StatusOr transfer_server = GetTransferServer(default_protocol, task_info); if (transfer_server.ok()) { From c2a9dd367b522244449e8d2392ff178a892af8bc Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 26 Jul 2023 13:22:33 -0700 Subject: [PATCH 198/410] Potential improvement opportunity to eliminate extra transpose op. This CL will add patterns to fold Transpose and FC to covert into a BMM, like below- FC(lhs, Transpose(rhs)) -> BMM(lha, rhs, false, false) The right thing to do in this pattern will be to apply the pattern only if keep_num_dims==True. Because, if the output rank is less-than the input rank, it means `keep_num_dims` has reduced the output. But checking for rank will improve the coverage. This pattern will now work PiperOrigin-RevId: 551297769 --- .../tests/end2end/unroll_batch_matmul.pbtxt | 7 +++-- .../compiler/mlir/lite/tests/optimize.mlir | 28 ------------------- .../mlir/lite/transforms/optimize_patterns.td | 26 ----------------- tensorflow/lite/python/analyzer_test.py | 18 +++++++----- 4 files changed, 16 insertions(+), 63 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/tests/end2end/unroll_batch_matmul.pbtxt b/tensorflow/compiler/mlir/lite/tests/end2end/unroll_batch_matmul.pbtxt index 1d6c0ef73bbab9..0657bcbf8126d5 100644 --- a/tensorflow/compiler/mlir/lite/tests/end2end/unroll_batch_matmul.pbtxt +++ b/tensorflow/compiler/mlir/lite/tests/end2end/unroll_batch_matmul.pbtxt @@ -78,12 +78,15 @@ versions { } # CHECK: func @main(%[[VAL_0:.*]]: tensor<2x5x3xf32>, %[[VAL_1:.*]]: tensor<3x7xf32>) -> tensor<2x5x7xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "Placeholder,Placeholder_1", outputs = "MatMul"}} { +# CHECK-DAG: %[[VAL_2:.*]] = arith.constant dense<[1, 0]> : tensor<2xi32> +# CHECK-DAG: %[[VAL_3:.*]] = "tfl.no_value"() {value} : () -> none # CHECK-DAG: %[[VAL_6:.*]] = arith.constant dense<0> : tensor # CHECK: %[[VAL_7:.*]]:2 = "tfl.split"(%[[VAL_6]], %[[VAL_0]]) {num_splits = 2 : i32} : (tensor, tensor<2x5x3xf32>) -> (tensor<1x5x3xf32>, tensor<1x5x3xf32>) # CHECK: %[[VAL_8:.*]] = "tfl.reshape"(%[[VAL_7]]#0, %cst) : (tensor<1x5x3xf32>, tensor<2xi32>) -> tensor<5x3xf32> # CHECK: %[[VAL_9:.*]] = "tfl.reshape"(%[[VAL_7]]#1, %cst) : (tensor<1x5x3xf32>, tensor<2xi32>) -> tensor<5x3xf32> -# CHECK: %[[VAL_11:.*]] = "tfl.batch_matmul"(%[[VAL_8]], %[[VAL_1]]) {adj_x = false, adj_y = false} : (tensor<5x3xf32>, tensor<3x7xf32>) -> tensor<5x7xf32> -# CHECK: %[[VAL_12:.*]] = "tfl.batch_matmul"(%[[VAL_9]], %[[VAL_1]]) {adj_x = false, adj_y = false} : (tensor<5x3xf32>, tensor<3x7xf32>) -> tensor<5x7xf32> +# CHECK: %[[VAL_10:.*]] = "tfl.transpose"(%[[VAL_1]], %[[VAL_2]]) : (tensor<3x7xf32>, tensor<2xi32>) -> tensor<7x3xf32> +# CHECK: %[[VAL_11:.*]] = "tfl.fully_connected"(%[[VAL_8]], %[[VAL_10]], %[[VAL_3]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<5x3xf32>, tensor<7x3xf32>, none) -> tensor<5x7xf32> +# CHECK: %[[VAL_12:.*]] = "tfl.fully_connected"(%[[VAL_9]], %[[VAL_10]], %[[VAL_3]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<5x3xf32>, tensor<7x3xf32>, none) -> tensor<5x7xf32> # CHECK: %[[VAL_13:.*]] = "tfl.pack"(%[[VAL_11]], %[[VAL_12]]) {axis = 0 : i32, values_count = 2 : i32} : (tensor<5x7xf32>, tensor<5x7xf32>) -> tensor<2x5x7xf32> # CHECK: return %[[VAL_13]] : tensor<2x5x7xf32> # CHECK: } diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir index 8d8c8d6eb50f90..da0cc252529c61 100644 --- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir @@ -686,34 +686,6 @@ func.func @FuseTransposeIntoBMM_RHS2(%arg0: tensor, %arg1: tensor } -// CHECK-LABEL: @FuseTransposeIntoFC_RHS -func.func @FuseTransposeIntoFC_RHS(%arg0: tensor<1x4x1440x256xf32>, %arg1: tensor<256x1440xf32>) -> tensor<1x4x1440x1440xf32> { - %cst = "tfl.no_value"() {value} : () -> none - %cst_1 = arith.constant dense<[1, 0]> : tensor<2xi32> - %32 = "tfl.transpose"(%arg1, %cst_1) : (tensor<256x1440xf32>, tensor<2xi32>) -> tensor<1440x256xf32> - %33 = "tfl.fully_connected"(%arg0, %32, %cst) {fused_activation_function = "NONE", keep_num_dims = true, weights_format = "DEFAULT"} : (tensor<1x4x1440x256xf32>, tensor<1440x256xf32>, none) -> tensor<1x4x1440x1440xf32> - return %33 : tensor<1x4x1440x1440xf32> - // CHECK: %0 = "tfl.batch_matmul"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<1x4x1440x256xf32>, tensor<256x1440xf32>) -> tensor<1x4x1440x1440xf32> - // CHECK: return %0 : tensor<1x4x1440x1440xf32> -} - -// CHECK-LABEL: @FuseTransposeIntoFC_RHS1 -func.func @FuseTransposeIntoFC_RHS1(%arg0: tensor<1x384x64x!quant.uniform>, %arg1: tensor<1x64x384x!quant.uniform>) -> tensor<1x384x384x!quant.uniform> { - %cst = "tfl.no_value"() {value} : () -> none - %cst_6 = arith.constant dense<[1, 0]> : tensor<2xi32> - %cst_7 = arith.constant dense<[64, 384]> : tensor<2xi32> - %cst_8 = arith.constant dense<[1, 384, 384]> : tensor<3xi32> - %46 = "tfl.reshape"(%arg1, %cst_7) : (tensor<1x64x384x!quant.uniform>, tensor<2xi32>) -> tensor<64x384x!quant.uniform> - %53 = "tfl.transpose"(%46, %cst_6) : (tensor<64x384x!quant.uniform>, tensor<2xi32>) -> tensor<384x64x!quant.uniform> - %33 = "tfl.fully_connected"(%arg0, %53, %cst) {asymmetric_quantize_inputs = false, fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x384x64x!quant.uniform>, tensor<384x64x!quant.uniform>, none) -> tensor<384x384x!quant.uniform> - %34 = "tfl.reshape"(%33, %cst_8) : (tensor<384x384x!quant.uniform>, tensor<3xi32>) -> tensor<1x384x384x!quant.uniform> - return %34 : tensor<1x384x384x!quant.uniform> - // CHECK: %cst = arith.constant dense<[64, 384]> : tensor<2xi32> - // CHECK: %0 = "tfl.reshape"(%arg1, %cst) : (tensor<1x64x384x!quant.uniform>, tensor<2xi32>) -> tensor<64x384x!quant.uniform> - // CHECK: %1 = "tfl.batch_matmul"(%arg0, %0) {adj_x = false, adj_y = false, asymmetric_quantize_inputs = false} : (tensor<1x384x64x!quant.uniform>, tensor<64x384x!quant.uniform>) -> tensor<1x384x384x!quant.uniform> - // CHECK: return %1 : tensor<1x384x384x!quant.uniform> -} - // CHECK-LABEL: @FuseTransposeIntoBMM_LHS func.func @FuseTransposeIntoBMM_LHS(%arg0: tensor<1x4x1440x256xf32>, %arg1: tensor<1x1440x256xf32>) -> tensor<1x4x256x256xf32> { %cst_1 = arith.constant dense<[0, 2, 1]> : tensor<3xi32> diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td index 69a9f19619ed77..34a0cd741f6998 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td @@ -441,12 +441,6 @@ def IsRankLessThanEqualTo : Constraint().getRank() <= " "$1.getType().cast().getRank()">>; -// Constraint check to test if the rank of an element in a ValueRange($0) -// is equal to the value($1) -class IsValueArrayElementRankEqualTo : Constraint().getRank() == " - "$1.getType().cast().getRank()">>; - def Flatten : NativeCodeCall< "$0.cast()" ".reshape(RankedTensorType::get({$0.getType().cast().getNumElements()}, " @@ -1512,23 +1506,3 @@ def FuseTransposeIntoBatchMatMulLHS: Pat< [(AreLastTwoDimsTransposed $perm_value), (IsBoolAttrEqual<"false"> $adj_x)]>; -// Fuse redundant TFL_TransposeOp into TFL_FullyConnecedOp to form TFL_BatchMatMulOp -def FuseTransposeIntoFullyConnectedRHS: Pat< - (TFL_FullyConnectedOp:$output $lhs, - (TFL_TransposeOp $input, (Arith_ConstantOp:$perm_value $p0)), - $bias, - TFL_AF_None, - TFL_FCWO_Default, - $keep_num_dims, - $asymmetric_quantize_inputs), - (TFL_BatchMatMulOp $lhs, $input, - ConstBoolAttrFalse, ConstBoolAttrFalse, $asymmetric_quantize_inputs), - [(IsNoneType $bias), - //if the output rank is less-than the input rank, it means keep_num_dims has - //reduced the output. The right thing to do here will be to apply the pattern - //only if keep_num_dims==True. But checking for rank will improve the - //coverage. This pattern will now work for all the transpose->fc patterns - //where the fc has not done imlicit reshape of the output - (IsValueArrayElementRankEqualTo<0> $output, $lhs), - (AreLastTwoDimsTransposed $perm_value)]>; - diff --git a/tensorflow/lite/python/analyzer_test.py b/tensorflow/lite/python/analyzer_test.py index a262823a89affd..22a6a939ba900c 100644 --- a/tensorflow/lite/python/analyzer_test.py +++ b/tensorflow/lite/python/analyzer_test.py @@ -221,19 +221,23 @@ def func(lhs, rhs): with test.mock.patch.object(sys, 'stdout', mock_stdout): analyzer.ModelAnalyzer.analyze(model_content=fb_model) txt = mock_stdout.getvalue() - self.assertIn('Op#0 RESHAPE(T#1, T#4[512, 512]) -> [T#5]', txt) - self.assertIn('Op#1 RESHAPE(T#0, T#3[100, 512]) -> [T#6]', txt) - self.assertIn('Op#2 BATCH_MATMUL(T#6, T#5) -> [T#7]', txt) - self.assertIn('Op#3 RESHAPE(T#7, T#2[1, 100, 8, 64]) -> [T#8]', txt) + self.assertIn('Op#0 RESHAPE(T#1, T#5[512, 512]) -> [T#6]', txt) + self.assertIn('Op#1 RESHAPE(T#0, T#3[100, 512]) -> [T#7]', txt) + self.assertIn('Op#2 TRANSPOSE(T#6, T#4[1, 0]) -> [T#8]', txt) + self.assertIn('Op#3 FULLY_CONNECTED(T#7, T#8, T#-1) -> [T#9]', txt) + self.assertIn('Op#4 RESHAPE(T#9, T#2[1, 100, 8, 64]) -> [T#10]', txt) self.assertIn( 'T#2(einsum/Einsum) shape:[4], type:INT32 RO 16 bytes, ' 'buffer: 3, data:[1, 100, 8, 64]', txt) self.assertIn( - 'T#3(einsum/Einsum1) shape:[2], type:INT32 RO 8 bytes, ' + 'T#3(einsum/Einsum2) shape:[2], type:INT32 RO 8 bytes, ' 'buffer: 4, data:[100, 512]', txt) self.assertIn( - 'T#4(einsum/Einsum2) shape:[2], type:INT32 RO 8 bytes, ' - 'buffer: 5, data:[512, 512]', txt) + 'T#4(einsum/Einsum3) shape:[2], type:INT32 RO 8 bytes, ' + 'buffer: 5, data:[1, 0]', txt) + self.assertIn( + 'T#5(einsum/Einsum4) shape:[2], type:INT32 RO 8 bytes, ' + 'buffer: 6, data:[512, 512]', txt) if __name__ == '__main__': test.main() From f9e20454fed47661defce813dc7e9d20145fc82e Mon Sep 17 00:00:00 2001 From: Junwhan Ahn Date: Wed, 26 Jul 2023 15:01:17 -0700 Subject: [PATCH 199/410] Remove one use of inlining in XlaCallModule shape refinement. To improve debuggability, we want the shape refinement to make as few changes as possible to the module. In this change we remove one use of inlining. PiperOrigin-RevId: 551325242 --- .../compiler/tests/xla_call_module_test.py | 4 +- .../tf2xla/kernels/xla_call_module_loader.cc | 39 +++++++++++++------ 2 files changed, 30 insertions(+), 13 deletions(-) diff --git a/tensorflow/compiler/tests/xla_call_module_test.py b/tensorflow/compiler/tests/xla_call_module_test.py index 8f7bed7fe54d8a..8b75dc450188b1 100644 --- a/tensorflow/compiler/tests/xla_call_module_test.py +++ b/tensorflow/compiler/tests/xla_call_module_test.py @@ -249,7 +249,7 @@ def test_wrong_actual_args_errors(self): # x: f32[a, 2], return x module, version = serialize(""" module @jit_f.0 attributes {jax.uses_shape_polymorphism = true} { - func.func public @main(%arg0: tensor, %arg1: tensor) -> tensor { + func.func public @main(%arg0: tensor, %arg1: tensor<*xi32>) -> tensor { return %arg0 : tensor } } @@ -279,7 +279,7 @@ def f(x, y): with self.assertRaisesRegex( errors.InvalidArgumentError, 'Element type mismatch for argument 1 passed to XlaCallModule: ' - r'expecting tensor<\?x\?xi32>, got tensor<2x3xf32>', + r'expecting tensor<\*xi32>, got tensor<2x3xf32>', ): self._assertOpOutputMatchesExpected(f, (x, y_bad_etype), (x,)) diff --git a/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc b/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc index 79257958431c7c..1557c7fd2dbc8d 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc @@ -375,23 +375,40 @@ tsl::Status XlaCallModuleLoader::RefineDynamicShapes( } } - // Refine 'main' argument types to use static input types instead. The main - // arguments may occur as return values, or as inputs to called functions, - // and changing their types may invalidate the module. To prevent this - // we insert dummy conversion ops as the sole uses of the main arguments. - mlir::OpBuilder op_builder(module_->getBodyRegion()); - op_builder.setInsertionPointToStart(&main_body); - for (auto i = 0; i < main_body.getNumArguments(); ++i) { - mlir::BlockArgument arg = main_body.getArgument(i); - auto convert_op = op_builder.create( - arg.getLoc(), arg.getType(), arg); - arg.replaceAllUsesExcept(convert_op, convert_op); + // Refine 'main' argument types to use static input types instead. + // This will only change the argument types and will not propagate the + // additional type information further. For that, we'll need to run + // shape refinement as explained below. + // Before refining the argument types it is useful to run the inliner to + // remove calls that may be called with the input arguments. + { + mlir::StatusScopedDiagnosticHandler diag_handler(module_->getContext()); + + mlir::PassManager pm_inline(module_->getContext()); + applyTensorflowAndCLOptions(pm_inline); + pm_inline.addPass(mlir::createInlinerPass()); + + if (mlir::failed(pm_inline.run(*module_))) { + return absl::InvalidArgumentError(absl::StrCat( + "Module inlining failed: ", diag_handler.ConsumeStatus().ToString())); + } } auto static_array_output_types = llvm::to_vector(main_.getResultTypes()); for (auto i = 0; i < main_body.getNumArguments(); ++i) { auto arg = main_body.getArgument(i); arg.setType(static_array_input_types[i]); + // If the argument is used by `func.return`, then we also need to + // update the function result types. It's not great that we need this hack, + // but in the future when we have stablehlo.func, stablehlo.return, etc, + // this will not be needed. + // TODO(burmako): Once https://github.com/openxla/stablehlo/issues/425 is + // fixed, clean this up. + for (mlir::OpOperand &use : arg.getUses()) { + if (auto ret = llvm::dyn_cast(use.getOwner())) { + static_array_output_types[use.getOperandNumber()] = arg.getType(); + } + } } main_.setType(builder.getFunctionType(static_array_input_types, static_array_output_types)); From fe33928038aca10e23e2f58c08869ced9fd7311e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 26 Jul 2023 16:23:55 -0700 Subject: [PATCH 200/410] Merge consecutive Pad operators PiperOrigin-RevId: 551347216 --- .../mlir/lite/stablehlo/tests/optimize.mlir | 164 ++++++++++++++++++ .../lite/stablehlo/transforms/optimize.cc | 100 +++++++++++ 2 files changed, 264 insertions(+) diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/optimize.mlir index d59c5488240fdd..722fc5b47459f8 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/optimize.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/optimize.mlir @@ -102,6 +102,170 @@ func.func @testRemoveReshapeAroundDot(%arg0: tensor<1x1x512xf32>, %arg1: tensor< // ----- +// CHECK-LABEL: testTwoConsecutivePads +func.func @testTwoConsecutivePads(%arg0: tensor<10x10x10xf32>) -> (tensor<12x12x12xf32>) { + %0 = mhlo.constant dense<0.000000e+00> : tensor + %1 = "mhlo.pad"(%arg0, %0) {edge_padding_high = dense<0> : tensor<3xi64>, edge_padding_low = dense<1> : tensor<3xi64>, interior_padding = dense<0> : tensor<3xi64>} : (tensor<10x10x10xf32>, tensor) -> tensor<11x11x11xf32> + %2 = mhlo.constant dense<0.000000e+00> : tensor + %3 = "mhlo.pad"(%1, %2) {edge_padding_high = dense<1> : tensor<3xi64>, edge_padding_low = dense<0> : tensor<3xi64>, interior_padding = dense<0> : tensor<3xi64>} : (tensor<11x11x11xf32>, tensor) -> tensor<12x12x12xf32> + return %3 : tensor<12x12x12xf32> +// CHECK: %[[RES:.*]] = "mhlo.pad"(%arg0, %0) { +// CHECK-SAME: edge_padding_high = dense<1> : tensor<3xi64>, +// CHECK-SAME: edge_padding_low = dense<1> : tensor<3xi64>, +// CHECK-SAME: interior_padding = dense<0> : tensor<3xi64> +// CHECK-SAME: } : (tensor<10x10x10xf32>, tensor) -> tensor<12x12x12xf32> +// CHECK: return %[[RES]] : tensor<12x12x12xf32> +} + +// ----- + +// CHECK-LABEL: testTwoConsecutivePadsNegativeLowPad +func.func @testTwoConsecutivePadsNegativeLowPad(%arg0: tensor<10x10x10xf32>) -> (tensor<10x10x10xf32>) { + %0 = mhlo.constant dense<0.000000e+00> : tensor + %1 = "mhlo.pad"(%arg0, %0) {edge_padding_high = dense<0> : tensor<3xi64>, edge_padding_low = dense<-1> : tensor<3xi64>, interior_padding = dense<0> : tensor<3xi64>} : (tensor<10x10x10xf32>, tensor) -> tensor<9x9x9xf32> + %2 = mhlo.constant dense<0.000000e+00> : tensor + %3 = "mhlo.pad"(%1, %2) {edge_padding_high = dense<1> : tensor<3xi64>, edge_padding_low = dense<0> : tensor<3xi64>, interior_padding = dense<0> : tensor<3xi64>} : (tensor<9x9x9xf32>, tensor) -> tensor<10x10x10xf32> + return %3 : tensor<10x10x10xf32> + +// CHECK: %[[RES:.*]] = "mhlo.pad"(%arg0, %0) { +// CHECK-SAME: edge_padding_high = dense<1> : tensor<3xi64>, +// CHECK-SAME: edge_padding_low = dense<-1> : tensor<3xi64>, +// CHECK-SAME: interior_padding = dense<0> : tensor<3xi64> +// CHECK-SAME: } : (tensor<10x10x10xf32>, tensor) -> tensor<10x10x10xf32> +// CHECK: return %[[RES]] : tensor<10x10x10xf32> +} + +// ----- + +// CHECK-LABEL: testTwoConsecutivePadsTwoNegativeHighPad +func.func @testTwoConsecutivePadsTwoNegativeHighPad(%arg0: tensor<10x10x10xf32>) -> (tensor<9x9x9xf32>) { + %0 = mhlo.constant dense<0.000000e+00> : tensor + %1 = "mhlo.pad"(%arg0, %0) {edge_padding_high = dense<-1> : tensor<3xi64>, edge_padding_low = dense<1> : tensor<3xi64>, interior_padding = dense<0> : tensor<3xi64>} : (tensor<10x10x10xf32>, tensor) -> tensor<10x10x10xf32> + %2 = mhlo.constant dense<0.000000e+00> : tensor + %3 = "mhlo.pad"(%1, %2) {edge_padding_high = dense<-1> : tensor<3xi64>, edge_padding_low = dense<0> : tensor<3xi64>, interior_padding = dense<0> : tensor<3xi64>} : (tensor<10x10x10xf32>, tensor) -> tensor<9x9x9xf32> + return %3 : tensor<9x9x9xf32> + +// CHECK: %[[RES:.*]] = "mhlo.pad"(%arg0, %0) { +// CHECK-SAME: edge_padding_high = dense<-2> : tensor<3xi64>, +// CHECK-SAME: edge_padding_low = dense<1> : tensor<3xi64>, +// CHECK-SAME: interior_padding = dense<0> : tensor<3xi64> +// CHECK-SAME: } : (tensor<10x10x10xf32>, tensor) -> tensor<9x9x9xf32> +// CHECK: return %[[RES]] : tensor<9x9x9xf32> +} + +// ----- + +// CHECK-LABEL: testTwoConsecutivePadsPositiveNegativeHighPad +func.func @testTwoConsecutivePadsPositiveNegativeHighPad(%arg0: tensor<10x10x10xf32>) -> (tensor<11x11x11xf32>) { + %0 = mhlo.constant dense<0.000000e+00> : tensor + %1 = "mhlo.pad"(%arg0, %0) {edge_padding_high = dense<1> : tensor<3xi64>, edge_padding_low = dense<1> : tensor<3xi64>, interior_padding = dense<0> : tensor<3xi64>} : (tensor<10x10x10xf32>, tensor) -> tensor<12x12x12xf32> + %2 = mhlo.constant dense<0.000000e+00> : tensor + %3 = "mhlo.pad"(%1, %2) {edge_padding_high = dense<-1> : tensor<3xi64>, edge_padding_low = dense<0> : tensor<3xi64>, interior_padding = dense<0> : tensor<3xi64>} : (tensor<12x12x12xf32>, tensor) -> tensor<11x11x11xf32> + return %3 : tensor<11x11x11xf32> + +// CHECK: %[[RES:.*]] = "mhlo.pad"(%arg0, %0) { +// CHECK-SAME: edge_padding_high = dense<0> : tensor<3xi64>, +// CHECK-SAME: edge_padding_low = dense<1> : tensor<3xi64>, +// CHECK-SAME: interior_padding = dense<0> : tensor<3xi64> +// CHECK-SAME: } : (tensor<10x10x10xf32>, tensor) -> tensor<11x11x11xf32> +// CHECK: return %[[RES]] : tensor<11x11x11xf32> +} + +// ----- + +// CHECK-LABEL: testTwoConsecutivePadsNegativePositiveHighPad +func.func @testTwoConsecutivePadsNegativePositiveHighPad(%arg0: tensor<10x10x10xf32>) -> (tensor<11x11x11xf32>) { + %0 = mhlo.constant dense<0.000000e+00> : tensor + %1 = "mhlo.pad"(%arg0, %0) {edge_padding_high = dense<-1> : tensor<3xi64>, edge_padding_low = dense<1> : tensor<3xi64>, interior_padding = dense<0> : tensor<3xi64>} : (tensor<10x10x10xf32>, tensor) -> tensor<10x10x10xf32> + %2 = mhlo.constant dense<0.000000e+00> : tensor + %3 = "mhlo.pad"(%1, %2) {edge_padding_high = dense<1> : tensor<3xi64>, edge_padding_low = dense<0> : tensor<3xi64>, interior_padding = dense<0> : tensor<3xi64>} : (tensor<10x10x10xf32>, tensor) -> tensor<11x11x11xf32> + return %3 : tensor<11x11x11xf32> + +// CHECK: "mhlo.pad"(%arg0, %0) { +// CHECK-SAME: edge_padding_high = dense<-1> : tensor<3xi64>, +// CHECK-SAME: edge_padding_low = dense<1> : tensor<3xi64>, +// CHECK-SAME: interior_padding = dense<0> : tensor<3xi64> +// CHECK-SAME: } : (tensor<10x10x10xf32>, tensor) -> tensor<10x10x10xf32> + +// CHECK: "mhlo.pad"(%1, %0) { +// CHECK-SAME: edge_padding_high = dense<1> : tensor<3xi64>, +// CHECK-SAME: edge_padding_low = dense<0> : tensor<3xi64>, +// CHECK-SAME: interior_padding = dense<0> : tensor<3xi64> +// CHECK-SAME: } : (tensor<10x10x10xf32>, tensor) -> tensor<11x11x11xf32> +} + +// ----- + +// CHECK-LABEL: testTwoConsecutivePadsDifferentPadVal +func.func @testTwoConsecutivePadsDifferentPadVal(%arg0: tensor<10x10x10xf32>) -> (tensor<14x14x14xf32>) { + %0 = mhlo.constant dense<1.000000e+00> : tensor + %1 = "mhlo.pad"(%arg0, %0) {edge_padding_high = dense<1> : tensor<3xi64>, edge_padding_low = dense<1> : tensor<3xi64>, interior_padding = dense<0> : tensor<3xi64>} : (tensor<10x10x10xf32>, tensor) -> tensor<12x12x12xf32> + %2 = mhlo.constant dense<0.000000e+00> : tensor + %3 = "mhlo.pad"(%1, %2) {edge_padding_high = dense<1> : tensor<3xi64>, edge_padding_low = dense<1> : tensor<3xi64>, interior_padding = dense<0> : tensor<3xi64>} : (tensor<12x12x12xf32>, tensor) -> tensor<14x14x14xf32> + return %3 : tensor<14x14x14xf32> + +// CHECK: "mhlo.pad"(%arg0, %1) { +// CHECK-SAME: edge_padding_high = dense<1> : tensor<3xi64>, +// CHECK-SAME: edge_padding_low = dense<1> : tensor<3xi64>, +// CHECK-SAME: interior_padding = dense<0> : tensor<3xi64> +// CHECK-SAME: } : (tensor<10x10x10xf32>, tensor) -> tensor<12x12x12xf32> + +// CHECK: "mhlo.pad"(%2, %0) { +// CHECK-SAME: edge_padding_high = dense<1> : tensor<3xi64>, +// CHECK-SAME: edge_padding_low = dense<1> : tensor<3xi64>, +// CHECK-SAME: interior_padding = dense<0> : tensor<3xi64> +// CHECK-SAME: } : (tensor<12x12x12xf32>, tensor) -> tensor<14x14x14xf32> +} + +// ----- + +// CHECK-LABEL: testTwoConsecutivePadsDifferentUsers +func.func @testTwoConsecutivePadsDifferentUsers(%arg0: tensor<10x10x10xf32>) -> (tensor<13x13x13xf32>, tensor<12x12x12xf32>) { + %0 = mhlo.constant dense<0.000000e+00> : tensor + %1 = "mhlo.pad"(%arg0, %0) {edge_padding_high = dense<1> : tensor<3xi64>, edge_padding_low = dense<1> : tensor<3xi64>, interior_padding = dense<0> : tensor<3xi64>} : (tensor<10x10x10xf32>, tensor) -> tensor<12x12x12xf32> + %2 = mhlo.exponential %1 : tensor<12x12x12xf32> + %3 = mhlo.constant dense<0.000000e+00> : tensor + %4 = "mhlo.pad"(%1, %3) {edge_padding_high = dense<1> : tensor<3xi64>, edge_padding_low = dense<0> : tensor<3xi64>, interior_padding = dense<0> : tensor<3xi64>} : (tensor<12x12x12xf32>, tensor) -> tensor<13x13x13xf32> + return %4, %2 : tensor<13x13x13xf32>, tensor<12x12x12xf32> + +// CHECK: "mhlo.pad"(%arg0, %0) { +// CHECK-SAME: edge_padding_high = dense<1> : tensor<3xi64>, +// CHECK-SAME: edge_padding_low = dense<1> : tensor<3xi64>, +// CHECK-SAME: interior_padding = dense<0> : tensor<3xi64> +// CHECK-SAME: } : (tensor<10x10x10xf32>, tensor) -> tensor<12x12x12xf32> + +// CHECK: "mhlo.pad"(%1, %0) { +// CHECK-SAME: edge_padding_high = dense<1> : tensor<3xi64>, +// CHECK-SAME: edge_padding_low = dense<0> : tensor<3xi64>, +// CHECK-SAME: interior_padding = dense<0> : tensor<3xi64> +// CHECK-SAME: } : (tensor<12x12x12xf32>, tensor) -> tensor<13x13x13xf32> +} + +// ----- + +// CHECK-LABEL: testTwoConsecutivePadsMultipleDownstreamUsers + func.func @testTwoConsecutivePadsMultipleDownstreamUsers(%arg0: tensor<10x10x10xf32>) -> (tensor<13x13x13xf32>, tensor<13x13x13xf32>) { + %0 = mhlo.constant dense<0.000000e+00> : tensor + %1 = "mhlo.pad"(%arg0, %0) {edge_padding_high = dense<1> : tensor<3xi64>, edge_padding_low = dense<1> : tensor<3xi64>, interior_padding = dense<0> : tensor<3xi64>} : (tensor<10x10x10xf32>, tensor) -> tensor<12x12x12xf32> + %2 = mhlo.constant dense<0.000000e+00> : tensor + %3 = "mhlo.pad"(%1, %2) {edge_padding_high = dense<1> : tensor<3xi64>, edge_padding_low = dense<0> : tensor<3xi64>, interior_padding = dense<0> : tensor<3xi64>} : (tensor<12x12x12xf32>, tensor) -> tensor<13x13x13xf32> + %4 = mhlo.exponential %3 : tensor<13x13x13xf32> + %5 = mhlo.tanh %3 : tensor<13x13x13xf32> + return %4, %5 : tensor<13x13x13xf32>, tensor<13x13x13xf32> + +// CHECK: "mhlo.pad"(%arg0, %0) { +// CHECK-SAME: edge_padding_high = dense<2> : tensor<3xi64>, +// CHECK-SAME: edge_padding_low = dense<1> : tensor<3xi64>, +// CHECK-SAME: interior_padding = dense<0> : tensor<3xi64> +// CHECK-SAME: } : (tensor<10x10x10xf32>, tensor) -> tensor<13x13x13xf32> + +// CHECK: mhlo.exponential %1 : tensor<13x13x13xf32> +// CHECK: mhlo.tanh %1 : tensor<13x13x13xf32> +// CHECK: return %2, %3 : tensor<13x13x13xf32>, tensor<13x13x13xf32> +} + +// ----- + // CHECK-LABEL: testLiftDotConcatLHSSimple func.func @testLiftDotConcatLHSSimple(%arg0: tensor<1x1x512xf32>, %arg1: tensor<2x1x512xf32>, %arg2: tensor<3x1x512xf32>, %arg3: tensor<512x13xf32>) -> tensor<6x1x13xf32> { %0 = "mhlo.dot_general"(%arg0, %arg3) { diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/optimize.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/optimize.cc index 8392c307fb939c..ffec1aa3b1b924 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/optimize.cc @@ -15,6 +15,7 @@ limitations under the License. #include #include +#include #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" @@ -451,6 +452,104 @@ LogicalResult FuseSliceConcat(mhlo::ConcatenateOp concat, return success(); } +// Converts: +// %y1 = pad(%x, pad_val, (p1_1,p1_2,p1_3, ...)) +// %y2 = pad(%y1, pad_val, (p2_1,p2_2,p2_3, ...)) +// To: +// %z = pad(%x, pad_val, (p1_1 + p2_1, p1_2 + p2_2, p1_3 + p2_3, ...)) +LogicalResult MergeConsecutivePad(mhlo::PadOp pad_op, + PatternRewriter &rewriter) { + // Fail for non-static shapes + if (!pad_op.getOperand().getType().hasStaticShape() || + !pad_op.getResult().getType().hasStaticShape() || + !pad_op.getPaddingValue().getType().hasStaticShape()) { + return rewriter.notifyMatchFailure(pad_op, "dynamic shapes not supported"); + } + + // Check if the operand is also a Pad op + auto parent_pad = + dyn_cast_or_null(pad_op.getOperand().getDefiningOp()); + if (!parent_pad) { + return rewriter.notifyMatchFailure(pad_op, "parent is not a pad operator"); + } + + // We need the parent pad to have exactly one use (which is the child pad), + // otherwise merging the two pads will create wrong shapes for the other + // users. + if (!parent_pad->hasOneUse()) { + return rewriter.notifyMatchFailure(pad_op, + "parent pad has more than one use"); + } + + // Fail for non-static shapes + if (!parent_pad.getOperand().getType().hasStaticShape() || + !parent_pad.getResult().getType().hasStaticShape() || + !parent_pad.getPaddingValue().getType().hasStaticShape()) { + return rewriter.notifyMatchFailure(parent_pad, + "dynamic shapes not supported"); + } + + // Check if the padding values are equal (otherwise merging is illegal) + // Because we are using the greedy pattern rewrite driver + // (applyPatternsAndFoldGreedily), all different constant operators with the + // same value will be replaced by a single constant operator of that value. + // Due to this, if the padding values in the input are equal, they will become + // the same constant operator and the following check (which compares memory + // addresses) works. + if (pad_op.getPaddingValue() != parent_pad.getPaddingValue()) { + return rewriter.notifyMatchFailure( + pad_op, "parent and child pad have different padding values"); + } + + // NOTE: Because negative paddings are allowed, we assert that if + // `parent_pad < 0` then `child_pad <= 0` The effect of the negative pad is to + // remove values, so for example if we have parent_pad = - 1, child_pad = 1 + // the merged pad will not change anything, while the un-merged will remove a + // value, then insert a 0 at its place. This only holds for low and high pads, + // the spec does not allow negative interior pads, so we don't check there. + auto low_pads = pad_op.getEdgePaddingLow().getValues(); + auto parent_low_pads = + parent_pad.getEdgePaddingLow().getValues(); + auto high_pads = pad_op.getEdgePaddingHigh().getValues(); + auto parent_high_pads = + parent_pad.getEdgePaddingHigh().getValues(); + auto interior_pads = pad_op.getInteriorPadding().getValues(); + auto parent_interior_pads = + parent_pad.getInteriorPadding().getValues(); + + // NOTE: Low/High/Interior pads have the same size + for (int i = 0; i < low_pads.size(); ++i) { + if (parent_low_pads[i].getInt() < 0 && low_pads[i].getInt() > 0) { + return rewriter.notifyMatchFailure( + pad_op, "can't merge consecutive negative and positive low pads"); + } + if (parent_high_pads[i].getInt() < 0 && high_pads[i].getInt() > 0) { + return rewriter.notifyMatchFailure( + pad_op, "can't merge consecutive negative and positive high pads"); + } + } + + std::vector new_low_pads(low_pads.size(), 0); + std::vector new_high_pads(high_pads.size(), 0); + std::vector new_interior_pads(interior_pads.size(), 0); + + for (int i = 0; i < low_pads.size(); ++i) { + new_low_pads[i] = low_pads[i].getInt() + parent_low_pads[i].getInt(); + new_high_pads[i] = high_pads[i].getInt() + parent_high_pads[i].getInt(); + new_interior_pads[i] = + interior_pads[i].getInt() + parent_interior_pads[i].getInt(); + } + + // Replace pad_op with a new pad having new attributes, taking the + // parent_pad's operand. (After this parent_pad has no users and is removed). + rewriter.replaceOpWithNewOp( + pad_op, pad_op.getType(), parent_pad.getOperand(), + parent_pad.getPaddingValue(), rewriter.getI64TensorAttr(new_low_pads), + rewriter.getI64TensorAttr(new_high_pads), + rewriter.getI64TensorAttr(new_interior_pads)); + return success(); +} + // Convert: // %input : 1xYxC // %1 = mhlo.reshape %param : (1xCxZ) -> CxZ @@ -530,6 +629,7 @@ class OptimizePass patterns.add(LiftDotConcatLHSAndRHS); patterns.add(FuseSliceConcat); patterns.add(ConvertReshapeDotRhsToBatchedDot); + patterns.add(MergeConsecutivePad); if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) { return signalPassFailure(); From 1c26b1cf986a2c9d587ae534d0df98b351377c41 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 26 Jul 2023 16:49:53 -0700 Subject: [PATCH 201/410] update internal files for release PiperOrigin-RevId: 551353292 --- ci/official/requirements_updater/BUILD.bazel | 36 +++++++++++++++++++ .../requirements_updater/release_updater.sh | 33 +++++++++++++++++ requirements_lock_3_10.txt | 14 ++++---- requirements_lock_3_11.txt | 14 ++++---- requirements_lock_3_9.txt | 14 ++++---- 5 files changed, 90 insertions(+), 21 deletions(-) create mode 100644 ci/official/requirements_updater/release_updater.sh diff --git a/ci/official/requirements_updater/BUILD.bazel b/ci/official/requirements_updater/BUILD.bazel index 4f2bb9b8d9edd4..8215ef5698b060 100644 --- a/ci/official/requirements_updater/BUILD.bazel +++ b/ci/official/requirements_updater/BUILD.bazel @@ -37,3 +37,39 @@ compile_pip_requirements_3_11( requirements_in = "requirements.in", requirements_txt = "requirements_lock_3_11.txt", ) + +compile_pip_requirements_3_9( + name = "requirements_3_9_release", + extra_args = [ + "--allow-unsafe", + "-P keras-nightly", + "-P tb-nightly", + "-P tf-estimator-nightly", + ], + requirements_in = "requirements.in", + requirements_txt = "requirements_lock_3_9.txt", +) + +compile_pip_requirements_3_10( + name = "requirements_3_10_release", + extra_args = [ + "--allow-unsafe", + "-P keras-nightly", + "-P tb-nightly", + "-P tf-estimator-nightly", + ], + requirements_in = "requirements.in", + requirements_txt = "requirements_lock_3_10.txt", +) + +compile_pip_requirements_3_11( + name = "requirements_3_11_release", + extra_args = [ + "--allow-unsafe", + "-P keras-nightly", + "-P tb-nightly", + "-P tf-estimator-nightly", + ], + requirements_in = "requirements.in", + requirements_txt = "requirements_lock_3_11.txt", +) diff --git a/ci/official/requirements_updater/release_updater.sh b/ci/official/requirements_updater/release_updater.sh new file mode 100644 index 00000000000000..81c5e5aa89c0ed --- /dev/null +++ b/ci/official/requirements_updater/release_updater.sh @@ -0,0 +1,33 @@ +#!/usr/bin/env bash +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# script to run pip-compile for keras, tensorboard, estimator deps. +# if there is a change in requirements.in then all lock files will be updated +# accordingly. + +mv BUILD.bazel BUILD + +SUPPORTED_VERSIONS=("3_9" "3_10" "3_11") + +for VERSION in "${SUPPORTED_VERSIONS[@]}" +do + cp ../../../requirements_lock_"$VERSION".txt "requirements_lock_"$VERSION".txt" + bazel run --experimental_convenience_symlinks=ignore //:requirements_"$VERSION"_release.update + sed -i '/^#/d' requirements_lock_"$VERSION".txt + mv "requirements_lock_"$VERSION".txt" ../../../requirements_lock_"$VERSION".txt +done + +mv BUILD BUILD.bazel diff --git a/requirements_lock_3_10.txt b/requirements_lock_3_10.txt index 65c54527d47670..94f2bedffe31b8 100644 --- a/requirements_lock_3_10.txt +++ b/requirements_lock_3_10.txt @@ -176,9 +176,9 @@ idna==3.4 \ jax==0.4.7 \ --hash=sha256:5e7002d74db25f97c99b979d4ba1233b1ef26e1597e5fc468ad11d1c8a9dc4f8 # via -r ./requirements.in -keras-nightly==2.14.0.dev2023072407 \ - --hash=sha256:60ca7fae3ad903eeff858f45ddf9dc0dc395cdf41a07c26e807cb7f07e955441 \ - --hash=sha256:9eb387e3488f5ca87a4686b1ea93b8bd85b36d1006934e130f02185daf192ba5 +keras-nightly==2.14.0.dev2023072507 \ + --hash=sha256:78f8f218f78a7ee9af4e740bbd4a3ef1bbe1b34e206ae6b6245b38e0461a9109 \ + --hash=sha256:aae547bca76ae23f07b434e064803aace8999a8b6042121bacd5868c51f11025 # via -r ./requirements.in lit==16.0.6 \ --hash=sha256:84623c9c23b6b14763d637f4e63e6b721b3446ada40bf7001d8fee70b8e77a9a @@ -402,16 +402,16 @@ six==1.16.0 \ --hash=sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926 \ --hash=sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254 # via google-auth -tb-nightly==2.14.0a20230724 \ - --hash=sha256:eed70d9ddca771b938d7ee90b90673bdbfcd4da8bf292badab2088f26c4612fe +tb-nightly==2.14.0a20230725 \ + --hash=sha256:a0717250191893ec909066564acbc8c92329041a984cf9e407d301fc3ebf3f3e # via -r ./requirements.in tensorboard-data-server==0.7.1 \ --hash=sha256:255c02b7f5b03dd5c0a88c928e563441ff39e1d4b4a234cdbe09f016e53d9594 \ --hash=sha256:9938bd39f5041797b33921066fba0eab03a0dd10d1887a05e62ae58841ad4c3f \ --hash=sha256:be8d016a1aa394e6198280d4a3dc37898f56467310c5f5e617cac10a783e055a # via tb-nightly -tf-estimator-nightly==2.14.0.dev2023072408 \ - --hash=sha256:80b79192e643e923a2e075837785626bce5c35d2d2d34c03eee4fbcea8afb884 +tf-estimator-nightly==2.14.0.dev2023072508 \ + --hash=sha256:6e810e250d6d13732f3dbd28e8649cbc0eab7cbc35d4154d08b19b749e53a275 # via -r ./requirements.in urllib3==1.26.16 \ --hash=sha256:8d36afa7616d8ab714608411b4a3b13e58f463aee519024578e062e141dce20f \ diff --git a/requirements_lock_3_11.txt b/requirements_lock_3_11.txt index 65c54527d47670..94f2bedffe31b8 100644 --- a/requirements_lock_3_11.txt +++ b/requirements_lock_3_11.txt @@ -176,9 +176,9 @@ idna==3.4 \ jax==0.4.7 \ --hash=sha256:5e7002d74db25f97c99b979d4ba1233b1ef26e1597e5fc468ad11d1c8a9dc4f8 # via -r ./requirements.in -keras-nightly==2.14.0.dev2023072407 \ - --hash=sha256:60ca7fae3ad903eeff858f45ddf9dc0dc395cdf41a07c26e807cb7f07e955441 \ - --hash=sha256:9eb387e3488f5ca87a4686b1ea93b8bd85b36d1006934e130f02185daf192ba5 +keras-nightly==2.14.0.dev2023072507 \ + --hash=sha256:78f8f218f78a7ee9af4e740bbd4a3ef1bbe1b34e206ae6b6245b38e0461a9109 \ + --hash=sha256:aae547bca76ae23f07b434e064803aace8999a8b6042121bacd5868c51f11025 # via -r ./requirements.in lit==16.0.6 \ --hash=sha256:84623c9c23b6b14763d637f4e63e6b721b3446ada40bf7001d8fee70b8e77a9a @@ -402,16 +402,16 @@ six==1.16.0 \ --hash=sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926 \ --hash=sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254 # via google-auth -tb-nightly==2.14.0a20230724 \ - --hash=sha256:eed70d9ddca771b938d7ee90b90673bdbfcd4da8bf292badab2088f26c4612fe +tb-nightly==2.14.0a20230725 \ + --hash=sha256:a0717250191893ec909066564acbc8c92329041a984cf9e407d301fc3ebf3f3e # via -r ./requirements.in tensorboard-data-server==0.7.1 \ --hash=sha256:255c02b7f5b03dd5c0a88c928e563441ff39e1d4b4a234cdbe09f016e53d9594 \ --hash=sha256:9938bd39f5041797b33921066fba0eab03a0dd10d1887a05e62ae58841ad4c3f \ --hash=sha256:be8d016a1aa394e6198280d4a3dc37898f56467310c5f5e617cac10a783e055a # via tb-nightly -tf-estimator-nightly==2.14.0.dev2023072408 \ - --hash=sha256:80b79192e643e923a2e075837785626bce5c35d2d2d34c03eee4fbcea8afb884 +tf-estimator-nightly==2.14.0.dev2023072508 \ + --hash=sha256:6e810e250d6d13732f3dbd28e8649cbc0eab7cbc35d4154d08b19b749e53a275 # via -r ./requirements.in urllib3==1.26.16 \ --hash=sha256:8d36afa7616d8ab714608411b4a3b13e58f463aee519024578e062e141dce20f \ diff --git a/requirements_lock_3_9.txt b/requirements_lock_3_9.txt index 0bda44766a13e9..c0aa9293971d9f 100644 --- a/requirements_lock_3_9.txt +++ b/requirements_lock_3_9.txt @@ -180,9 +180,9 @@ importlib-metadata==6.8.0 \ jax==0.4.7 \ --hash=sha256:5e7002d74db25f97c99b979d4ba1233b1ef26e1597e5fc468ad11d1c8a9dc4f8 # via -r ./requirements.in -keras-nightly==2.14.0.dev2023072407 \ - --hash=sha256:60ca7fae3ad903eeff858f45ddf9dc0dc395cdf41a07c26e807cb7f07e955441 \ - --hash=sha256:9eb387e3488f5ca87a4686b1ea93b8bd85b36d1006934e130f02185daf192ba5 +keras-nightly==2.14.0.dev2023072507 \ + --hash=sha256:78f8f218f78a7ee9af4e740bbd4a3ef1bbe1b34e206ae6b6245b38e0461a9109 \ + --hash=sha256:aae547bca76ae23f07b434e064803aace8999a8b6042121bacd5868c51f11025 # via -r ./requirements.in lit==16.0.6 \ --hash=sha256:84623c9c23b6b14763d637f4e63e6b721b3446ada40bf7001d8fee70b8e77a9a @@ -406,16 +406,16 @@ six==1.16.0 \ --hash=sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926 \ --hash=sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254 # via google-auth -tb-nightly==2.14.0a20230724 \ - --hash=sha256:eed70d9ddca771b938d7ee90b90673bdbfcd4da8bf292badab2088f26c4612fe +tb-nightly==2.14.0a20230725 \ + --hash=sha256:a0717250191893ec909066564acbc8c92329041a984cf9e407d301fc3ebf3f3e # via -r ./requirements.in tensorboard-data-server==0.7.1 \ --hash=sha256:255c02b7f5b03dd5c0a88c928e563441ff39e1d4b4a234cdbe09f016e53d9594 \ --hash=sha256:9938bd39f5041797b33921066fba0eab03a0dd10d1887a05e62ae58841ad4c3f \ --hash=sha256:be8d016a1aa394e6198280d4a3dc37898f56467310c5f5e617cac10a783e055a # via tb-nightly -tf-estimator-nightly==2.14.0.dev2023072408 \ - --hash=sha256:80b79192e643e923a2e075837785626bce5c35d2d2d34c03eee4fbcea8afb884 +tf-estimator-nightly==2.14.0.dev2023072508 \ + --hash=sha256:6e810e250d6d13732f3dbd28e8649cbc0eab7cbc35d4154d08b19b749e53a275 # via -r ./requirements.in urllib3==1.26.16 \ --hash=sha256:8d36afa7616d8ab714608411b4a3b13e58f463aee519024578e062e141dce20f \ From 0d22c4bade0f70ee315eeed192d7fe7a3bdf70d4 Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Wed, 26 Jul 2023 18:45:02 -0700 Subject: [PATCH 202/410] [XLA:Python] Use int64_t instead of ssize_t (part 2) to fix another Mac compiler error Apparently ssize_t is only a long sometimes (at least 32-bit), instead of long long (at least 64-bit). I don't have a mac so I can't repro the failing build, but hopefully this fixes it based on the error message. PiperOrigin-RevId: 551376003 --- tensorflow/compiler/xla/python/py_array.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/xla/python/py_array.cc b/tensorflow/compiler/xla/python/py_array.cc index 3325347bb36212..61a4f26149bdcc 100644 --- a/tensorflow/compiler/xla/python/py_array.cc +++ b/tensorflow/compiler/xla/python/py_array.cc @@ -771,7 +771,7 @@ struct ExtraBufferInfo { : buffer(std::move(buffer)), external_reference_hold(std::move(external_reference_hold)) {} - std::vector strides; + std::vector strides; // We keep an external reference hold to the PjRtBuffer. This prevents a // use-after-free in the event that Delete() is called on a buffer with an // live buffer protocol view. It does however mean that Delete() sometimes @@ -877,7 +877,8 @@ int PyArray_bf_getbuffer(PyObject* exporter, Py_buffer* view, int flags) { if ((flags & PyBUF_STRIDES) == PyBUF_STRIDES) { extra->strides = ByteStridesForShape( buffer.element_type(), buffer.dimensions(), buffer.layout()); - view->strides = extra->strides.data(); + view->strides = reinterpret_cast( + const_cast(extra->strides.data())); } } } From 60217336ee14d6c92f08802452aa7a8f3e45ee5e Mon Sep 17 00:00:00 2001 From: "Balaji V. Iyer" Date: Wed, 26 Jul 2023 21:18:42 -0700 Subject: [PATCH 203/410] Added a workaround for broadcast. PiperOrigin-RevId: 551401683 --- .../tests/tpu-annotate-dynamic-shape-inputs.mlir | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu-annotate-dynamic-shape-inputs.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu-annotate-dynamic-shape-inputs.mlir index accdbdacca8b1c..75deb34c0f773a 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu-annotate-dynamic-shape-inputs.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu-annotate-dynamic-shape-inputs.mlir @@ -22,7 +22,10 @@ module attributes {tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/ // CHECK: mhlo.type_extensions func.func @tpu_func ( %arg0: tensor<2048xi32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg1: tensor<2048xi32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) -> (tensor<2048xi32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) { - %0 = "tf.AddV2"(%arg0, %arg1) : (tensor<2048xi32>, tensor<2048xi32>) -> tensor<2048xi32> - return %0 : tensor<2048xi32> + // TODO(b/292540052): Below tf.addV2 instruction is replaced with just + // returning arg0 due to the workaround mentioned in the above bug. Revert + // this after the bug is fixed. + // %0 = "tf.AddV2"(%arg0, %arg1) : (tensor<2048xi32>, tensor<2048xi32>) -> tensor<2048xi32> + return %arg0 : tensor<2048xi32> } } \ No newline at end of file From 54eeb2acce219fd5c8f4827823536d4168c0d64e Mon Sep 17 00:00:00 2001 From: Ce Zheng Date: Wed, 26 Jul 2023 21:56:42 -0700 Subject: [PATCH 204/410] [PJRT] Add PjRtDevice::PoisonExecution. PiperOrigin-RevId: 551408554 --- tensorflow/compiler/xla/pjrt/pjrt_client.h | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tensorflow/compiler/xla/pjrt/pjrt_client.h b/tensorflow/compiler/xla/pjrt/pjrt_client.h index 2aefb40a4996e8..98c96c2d5e7816 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_client.h +++ b/tensorflow/compiler/xla/pjrt/pjrt_client.h @@ -179,6 +179,22 @@ class PjRtDevice { // Returns the default memory space attached to this device. virtual StatusOr default_memory_space() const = 0; + + // Experimental: Poisons the earliest execution on this device with given + // launch_id if it's not finished yet, i.e. makes its output buffers error. + // + // Returns true if the output buffers have been successfully poisoned. + // + // Returns false if the output buffers were not successfully poisoned because + // launch_id is not in the list of executions that have not yet completed. + // This may happen either because the execution corresponding to launch_id has + // already completed, or because an incorrect launch_id was supplied. + // + // Returns error otherwise, including in the case that poisoning is not + // implemented by this client. + virtual StatusOr PoisonExecution(int32_t launch_id, Status error) { + return Unimplemented("PoisonExecution is not supported"); + } }; // Forward declaration. From 333bb69c750ffa60b2247044b0def43625ce9a95 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 26 Jul 2023 22:08:15 -0700 Subject: [PATCH 205/410] [IFRT] Update ShardingParam to also support scalars. PiperOrigin-RevId: 551410772 --- .../compiler/xla/python/ifrt/ir/ifrt_dialect.cc | 7 +++++++ .../compiler/xla/python/ifrt/ir/sharding_param.cc | 7 ++++--- .../xla/python/ifrt/ir/tests/verify_array.mlir | 12 ++++++++++-- 3 files changed, 21 insertions(+), 5 deletions(-) diff --git a/tensorflow/compiler/xla/python/ifrt/ir/ifrt_dialect.cc b/tensorflow/compiler/xla/python/ifrt/ir/ifrt_dialect.cc index 3f6b8cf4ffa7b6..968b9ef33e91e8 100644 --- a/tensorflow/compiler/xla/python/ifrt/ir/ifrt_dialect.cc +++ b/tensorflow/compiler/xla/python/ifrt/ir/ifrt_dialect.cc @@ -130,6 +130,13 @@ mlir::LogicalResult IfrtArrayType::verify( return mlir::failure(); } + if (shape.getRank() != sharding.dim_shards().size()) { + return emitError() << "Requires dim shards to have the same rank as the " + "array. Array rank is " + << shape.getRank() << " vs dim shards rank of " + << sharding.dim_shards().size(); + } + int devices_in_mesh = 1; for (const int axis_size : sharding.minor_to_major().axis_sizes) { devices_in_mesh *= axis_size; diff --git a/tensorflow/compiler/xla/python/ifrt/ir/sharding_param.cc b/tensorflow/compiler/xla/python/ifrt/ir/sharding_param.cc index b50b081c0743d4..59bb529a68b469 100644 --- a/tensorflow/compiler/xla/python/ifrt/ir/sharding_param.cc +++ b/tensorflow/compiler/xla/python/ifrt/ir/sharding_param.cc @@ -33,6 +33,10 @@ namespace { template void PrintDims(llvm::raw_ostream& os, llvm::ArrayRef dims) { + if (dims.empty()) { + // A scalar does not have dimensions. + return; + } os << dims[0]; for (int i = 1; i < dims.size(); ++i) { os << "x" << dims[i]; @@ -121,9 +125,6 @@ mlir::LogicalResult ShardingParam::verify( if (mlir::failed(minor_to_major().verify(emit_error))) { return mlir::failure(); } - if (dim_shards().empty()) { - return emit_error() << "Dim shards is empty"; - } int dim_index = 0; int cum_size = 1; diff --git a/tensorflow/compiler/xla/python/ifrt/ir/tests/verify_array.mlir b/tensorflow/compiler/xla/python/ifrt/ir/tests/verify_array.mlir index cb3f9a58122dcf..86eafd22d9508d 100644 --- a/tensorflow/compiler/xla/python/ifrt/ir/tests/verify_array.mlir +++ b/tensorflow/compiler/xla/python/ifrt/ir/tests/verify_array.mlir @@ -26,6 +26,14 @@ func.func @good_array_with_aliased_devices() { // ----- +func.func @good_array_scalar() { + %0 = builtin.unrealized_conversion_cast to + !ifrt.array, to [0,1] on 2x2, [0,1,2,3]> + return +} + +// ----- + func.func @array_devices_should_be_distinct() { // expected-error@+3 {{Device list has duplicate id 0}} // expected-error@+2 {{failed to parse Ifrt_ArrayType parameter 'devices_attr'}} @@ -82,8 +90,8 @@ func.func @array_requires_same_size_of_devices_and_from_axes() { // ----- -func.func @array_requires_non_empty_dim_shards() { - // expected-error@+2 {{Dim shards is empty}} +func.func @array_requires_rank_matching_dim_shards() { + // expected-error@+2 {{Requires dim shards to have the same rank as the array. Array rank is 2 vs dim shards rank of 0}} %0 = builtin.unrealized_conversion_cast to !ifrt.array, to [0,1] on 2x2, [0,1,2,3]> return From b03ea0834739bcc8642c160dad5db3aa60c2272e Mon Sep 17 00:00:00 2001 From: Anlun Xu Date: Wed, 26 Jul 2023 22:26:11 -0700 Subject: [PATCH 206/410] [xla:gpu] Fix the BFS algorithm in dataflow analysis The BFS algorithm didn't have a visited set, therefore had a complexity of O(N*E). PiperOrigin-RevId: 551414282 --- .../xla/mlir/backends/gpu/transforms/dataflow_analysis.cc | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/dataflow_analysis.cc b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/dataflow_analysis.cc index f4b23864a1fe08..bdbac015109fc9 100644 --- a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/dataflow_analysis.cc +++ b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/dataflow_analysis.cc @@ -173,14 +173,20 @@ bool Reachable(const DataflowAnalysis::DataflowGraph& graph, size_t from_index, std::queue bfs_queue; bfs_queue.push(from_index); + std::vector visited(graph.size(), false); + while (!bfs_queue.empty()) { size_t index = bfs_queue.front(); + visited[index] = true; bfs_queue.pop(); + if (index == to_index) return true; const DataflowAnalysis::Node& node = graph[index]; for (size_t child_index : node.children) { - bfs_queue.push(child_index); + if (!visited[child_index]) { + bfs_queue.push(child_index); + } } } From d5422e3857a3bcab5063fdd01600d4c15393c887 Mon Sep 17 00:00:00 2001 From: Dan Suh Date: Wed, 26 Jul 2023 22:56:12 -0700 Subject: [PATCH 207/410] Add uniform quantized `stablehlo.convolution` -> `tfl.conv_2d` conversion pattern. This change implements a conversion pattern that converts stablehlo.convolution to tfl.conv_2d. This is a minimal version that converts quantized `stablehlo.convolution` with certain assumptions like that the filter has the format of `[0, 1, i, o]`. PiperOrigin-RevId: 551419638 --- .../uniform-quantized-stablehlo-to-tfl.mlir | 84 ++++ ...uniform_quantized_stablehlo_to_tfl_pass.cc | 396 +++++++++++++++++- 2 files changed, 476 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/uniform-quantized-stablehlo-to-tfl.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/uniform-quantized-stablehlo-to-tfl.mlir index bfd94d7ebb8cb8..70f0b14b209a42 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/uniform-quantized-stablehlo-to-tfl.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/uniform-quantized-stablehlo-to-tfl.mlir @@ -101,3 +101,87 @@ func.func @uniform_dequantize_op_return_f64(%arg: tensor<2x2x!quant.uniform>) -> tensor<1x3x3x2x!quant.uniform> { + %0 = stablehlo.constant() {value = dense<3> : tensor<3x3x4x2xi8>} : () -> tensor<3x3x4x2x!quant.uniform> + %1 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4x!quant.uniform>, tensor<3x3x4x2x!quant.uniform>) -> tensor<1x3x3x2x!quant.uniform> + return %1 : tensor<1x3x3x2x!quant.uniform> +} +// CHECK-SAME: %[[ARG:.*]]: tensor<1x3x3x4x!quant.uniform> + +// Note that the quantized dimension is 0, and the shape has been transposed +// to (2, 3, 3, 4). +// CHECK: %[[QCONST_0:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<2x3x3x4x!quant.uniform>, value = dense<3> : tensor<2x3x3x4xi8>} : () -> tensor<2x3x3x4x!quant.uniform> +// CHECK: %[[QCONST_1:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<2x!quant.uniform>, value = dense<0> : tensor<2xi32>} : () -> tensor<2x!quant.uniform> +// CHECK: %[[CONV2D:.*]] = "tfl.conv_2d"(%[[ARG]], %[[QCONST_0]], %[[QCONST_1]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x3x3x4x!quant.uniform>, tensor<2x3x3x4x!quant.uniform>, tensor<2x!quant.uniform>) -> tensor<1x3x3x2x!quant.uniform> +// CHECK: return %[[CONV2D]] : tensor<1x3x3x2x!quant.uniform> + +// ----- + +// CHECK-LABEL: convolution_op_non_const_filter +func.func @convolution_op_non_const_filter(%arg0: tensor<1x3x3x4x!quant.uniform>, %arg1: tensor<3x3x4x2x!quant.uniform>) -> tensor<1x3x3x2x!quant.uniform> { + %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4x!quant.uniform>, tensor<3x3x4x2x!quant.uniform>) -> tensor<1x3x3x2x!quant.uniform> + return %0 : tensor<1x3x3x2x!quant.uniform> +} +// CHECK-SAME: %[[ARG:.*]]: tensor<1x3x3x4x!quant.uniform> + +// Confirm that the `stablehlo.convolution` is not converted to `tfl.conv_2d`. +// CHECK: stablehlo.convolution +// CHECK-NOT: tfl.conv_2d + +// ----- + +// Test that if the window padding contains values of 0, the resulting +// `padding` attribute of the `tfl.conv_2d` becomes "VALID". + +// CHECK-LABEL: convolution_op_valid_padding +func.func @convolution_op_valid_padding(%arg0: tensor<1x3x3x4x!quant.uniform>) -> tensor<1x1x1x2x!quant.uniform> { + %0 = stablehlo.constant() {value = dense<3> : tensor<3x3x4x2xi8>} : () -> tensor<3x3x4x2x!quant.uniform> + %1 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[0, 0], [0, 0]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4x!quant.uniform>, tensor<3x3x4x2x!quant.uniform>) -> tensor<1x1x1x2x!quant.uniform> + return %1 : tensor<1x1x1x2x!quant.uniform> +} +// CHECK-SAME: %[[ARG:.*]]: tensor<1x3x3x4x!quant.uniform> +// CHECK: %[[QCONST_0:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<2x3x3x4x!quant.uniform>, value = dense<3> : tensor<2x3x3x4xi8>} : () -> tensor<2x3x3x4x!quant.uniform> +// CHECK: %[[QCONST_1:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<2x!quant.uniform>, value = dense<0> : tensor<2xi32>} : () -> tensor<2x!quant.uniform> +// CHECK: %[[CONV2D:.*]] = "tfl.conv_2d"(%[[ARG]], %[[QCONST_0]], %[[QCONST_1]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x3x3x4x!quant.uniform>, tensor<2x3x3x4x!quant.uniform>, tensor<2x!quant.uniform>) -> tensor<1x1x1x2x!quant.uniform> +// CHECK: return %[[CONV2D]] : tensor<1x1x1x2x!quant.uniform> + +// ----- + +// Test that if the window padding value is missing, the resulting +// `padding` attribute of the `tfl.conv_2d` becomes "VALID". + +// CHECK-LABEL: convolution_op_valid_padding +func.func @convolution_op_valid_padding(%arg0: tensor<1x3x3x4x!quant.uniform>) -> tensor<1x1x1x2x!quant.uniform> { + %0 = stablehlo.constant() {value = dense<3> : tensor<3x3x4x2xi8>} : () -> tensor<3x3x4x2x!quant.uniform> + // The `window` attribute is empty. + %1 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4x!quant.uniform>, tensor<3x3x4x2x!quant.uniform>) -> tensor<1x1x1x2x!quant.uniform> + return %1 : tensor<1x1x1x2x!quant.uniform> +} +// CHECK-SAME: %[[ARG:.*]]: tensor<1x3x3x4x!quant.uniform> +// CHECK: %[[QCONST_0:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<2x3x3x4x!quant.uniform>, value = dense<3> : tensor<2x3x3x4xi8>} : () -> tensor<2x3x3x4x!quant.uniform> +// CHECK: %[[QCONST_1:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<2x!quant.uniform>, value = dense<0> : tensor<2xi32>} : () -> tensor<2x!quant.uniform> +// CHECK: %[[CONV2D:.*]] = "tfl.conv_2d"(%[[ARG]], %[[QCONST_0]], %[[QCONST_1]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x3x3x4x!quant.uniform>, tensor<2x3x3x4x!quant.uniform>, tensor<2x!quant.uniform>) -> tensor<1x1x1x2x!quant.uniform> +// CHECK: return %[[CONV2D]] : tensor<1x1x1x2x!quant.uniform> + +// ----- + +// Test that if the window stride value is explicitly set, the attribute +// value is transferred to tfl.conv_2d's stridw_h and stride_w values. + +// CHECK-LABEL: convolution_strides +func.func @convolution_strides(%arg0: tensor<1x3x3x4x!quant.uniform>) -> tensor<1x3x2x2x!quant.uniform> { + %0 = stablehlo.constant() {value = dense<3> : tensor<3x3x4x2xi8>} : () -> tensor<3x3x4x2x!quant.uniform> + // The stride value is explicitly set to [1, 2]. + %1 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 2], pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4x!quant.uniform>, tensor<3x3x4x2x!quant.uniform>) -> tensor<1x3x2x2x!quant.uniform> + return %1 : tensor<1x3x2x2x!quant.uniform> +} +// CHECK-SAME: %[[ARG:.*]]: tensor<1x3x3x4x!quant.uniform> +// CHECK: %[[QCONST_0:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<2x3x3x4x!quant.uniform>, value = dense<3> : tensor<2x3x3x4xi8>} : () -> tensor<2x3x3x4x!quant.uniform> +// CHECK: %[[QCONST_1:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<2x!quant.uniform>, value = dense<0> : tensor<2xi32>} : () -> tensor<2x!quant.uniform> +// Tests that the stride_w is set to 2. +// CHECK: %[[CONV2D:.*]] = "tfl.conv_2d"(%[[ARG]], %[[QCONST_0]], %[[QCONST_1]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 2 : i32} : (tensor<1x3x3x4x!quant.uniform>, tensor<2x3x3x4x!quant.uniform>, tensor<2x!quant.uniform>) -> tensor<1x3x2x2x!quant.uniform> +// CHECK: return %[[CONV2D]] : tensor<1x3x2x2x!quant.uniform> diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/uniform_quantized_stablehlo_to_tfl_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/uniform_quantized_stablehlo_to_tfl_pass.cc index d6548e271cdbc3..588880f1b8a120 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/uniform_quantized_stablehlo_to_tfl_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/uniform_quantized_stablehlo_to_tfl_pass.cc @@ -12,7 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include +#include #include #include "llvm/Support/Debug.h" @@ -28,11 +30,14 @@ limitations under the License. #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#define DEBUG_TYPE "uniform-quantized-stablehlo-to-tfl" + namespace mlir { namespace odml { namespace { -#define DEBUG_TYPE "uniform-quantized-stablehlo-to-tfl" +using quant::UniformQuantizedPerAxisType; +using quant::UniformQuantizedType; #define GEN_PASS_DEF_UNIFORMQUANTIZEDSTABLEHLOTOTFLPASS #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h.inc" @@ -59,6 +64,69 @@ bool IsSupportedByTfliteQuantizeOrDequantizeOps(IntegerType storage_type) { return true; } +// Returns true iff the storage type of `quantized_type` is 8-bit integer. +bool IsStorageTypeI8(QuantizedType quantized_type) { + const Type storage_type = quantized_type.getStorageType(); + return storage_type.isInteger(/*width=*/8); +} + +// Returns true iff the expressed type of `quantized_type` is f32. +bool IsExpressedTypeF32(QuantizedType quantized_type) { + const Type expressed_type = quantized_type.getExpressedType(); + return expressed_type.isa(); +} + +// Returns true iff `type` is a uniform quantized type whose storage type is +// 8-bit integer and expressed type is f32. +bool IsI8F32UniformQuantizedType(const Type type) { + auto quantized_type = type.dyn_cast_or_null(); + if (!quantized_type) { + LLVM_DEBUG(llvm::dbgs() + << "Expected a uniform quantized type. Got: " << type << ".\n"); + return false; + } + + if (!IsStorageTypeI8(quantized_type)) { + LLVM_DEBUG(llvm::dbgs() << "Expected an i8 storage type. Got: " + << quantized_type << ".\n"); + return false; + } + + if (!IsExpressedTypeF32(quantized_type)) { + LLVM_DEBUG(llvm::dbgs() << "Expected an f32 expressed type. Got: " + << quantized_type << ".\n"); + return false; + } + + return true; +} + +// Returns true iff `type` is a uniform quantized per-axis (per-channel) type +// whose storage type is 8-bit integer and expressed type is f32. +bool IsI8F32UniformQuantizedPerAxisType(const Type type) { + auto quantized_per_axis_type = + type.dyn_cast_or_null(); + if (!quantized_per_axis_type) { + LLVM_DEBUG(llvm::dbgs() + << "Expected a uniform quantized type. Got: " << type << ".\n"); + return false; + } + + if (!IsStorageTypeI8(quantized_per_axis_type)) { + LLVM_DEBUG(llvm::dbgs() << "Expected an i8 storage type. Got: " + << quantized_per_axis_type << ".\n"); + return false; + } + + if (!IsExpressedTypeF32(quantized_per_axis_type)) { + LLVM_DEBUG(llvm::dbgs() << "Expected an f32 expressed type. Got: " + << quantized_per_axis_type << ".\n"); + return false; + } + + return true; +} + // stablehlo.uniform_quantize -> tfl.quantize class RewriteUniformQuantizeOp : public OpRewritePattern { @@ -78,8 +146,8 @@ class RewriteUniformQuantizeOp return failure(); } - // Output type of `UniformQuantizeOp` is guaranteed to be a quantized tensor - // with integer storage type. + // Output type of `UniformQuantizeOp` is guaranteed to be a quantized + // tensor with integer storage type. const auto output_storage_type = op.getResult() .getType() .cast() @@ -151,12 +219,332 @@ class RewriteUniformDequantizeOp } }; +// Rewrites `stablehlo.convolution` -> `tfl.conv_2d` when it accepts uniform +// quantized tensors. +// +// Conditions for the conversion: +// * Input and output tensors are per-tensor uniform quantized (i8->f32) +// tensors. +// * The filter tensor is constant a per-channel uniform quantized (i8->f32) +// tensor. +// * Convolution is a 2D convolution op and both the input's and filter's +// shape is 4 dimensional. +// * The filter tensor's format is `[0, 1, i, o]`. +// * Not a depthwise convolution. +// * Does not consider bias add fusion. +class RewriteQuantizedConvolutionOp + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + static LogicalResult MatchInput(Value input) { + auto input_type = input.getType().cast(); + if (input_type.getRank() != 4) { + LLVM_DEBUG(llvm::dbgs() << "Only 2D convolution op is supported. " + "Expected input rank of 4. Got: " + << input_type.getRank() << ".\n"); + return failure(); + } + + if (const auto input_element_type = input_type.getElementType(); + !IsI8F32UniformQuantizedType(input_element_type)) { + LLVM_DEBUG(llvm::dbgs() + << "Expected an i8->f32 uniform quantized type. Got: " + << input_element_type << ".\n"); + return failure(); + } + + return success(); + } + + static LogicalResult MatchFilter(Value filter) { + auto filter_type = filter.getType().cast(); + if (filter_type.getRank() != 4) { + LLVM_DEBUG(llvm::dbgs() << "Only 2D convolution op is supported. " + "Expected filter rank of 4. Got: " + << filter_type.getRank() << ".\n"); + return failure(); + } + + const Type filter_element_type = filter_type.getElementType(); + if (!IsI8F32UniformQuantizedPerAxisType(filter_type.getElementType())) { + LLVM_DEBUG( + llvm::dbgs() + << "Expected a per-channel uniform quantized (i8->f32) type. Got: " + << filter_element_type << "\n"); + return failure(); + } + + if (filter_element_type.cast() + .getQuantizedDimension() != 3) { + LLVM_DEBUG(llvm::dbgs() << "Quantized dimension should be 3. Got: " + << filter_element_type << "\n"); + return failure(); + } + + if (Operation* filter_op = filter.getDefiningOp(); + filter_op == nullptr || !isa(filter_op)) { + LLVM_DEBUG(llvm::dbgs() << "Filter should be a constant.\n"); + return failure(); + } + + return success(); + } + + static LogicalResult MatchOutput(Value output) { + const Type output_element_type = + output.getType().cast().getElementType(); + if (!IsI8F32UniformQuantizedType(output_element_type)) { + LLVM_DEBUG(llvm::dbgs() + << "Expected a uniform quantized (i8->f32) type. Got: " + << output_element_type << ".\n"); + return failure(); + } + + return success(); + } + + LogicalResult match(stablehlo::ConvolutionOp op) const override { + stablehlo::ConvDimensionNumbersAttr dimension_numbers = + op.getDimensionNumbers(); + + const int64_t output_dimension = + dimension_numbers.getKernelOutputFeatureDimension(); + if (output_dimension != 3) { + LLVM_DEBUG(llvm::dbgs() << "Expected kernel output feature == 3. Got: " + << output_dimension << ".\n"); + return failure(); + } + + const int64_t input_dimension = + dimension_numbers.getKernelInputFeatureDimension(); + if (input_dimension != 2) { + LLVM_DEBUG(llvm::dbgs() << "Expected kernel input feature == 2. Got: " + << input_dimension << ".\n"); + return failure(); + } + + if (failed(MatchInput(op.getOperand(0)))) { + LLVM_DEBUG(llvm::dbgs() + << "Failed to match input for quantized convolution_op.\n"); + return failure(); + } + + if (failed(MatchFilter(op.getOperand(1)))) { + LLVM_DEBUG(llvm::dbgs() + << "Failed to match filter for quantized convolution_op.\n"); + return failure(); + } + + if (failed(MatchOutput(op.getResult()))) { + LLVM_DEBUG(llvm::dbgs() + << "Failed to match output for quantized convolution_op.\n"); + return failure(); + } + + return success(); + } + + void rewrite(stablehlo::ConvolutionOp op, + PatternRewriter& rewriter) const override { + Value filter_value = op.getOperand(1); + Operation* filter_op = filter_value.getDefiningOp(); + + auto filter_uniform_quantized_type = + filter_value.getType() + .cast() + .getElementType() + .cast(); + + // Create a new quantized tensor type for the filter. This is required + // because the quantized dimension is changed from 3 -> 0. `TFL::Conv2DOp` + // requires the quantized dimension to be 0 because it accepts a filter + // tensor of format OHWI + // (https://github.com/tensorflow/tensorflow/blob/5430e5e238f868ce977df96ba89c9c1d31fbe8fa/tensorflow/compiler/mlir/lite/ir/tfl_ops.td#L933). + // The quantized dimension should correspond to the output feature + // dimension. + auto new_filter_quantized_type = UniformQuantizedPerAxisType::getChecked( + filter_op->getLoc(), /*flags=*/true, + /*storageType=*/filter_uniform_quantized_type.getStorageType(), + filter_uniform_quantized_type.getExpressedType(), + filter_uniform_quantized_type.getScales(), + filter_uniform_quantized_type.getZeroPoints(), + /*quantizedDimension=*/0, + filter_uniform_quantized_type.getStorageTypeMin(), + filter_uniform_quantized_type.getStorageTypeMax()); + + auto filter_constant_value_attr = cast( + cast(filter_value.getDefiningOp()).getValue()); + + // Using TransposeOp doesn't work because the quantized dimension + // changes which violates the constraint for the TransposeOp that the + // input's and output's element type should be the same. + const DenseIntElementsAttr new_filter_value_attr = TransposeFilterValue( + filter_op->getLoc(), rewriter, filter_constant_value_attr); + + auto new_filter_result_type = RankedTensorType::getChecked( + filter_op->getLoc(), + /*shape=*/new_filter_value_attr.getShapedType().getShape(), + /*type=*/new_filter_quantized_type); + + auto new_filter_constant_op = rewriter.create( + filter_op->getLoc(), /*output=*/TypeAttr::get(new_filter_result_type), + new_filter_value_attr); + + // Create a bias filled with zeros. Mimics the behavior of no bias add. + const int64_t num_output_features = new_filter_result_type.getShape()[0]; + const SmallVector bias_shape = {num_output_features}; + auto bias_quantized_type = UniformQuantizedPerAxisType::getChecked( + op.getLoc(), /*flags=*/true, + /*storageType=*/rewriter.getI32Type(), // i32 for bias + /*expressedType=*/rewriter.getF32Type(), + // TODO: b/292886169 - Set this to be s1 * s2. + /*scales=*/new_filter_quantized_type.getScales(), + /*zeroPoints=*/new_filter_quantized_type.getZeroPoints(), + /*quantizedDimension=*/0, + /*storageTypeMin=*/std::numeric_limits::min(), + /*storageTypeMax=*/std::numeric_limits::max()); + auto bias_type = RankedTensorType::getChecked(op.getLoc(), bias_shape, + bias_quantized_type); + + // Create a bias constant. It should have values of 0. + auto bias_value_type = RankedTensorType::getChecked(op.getLoc(), bias_shape, + rewriter.getI32Type()); + auto bias_value = DenseIntElementsAttr::get( + bias_value_type, APInt(/*numBits=*/32, /*value=*/0, /*isSigned=*/true)); + auto bias = rewriter.create( + op.getLoc(), /*output=*/TypeAttr::get(bias_type), + /*value=*/bias_value); + + // Determine the attributes for the TFL::Conv2DOp. + const std::string padding = GetPadding(op); + const auto [stride_h, stride_w] = GetStrides(op); + const auto [dilation_h_factor, dilation_w_factor] = GetDilationFactors(op); + + Value input_value = op.getOperand(0); + auto tfl_conv2d_op = rewriter.create( + op.getLoc(), /*output=*/op.getResult().getType(), + /*input=*/input_value, + /*filter=*/new_filter_constant_op, /*bias=*/bias.getResult(), + /*dilation_h_factor=*/rewriter.getI32IntegerAttr(dilation_h_factor), + /*dilation_w_factor=*/rewriter.getI32IntegerAttr(dilation_w_factor), + /*fused_activation_function=*/rewriter.getStringAttr("NONE"), + /*padding=*/rewriter.getStringAttr(padding), + /*stride_h=*/rewriter.getI32IntegerAttr(stride_h), + /*stride_w=*/rewriter.getI32IntegerAttr(stride_w)); + + rewriter.replaceAllUsesWith(op.getResult(), tfl_conv2d_op.getResult()); + rewriter.eraseOp(op); + } + + private: + // Transposes the filter tensor to match the filter tensor format for + // `tfl.conv_2d`. This function performs the following index permutation + // only: (3, 0, 1, 2). The filter value is assumed to be of `[0, 1, i, o]` + // format. The `tfl.conv_2d` accepts the filter of `[o, 0, 1, i]`. + // TODO: b/291598373 - Lift the assumption about the filter tensor's format + // and generalize the transpose. + DenseIntElementsAttr TransposeFilterValue( + Location loc, PatternRewriter& rewriter, + const DenseIntElementsAttr& filter_value_attr) const { + ArrayRef filter_shape = + filter_value_attr.getShapedType().getShape(); + SmallVector filter_constant_values; + for (const auto filter_val : filter_value_attr.getValues()) { + filter_constant_values.push_back(filter_val); + } + + SmallVector new_filter_constant_values( + filter_constant_values.size(), 0); + + SmallVector new_filter_shape; + SmallVector transpose_dims = {3, 0, 1, 2}; + for (int i = 0; i < filter_shape.size(); ++i) { + new_filter_shape.push_back(filter_shape[transpose_dims[i]]); + } + + auto get_array_idx = [](ArrayRef shape, const int i, const int j, + const int k, const int l) -> int64_t { + return (i * shape[1] * shape[2] * shape[3]) + (j * shape[2] * shape[3]) + + (k * shape[3]) + l; + }; + + // Transpose the filter value. + for (int i = 0; i < filter_shape[0]; ++i) { + for (int j = 0; j < filter_shape[1]; ++j) { + for (int k = 0; k < filter_shape[2]; ++k) { + for (int l = 0; l < filter_shape[3]; ++l) { + // [i][j][k][l] -> [l][i][j][k] + const int old_idx = get_array_idx(filter_shape, i, j, k, l); + const int new_idx = get_array_idx(new_filter_shape, l, i, j, k); + + new_filter_constant_values[new_idx] = + filter_constant_values[old_idx]; + } + } + } + } + + // Create the new filter constant. + auto new_filter_value_attr_type = + RankedTensorType::getChecked(loc, new_filter_shape, + /*elementType=*/rewriter.getI8Type()); + auto new_filter_constant_value_attr = DenseIntElementsAttr::get( + new_filter_value_attr_type, new_filter_constant_values); + + return new_filter_constant_value_attr; + } + + // Returns the padding attribute used for tfl.conv_2d derived by the padding + // attribute of `op`. + // TODO: b/291599812 - Validate the values for "SAME" padding. + std::string GetPadding(stablehlo::ConvolutionOp op) const { + const DenseIntElementsAttr padding_attr = op.getPaddingAttr(); + if (!padding_attr) { + return "VALID"; + } + if (padding_attr.isSplat() && padding_attr.getSplatValue() == 0) { + return "VALID"; + } + return "SAME"; + } + + // Returns the stride amount for the height and width, respectively. + std::pair GetStrides(stablehlo::ConvolutionOp op) const { + const DenseIntElementsAttr window_strides_attr = op.getWindowStridesAttr(); + if (!window_strides_attr) { + return {1, 1}; // Default values. + } + + const auto window_strides_attr_value = + window_strides_attr.getValues(); + // It is guaranteed from the spec that it has two values: + // https://github.com/openxla/stablehlo/blob/main/docs/spec.md#convolution. + return {window_strides_attr_value[0], window_strides_attr_value[1]}; + } + + // Returns the dilation amount for the height and width, respectively. + std::pair GetDilationFactors( + stablehlo::ConvolutionOp op) const { + const DenseIntElementsAttr lhs_dilation_attr = op.getLhsDilationAttr(); + if (!lhs_dilation_attr) { + return {1, 1}; // Default values. + } + + const auto lhs_dilation_attr_value = lhs_dilation_attr.getValues(); + // It is guaranteed from the spec that it has two values: + // https://github.com/openxla/stablehlo/blob/main/docs/spec.md#convolution. + return {lhs_dilation_attr_value[0], lhs_dilation_attr_value[1]}; + } +}; + void UniformQuantizedStablehloToTflPass::runOnOperation() { func::FuncOp func_op = getOperation(); MLIRContext& ctx = getContext(); RewritePatternSet patterns(&ctx); - patterns.add(&ctx); + patterns.add(&ctx); if (failed(applyPatternsAndFoldGreedily(func_op, std::move(patterns)))) { func_op.emitError() << "Failed to convert stablehlo ops with uniform " From 7d501154c752c10ab0ea6a85d765f091be3b0368 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 27 Jul 2023 00:39:07 -0700 Subject: [PATCH 208/410] [XLA] Update PatternMatchUnmergeSharding to avoid illegal function call on SplitShardingDimension. PiperOrigin-RevId: 551438695 --- tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc | 9 +++------ .../compiler/xla/service/spmd/spmd_partitioner_test.cc | 2 +- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc index dff4a2e5188be1..930cffccd78793 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc @@ -1770,8 +1770,7 @@ namespace { // Matching a pattern like [..,X,..,Y] -> [..,X*Y,..,1] or [..,X,..,Y] -> // [..,1,..,X*Y]. // Output tuple: -// - HloSharding: The original sharding with an extra dimension dimension added -// of size 1. +// - HloSharding: The original sharding with an extra dimension added of size 1. // - HloSharding: The sharding with the dimension we want to merge moved in // place of the dimension of size 1 we added. // - int: Dimension in the input that is going to be merged with another @@ -1912,10 +1911,8 @@ PatternMatchUnmergeSharding(const Shape& shape, const Shape& base_shape, << target.tile_assignment().dim(target_dim); return std::nullopt; } - return hlo_sharding_util::SplitShardingDimension( - source, i, - source.tile_assignment().dim(i) / - target.tile_assignment().dim(target_dim)); + return hlo_sharding_util::SplitShardingDimension(source, i, + dimension_size); }; for (int j = i - 1; j >= 0; --j) { if (auto reshaped_sharding = get_reshaped_sharding(j)) { diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc index 5a6b1dee645e4b..9514ff51dc9e6c 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc @@ -10039,7 +10039,7 @@ ENTRY %module { PartitionComputation(hlo_string, /*num_devices=*/8)); const auto root = module->entry_computation()->root_instruction(); VLOG(1) << module->ToString(); - auto operand = AllOf(op::Shape("s32[2,2,2,2]"), op::Reshape()); + auto operand = AllOf(op::Shape("s32[2,2,2,2]"), op::CollectivePermute()); auto indices = AllOf(op::Shape("s32[2,2,2]"), op::Subtract()); auto gather = AllOf(op::Shape("s32[2,2,2,2]"), op::Gather(operand, indices)); EXPECT_THAT(root, op::AllReduce(op::AllReduce( From 9b2016c924538a3024946dcb639bfdb6898d8701 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Thu, 27 Jul 2023 01:44:13 -0700 Subject: [PATCH 209/410] =?UTF-8?q?[XLA:GPU]=C2=A0Enable=20using=20region?= =?UTF-8?q?=20analysis=20in=20the=20CopyInsertion=20pass.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This is in preparation for another change improving the state of copies in while loops. PiperOrigin-RevId: 551451818 --- tensorflow/compiler/xla/service/gpu/gpu_compiler.cc | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index c150e80ea4e103..3d9f3a066a9956 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -879,7 +879,10 @@ Status GpuCompiler::PrepareHloModuleForIrEmitting(HloModule* hlo_module) { pipeline.AddPass(); } pipeline.AddPass(GetCanShareBuffer()); - pipeline.AddPass(GetCanShareBuffer()); + + constexpr int64_t kNoRegionBasedLiveRangeAnalysisLimit = -1; + pipeline.AddPass(GetCanShareBuffer(), + kNoRegionBasedLiveRangeAnalysisLimit); // We are using a sub-pipeline here, so that the verifier only runs after both // GpuHorizontalLoopFusion and HloDCE. auto& sub_pipeline = From e5e0228b2ba5f0596be0a99599ed20b5d89f0769 Mon Sep 17 00:00:00 2001 From: Matt Kreileder Date: Thu, 27 Jul 2023 01:49:59 -0700 Subject: [PATCH 210/410] Add a tflite model containing tensors that store string data. PiperOrigin-RevId: 551452964 --- .../lite/testdata/tensor_string_data_type.bin | Bin 0 -> 560 bytes 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 tensorflow/lite/testdata/tensor_string_data_type.bin diff --git a/tensorflow/lite/testdata/tensor_string_data_type.bin b/tensorflow/lite/testdata/tensor_string_data_type.bin new file mode 100644 index 0000000000000000000000000000000000000000..3c10718e935fa30694d2b26bf0e70866cc0e9cc9 GIT binary patch literal 560 zcma)(Jx;?w5QSgB5DTm*g_Wd8;gSN;a0Q@B2L)|R{u&g^0*R8EGjIw{z|G9}cC(5U zfy78}-^}d%?Q9|G;cj*f>7(nqV|W4za0zZe4{UZkgrzg(fitvtKZbALnr-{CTvV%e z`&`|>*6ViOY_3AS>q)2NWt`$nGN50EztB*n2Dslr0~Wyj_OY>7iBCkDo!2MIdB}P6 zNfn0YKH32ab{;10hco#OW4wcJ@Ce+yW-ro7&$D Date: Thu, 27 Jul 2023 02:02:41 -0700 Subject: [PATCH 211/410] Update GraphDef version to 1570. PiperOrigin-RevId: 551455785 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index fa7a6a64e1608c..6fbd1d7cefb1e0 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 1569 // Updated: 2023/7/26 +#define TF_GRAPH_DEF_VERSION 1570 // Updated: 2023/7/27 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From dab618eb9c4a06c6298d51f5e89279d18d77f108 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 27 Jul 2023 02:03:39 -0700 Subject: [PATCH 212/410] compat: Update forward compatibility horizon to 2023-07-27 PiperOrigin-RevId: 551455985 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index d9a74cbdaeddb3..ed761a1484ff06 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 7, 26) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 7, 27) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From 240175440ef2921ad3da92714caa2d072f19b8e8 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 27 Jul 2023 02:36:57 -0700 Subject: [PATCH 213/410] Parse compiler flags passed as string through env_option_overrides. PiperOrigin-RevId: 551464096 --- .../compiler/xla/pjrt/pjrt_executable.cc | 47 +++++++++++++++++-- .../compiler/xla/pjrt/pjrt_executable.h | 3 ++ .../compiler/xla/pjrt/pjrt_executable_test.cc | 24 ++++++++++ 3 files changed, 69 insertions(+), 5 deletions(-) diff --git a/tensorflow/compiler/xla/pjrt/pjrt_executable.cc b/tensorflow/compiler/xla/pjrt/pjrt_executable.cc index 2dc5d7881450ed..36ab02d3729de4 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_executable.cc +++ b/tensorflow/compiler/xla/pjrt/pjrt_executable.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/strings/numbers.h" #include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/pjrt/execute_options.pb.h" #include "tensorflow/compiler/xla/util.h" @@ -293,11 +294,9 @@ Status CompileOptions::ApplyOption(const std::string& key, std::holds_alternative(value)) { reflection->SetBool(&debug_options, xla_field, std::get(value)); return OkStatus(); - } else if (xla_field->type() == - tsl::protobuf::FieldDescriptor::TYPE_STRING && - std::holds_alternative(value)) { - reflection->SetString(&debug_options, xla_field, - std::get(value)); + } else if (std::holds_alternative(value)) { + TF_RETURN_IF_ERROR( + ApplyOptionFromString(xla_field, std::get(value))); return OkStatus(); } else if (xla_field->type() == tsl::protobuf::FieldDescriptor::TYPE_INT32 && @@ -327,4 +326,42 @@ Status CompileOptions::ApplyAllOptionOverrides() { return OkStatus(); } +Status CompileOptions::ApplyOptionFromString( + const tsl::protobuf::FieldDescriptor* field, const std::string& value) { + xla::DebugOptions& debug_options = + *executable_build_options.mutable_debug_options(); + const tsl::protobuf::Reflection* reflection = debug_options.GetReflection(); + if (field->type() == tsl::protobuf::FieldDescriptor::TYPE_STRING) { + reflection->SetString(&debug_options, field, value); + return OkStatus(); + } else if (field->type() == tsl::protobuf::FieldDescriptor::TYPE_INT32) { + int int_value; + if (absl::SimpleAtoi(value, &int_value)) { + reflection->SetInt32(&debug_options, field, int_value); + return OkStatus(); + } + } else if (field->type() == tsl::protobuf::FieldDescriptor::TYPE_INT64) { + int int_value; + if (absl::SimpleAtoi(value, &int_value)) { + reflection->SetInt64(&debug_options, field, int_value); + return OkStatus(); + } + } else if (field->type() == tsl::protobuf::FieldDescriptor::TYPE_FLOAT) { + float float_value; + if (absl::SimpleAtof(value, &float_value)) { + reflection->SetFloat(&debug_options, field, float_value); + return OkStatus(); + } + } else if (field->type() == tsl::protobuf::FieldDescriptor::TYPE_BOOL) { + bool bvalue = value == "True"; + if (value == "True" || value == "False") { + reflection->SetBool(&debug_options, field, bvalue); + return OkStatus(); + } + } + return InvalidArgument( + "While setting option %s, '%s' is not a valid %s value.", field->name(), + value, field->type_name()); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/pjrt/pjrt_executable.h b/tensorflow/compiler/xla/pjrt/pjrt_executable.h index 3dc2be6a57d858..a7a00ffafd66dd 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_executable.h +++ b/tensorflow/compiler/xla/pjrt/pjrt_executable.h @@ -102,6 +102,9 @@ struct CompileOptions { // Applies a single option to executable_build_options.debug_options(). Status ApplyOption(const std::string& key, const OptionOverride& value); + Status ApplyOptionFromString(const tsl::protobuf::FieldDescriptor* field, + const std::string& value); + // Serialize the CompileOptions into a CompileOptionsProto. StatusOr ToProto() const; diff --git a/tensorflow/compiler/xla/pjrt/pjrt_executable_test.cc b/tensorflow/compiler/xla/pjrt/pjrt_executable_test.cc index a480172117d5f7..adf00146db2c08 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_executable_test.cc +++ b/tensorflow/compiler/xla/pjrt/pjrt_executable_test.cc @@ -14,6 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/pjrt/pjrt_executable.h" +#include +#include +#include #include #include @@ -90,5 +93,26 @@ TEST(ExecuteOptionsTest, SendRecvNotSupported) { "ExecuteOptions with send/recv calbacks is not serializable")); } +TEST(ExecuteOptionsTest, ApplyOptionsCanParseStrings) { + using OptionOverride = std::variant; + std::vector> env_override_options; + env_override_options = { + {"xla_gpu_use_runtime_fusion", std::string("True")}, + {"xla_gpu_graph_min_graph_size", std::string("2")}, + {"xla_gpu_redzone_scratch_max_megabytes", std::string("3400")}, + {"xla_gpu_auto_spmd_partitioning_memory_budget_ratio", + std::string("0.9")}, + {"xla_gpu_pgle_profile_file_or_directory_path", std::string("abc")}}; + CompileOptions src; + src.env_option_overrides = env_override_options; + auto s = src.ApplyAllOptionOverrides(); + auto& debug_options = src.executable_build_options.debug_options(); + EXPECT_EQ(debug_options.xla_gpu_use_runtime_fusion(), true); + EXPECT_EQ(debug_options.xla_gpu_graph_min_graph_size(), 2); + EXPECT_EQ(debug_options.xla_gpu_redzone_scratch_max_megabytes(), 3400); + EXPECT_FLOAT_EQ( + debug_options.xla_gpu_auto_spmd_partitioning_memory_budget_ratio(), 0.9); + EXPECT_EQ(debug_options.xla_gpu_pgle_profile_file_or_directory_path(), "abc"); +} } // namespace } // namespace xla From 703993de3aef5b53de8d960da0a544bc625ae3f3 Mon Sep 17 00:00:00 2001 From: Crefeda Rodrigues Date: Thu, 27 Jul 2023 09:44:15 +0000 Subject: [PATCH 214/410] Make oneDNN ACL default on Neoverse V1 cores --- tensorflow/core/util/port.cc | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/tensorflow/core/util/port.cc b/tensorflow/core/util/port.cc index 1a63ddd20f3e47..7df44ddf049ed9 100644 --- a/tensorflow/core/util/port.cc +++ b/tensorflow/core/util/port.cc @@ -76,7 +76,9 @@ inline bool DefaultOneDnnPolicy() { port::TestCPUFeature(port::CPUFeature::AVX_VNNI) || port::TestCPUFeature(port::CPUFeature::AMX_TILE) || port::TestCPUFeature(port::CPUFeature::AMX_INT8) || - port::TestCPUFeature(port::CPUFeature::AMX_BF16); + port::TestCPUFeature(port::CPUFeature::AMX_BF16) || + port::TestAarch64CPU( + port::Aarch64CPU::ARM_NEOVERSE_V1); // ARM NEOVERSE V1 #else return false; #endif // !defined(INTEL_MKL) @@ -106,17 +108,11 @@ bool IsMklEnabled() { << oneDNN_enabled; } if (oneDNN_enabled) { -#ifndef DNNL_AARCH64_USE_ACL LOG(INFO) << "oneDNN custom operations are on. " << "You may see slightly different numerical results due to " << "floating-point round-off errors from different computation " << "orders. To turn them off, set the environment variable " << "`TF_ENABLE_ONEDNN_OPTS=0`."; -#else - LOG(INFO) << "Experimental oneDNN custom operations are on. " - << "If you experience issues, please turn them off by setting " - << "the environment variable `TF_ENABLE_ONEDNN_OPTS=0`."; -#endif // !DNNL_AARCH64_USE_ACL } }); return oneDNN_enabled; From 5a9e6252acde3596be84fa2038e356d63efe6cf4 Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Thu, 27 Jul 2023 03:00:33 -0700 Subject: [PATCH 215/410] [XLA:GPU][NFC] Refactor the propagation of dimension orders in the Triton GEMM rewriter. Store dimension orders by outputs of instructions during all traversals. PiperOrigin-RevId: 551469576 --- .../xla/service/gpu/gemm_rewriter_triton.cc | 154 ++++++++++-------- 1 file changed, 82 insertions(+), 72 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.cc b/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.cc index 7ac54ff7ce8ae2..82a928071cf208 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.cc @@ -590,13 +590,14 @@ bool IsOutputWorthFusing(const HloInstruction& hlo) { } // Checks if the instruction is possible and profitable to fuse. -// If so tries to transform dim_order describing one side of `hlo` into a -// description of its other side if it is supported by the triton GEMM emitter. +// If so tries to transform dim_order describing one side of `hlo` into +// description(s) of its other side if it is supported. FusionDecision CanFuse(const HloInstruction& hlo, bool as_input, - DimensionOrder& dim_order, + const DimensionOrder& dim_order, absl::flat_hash_map& old_to_new_mapping, - const GpuVersion gpu_version) { + const GpuVersion gpu_version, + std::vector& result_dim_orders) { if (hlo.opcode() == HloOpcode::kTuple || hlo.opcode() == HloOpcode::kGetTupleElement) { return "Unsupported instruction."; @@ -659,14 +660,30 @@ FusionDecision CanFuse(const HloInstruction& hlo, bool as_input, } } - if (FusionDecision decision = dim_order.HandleInstruction( + DimensionOrder new_dim_order = DimensionOrder(dim_order); + if (FusionDecision decision = new_dim_order.HandleInstruction( &hlo, as_input ? DimensionOrder::TransformDirection::kOutputToInput : DimensionOrder::TransformDirection::kInputToOutput); !decision) { return decision; } - return RequireTritonGemmSupportedDimOrder(dim_order); + if (FusionDecision result = RequireTritonGemmSupportedDimOrder(new_dim_order); + !result) { + return result; + } + result_dim_orders.clear(); + if (as_input) { + result_dim_orders.reserve(hlo.operand_count()); + for (int i = 0; i < hlo.operand_count(); ++i) { + // All currently supported instructions with multiple operands are + // elementwise = have the same dimension orders for all operands. + result_dim_orders.push_back(new_dim_order); + } + } else { + result_dim_orders.push_back(new_dim_order); + } + return FusionDecision{}; } // Clone an instruction into the fusion. @@ -720,9 +737,9 @@ int64_t NumAddedParameters(const HloInstruction& hlo) { // Fuse an instruction with all its fusible inputs. // If an input is not fusible stop there and make a parameter of the new // fusion, otherwise put it onto stack and check its own inputs first. -void FuseWithInputsRecursively( - HloInstruction* root, DimensionOrder root_dim_order, - // Dimension orders describing inputs of corresponding instructions. +void TryToFuseWithInputsRecursively( + HloInstruction& root, + // Dimension orders describing outputs of corresponding instructions. absl::flat_hash_map& dim_orders, const GpuVersion gpu_version, absl::flat_hash_map& @@ -739,14 +756,31 @@ void FuseWithInputsRecursively( // Let it change while the scope has one input; afterwards require all // of them to be physically compatible. const HloInstruction* reference_dim_order_hlo = nullptr; - if (CanFuse(*root, /*as_input=*/true, root_dim_order, old_to_new_mapping, - gpu_version)) { - to_fuse.push(root); - inputs.insert(root->operands().begin(), root->operands().end()); - // root_dim_order went through output -> input transformation here. - CHECK(dim_orders.insert({root, root_dim_order}).second) << root->ToString(); - } - visited.insert(root); + auto try_fuse_one = [&](HloInstruction& hlo) { + std::vector operand_dim_orders; + if (!CanFuse(hlo, /*as_input=*/true, dim_orders.at(&hlo), + old_to_new_mapping, gpu_version, operand_dim_orders)) { + return false; + } + for (const DimensionOrder& dim_order : operand_dim_orders) { + if (reference_dim_order_hlo != nullptr && + !dim_order.IsPhysicallyEquivalent( + dim_orders.at(reference_dim_order_hlo))) { + return false; + } + } + to_fuse.push(&hlo); + if (hlo.opcode() != HloOpcode::kParameter) { + inputs.erase(&hlo); + } + for (int i = 0; i < hlo.operand_count(); ++i) { + inputs.insert(hlo.operand(i)); + dim_orders.insert({hlo.operand(i), operand_dim_orders[i]}); + } + return true; + }; + try_fuse_one(root); + visited.insert(&root); while (!to_fuse.empty()) { bool top_is_ready_to_fuse = true; HloInstruction* hlo = to_fuse.top(); @@ -760,27 +794,9 @@ void FuseWithInputsRecursively( NumAddedParameters(*operand) > 0) { continue; } - // Operand's output is described by its consumer's input. - DimensionOrder operand_dim_order(dim_orders.at(hlo)); - // CanFuse() makes output -> input transformation of - // operand_dim_order if succeeds. - if (CanFuse(*operand, /*as_input=*/true, operand_dim_order, - old_to_new_mapping, gpu_version)) { - if (reference_dim_order_hlo != nullptr && - !operand_dim_order.IsPhysicallyEquivalent( - dim_orders.at(reference_dim_order_hlo))) { - continue; - } - to_fuse.push(operand); - if (operand->opcode() != HloOpcode::kParameter) { - inputs.erase(operand); - } - inputs.insert(operand->operands().begin(), operand->operands().end()); + if (try_fuse_one(*operand)) { top_is_ready_to_fuse = false; } - // Save the dimension order description of operand's input. - CHECK(dim_orders.insert({operand, operand_dim_order}).second) - << operand->ToString(); } } if (top_is_ready_to_fuse) { @@ -817,13 +833,14 @@ StatusOr FuseDot(HloInstruction& dot, auto fuse_inputs = [&](int operand_number) -> StatusOr> { + const int operand_count_before = fusion_inputs.size(); absl::flat_hash_map dim_orders; - int operand_count_before = fusion_inputs.size(); // Direct dot inputs have well defined dimension orders. - FuseWithInputsRecursively( - dot.mutable_operand(operand_number), - DimensionOrder::FromDotOperand(dot, operand_number), dim_orders, - gpu_version, old_to_new_mapping, fusion_inputs, builder); + dim_orders.insert({dot.operand(operand_number), + DimensionOrder::FromDotOperand(dot, operand_number)}); + TryToFuseWithInputsRecursively(*dot.mutable_operand(operand_number), + dim_orders, gpu_version, old_to_new_mapping, + fusion_inputs, builder); TF_RET_CHECK(fusion_inputs.size() - operand_count_before <= DotFusionAnalysis::kMaxParameterPerScope); return dim_orders; @@ -839,8 +856,9 @@ StatusOr FuseDot(HloInstruction& dot, // the same tiling. auto first_lhs_parameter_it = lhs_dim_orders.cbegin(); while (first_lhs_parameter_it != lhs_dim_orders.cend()) { - if (old_to_new_mapping[first_lhs_parameter_it->first]->opcode() == - HloOpcode::kParameter) { + if (auto it = old_to_new_mapping.find(first_lhs_parameter_it->first); + it != old_to_new_mapping.cend() && + it->second->opcode() == HloOpcode::kParameter) { break; } ++first_lhs_parameter_it; @@ -877,19 +895,19 @@ StatusOr FuseDot(HloInstruction& dot, if (!IsDistributiveOverAddition(*user)) { break; } - // Describes the output of `current_output` = input of `user`. - DimensionOrder dim_order(out_dim_orders.at(fusion_output)); - if (CanFuse(*user, /*as_input=*/false, dim_order, old_to_new_mapping, - gpu_version)) { - // Now it describes the output of the user. - CHECK(out_dim_orders.insert({user, dim_order}).second); + if (std::vector output_dim_order; + CanFuse(*user, /*as_input=*/false, out_dim_orders.at(fusion_output), + old_to_new_mapping, gpu_version, output_dim_order)) { + CHECK(out_dim_orders.insert({user, output_dim_order[0]}).second); for (HloInstruction* operand : user->operands()) { if (!old_to_new_mapping.contains(operand)) { - // Here we need again a dim order describing inputs of the user. - FuseWithInputsRecursively( - operand, DimensionOrder(out_dim_orders.at(fusion_output)), - out_dim_orders, gpu_version, old_to_new_mapping, fusion_inputs, - builder); + // Here using a dimension order of one known operand of `user` for + // the other operand. This is fine for now because all supported + // multi-operand instructions are elementwise. + out_dim_orders.insert({operand, out_dim_orders.at(fusion_output)}); + TryToFuseWithInputsRecursively(*operand, out_dim_orders, gpu_version, + old_to_new_mapping, fusion_inputs, + builder); } } Fuse(*user, old_to_new_mapping, fusion_inputs, builder); @@ -1194,16 +1212,16 @@ Status MakeDotComputationSplitKBatch( } // Propagate dimension orders in consumer->producer direction starting at -// `origin` with input `origin_dim_order` till parameters of the computation. +// `origin` with output `origin_dim_order` till parameters of the computation. // Store the found parameters and their iteration specs. Status PropagateDimensionOrdersToParameters( - const HloInstruction& origin, const DimensionOrder& origin_dim_order, + const HloInstruction& origin, DimensionOrder origin_dim_order, absl::flat_hash_set& parameters, absl::flat_hash_map& iter_specs) { absl::flat_hash_set visited; std::queue to_process; - // Dimension orders describing inputs of corresponding instructions. + // Dimension orders describing outputs of corresponding instructions. absl::flat_hash_map dim_orders; TF_RET_CHECK(RequireTritonGemmSupportedDimOrder(origin_dim_order)); dim_orders.insert({&origin, origin_dim_order}); @@ -1231,14 +1249,12 @@ Status PropagateDimensionOrdersToParameters( continue; } // Operand's output is described by its consumer's input. - auto [it, inserted] = - dim_orders.insert({operand, DimensionOrder(dim_orders.at(hlo))}); - TF_RET_CHECK(inserted); - DimensionOrder& hlo_operand_dim_order = it->second; - TF_RET_CHECK(hlo_operand_dim_order.HandleInstruction( - operand, DimensionOrder::TransformDirection::kOutputToInput)) + DimensionOrder operand_dim_order(dim_orders.at(hlo)); + TF_RET_CHECK(operand_dim_order.HandleInstruction( + hlo, DimensionOrder::TransformDirection::kOutputToInput)) << operand->ToString(); - TF_RET_CHECK(RequireTritonGemmSupportedDimOrder(hlo_operand_dim_order)); + TF_RET_CHECK(RequireTritonGemmSupportedDimOrder(operand_dim_order)); + TF_RET_CHECK(dim_orders.insert({operand, operand_dim_order}).second); to_process.push(operand); } } @@ -1405,13 +1421,9 @@ Status DotFusionAnalysis::ExecuteImpl(const HloComputation* computation, for (const Scope scope : {Scope::LHS, Scope::RHS}) { const int operand_number = static_cast(scope); const HloInstruction* operand = dot->operand(operand_number); - DimensionOrder dot_operand_dim_order = - DimensionOrder::FromDotOperand(*dot, operand_number, split_k); - TF_RET_CHECK(dot_operand_dim_order.HandleInstruction( - operand, DimensionOrder::TransformDirection::kOutputToInput)); TF_RETURN_IF_ERROR(PropagateDimensionOrdersToParameters( - *operand, dot_operand_dim_order, parameters_[scope], - iter_specs_[scope])); + *operand, DimensionOrder::FromDotOperand(*dot, operand_number, split_k), + parameters_[scope], iter_specs_[scope])); } int64_t lhs_nc_split_major_part_size = -1; @@ -1441,8 +1453,6 @@ Status DotFusionAnalysis::ExecuteImpl(const HloComputation* computation, .second); if (output != dot) { // Propagate back to parameters of the output fusion. - TF_RET_CHECK(dim_order.HandleInstruction( - output, DimensionOrder::TransformDirection::kOutputToInput)); TF_RETURN_IF_ERROR(PropagateDimensionOrdersToParameters( *output, dim_order, parameters_[Scope::OUTPUT], iter_specs_[Scope::OUTPUT])); From e397708416a3d3284a47d299703b900e113c28bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tam=C3=A1s=20Danyluk?= Date: Thu, 27 Jul 2023 03:44:18 -0700 Subject: [PATCH 216/410] [XLA:GPU] Fix for "double free in RegisterXlaFfiStreamProvider" when running compilations in parallel This was triggered during the run of exhaustive autotuning for Triton fusions. 1. We are protecting access to the stream_providers vector with a mutex, preventing the crash. 2. We are only storing each StreamProvider once. 3. We are protecting the modules vector with a mutex as well, out of caution. PiperOrigin-RevId: 551479820 --- tensorflow/compiler/xla/runtime/BUILD | 1 + tensorflow/compiler/xla/runtime/ffi.cc | 29 +++++++++++++++++++++++--- 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/tensorflow/compiler/xla/runtime/BUILD b/tensorflow/compiler/xla/runtime/BUILD index 719ea5e796e715..e4fa0f343bc798 100644 --- a/tensorflow/compiler/xla/runtime/BUILD +++ b/tensorflow/compiler/xla/runtime/BUILD @@ -313,6 +313,7 @@ cc_library( "//tensorflow/compiler/xla/runtime/ffi:ffi_c_api_hdrs", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status", + "@com_google_absl//absl/synchronization", ], ) diff --git a/tensorflow/compiler/xla/runtime/ffi.cc b/tensorflow/compiler/xla/runtime/ffi.cc index 58833212510851..7102403d1aa5ee 100644 --- a/tensorflow/compiler/xla/runtime/ffi.cc +++ b/tensorflow/compiler/xla/runtime/ffi.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/status/status.h" +#include "absl/synchronization/mutex.h" #include "tensorflow/compiler/xla/runtime/custom_call.h" #include "tensorflow/compiler/xla/runtime/ffi/ffi_c_api.h" #include "tensorflow/compiler/xla/runtime/module.h" @@ -125,17 +126,30 @@ absl::StatusCode ConvertErrorCode(XLA_FFI_Error_Code errc) { using StreamProvider = XLA_FFI_Stream* (*)(const CustomCall::UserData*, const DiagnosticEngine*); -static std::vector& GetStreamProviders() { +namespace { +// For protecting GetStreamProviders(). +ABSL_CONST_INIT absl::Mutex stream_providers_mu(absl::kConstInit); +} // namespace + +static std::vector& GetStreamProviders() + ABSL_EXCLUSIVE_LOCKS_REQUIRED(stream_providers_mu) { static auto* stream_providers = new std::vector(); return *stream_providers; } void RegisterXlaFfiStreamProvider(StreamProvider provider) { - GetStreamProviders().push_back(provider); + absl::MutexLock lock(&stream_providers_mu); + std::vector& stream_providers = GetStreamProviders(); + // AFAIK there is only one stream provider now, so this count operation is not + // slow. + if (absl::c_count(stream_providers, provider) == 0) { + stream_providers.push_back(provider); + } } XLA_FFI_Stream* GetXlaFfiStream(const CustomCall::UserData* user_data, const DiagnosticEngine* diagnostic) { + absl::MutexLock lock(&stream_providers_mu); for (auto provider : GetStreamProviders()) { if (XLA_FFI_Stream* stream = provider(user_data, diagnostic)) { return stream; @@ -336,12 +350,20 @@ static XLA_FFI_Error* CreateError(XLA_FFI_Error_Create_Args* args) { // XLA runtime FFI backend implementation. //===----------------------------------------------------------------------===// -static std::vector& OwnedFfiModules() { +namespace { +// For protecting OwnedFfiModules(). +ABSL_CONST_INIT absl::Mutex modules_mu(absl::kConstInit); +} // namespace + +static std::vector& OwnedFfiModules() + ABSL_EXCLUSIVE_LOCKS_REQUIRED(modules_mu) { static auto* modules = new std::vector(); return *modules; } std::vector FfiModules() { + absl::MutexLock lock(&modules_mu); + std::vector modules; absl::c_transform(OwnedFfiModules(), std::back_inserter(modules), [](const FfiModule& module) { return &module; }); @@ -420,6 +442,7 @@ static void RegisterXlaFfiModule(XLA_FFI_Module_Register_Args* args) { exported_functions.push_back(fn); } + absl::MutexLock lock(&modules_mu); auto& modules = OwnedFfiModules(); modules.emplace_back(api(), /*id=*/modules.size(), args->name, args->module, args->state_type, args->create_state, From 3382f0b9f152d39bf9f8e3e8363fd64a04601634 Mon Sep 17 00:00:00 2001 From: Fergus Henderson Date: Thu, 27 Jul 2023 04:09:39 -0700 Subject: [PATCH 217/410] Fix a bug where the XNNPACK delegate plugin wasn't honouring the 'flags' parameter in the XnnpackSettings FlatBuffer. Testing this required also adding an API function to get the flags given a TfLiteDelegate* that points to an XNNPACK delegate. PiperOrigin-RevId: 551484835 --- .../configuration/c/xnnpack_plugin.cc | 7 +++ .../configuration/c/xnnpack_plugin_test.cc | 57 ++++++++++++++++++- .../delegates/xnnpack/xnnpack_delegate.cc | 10 ++++ .../lite/delegates/xnnpack/xnnpack_delegate.h | 6 ++ 4 files changed, 79 insertions(+), 1 deletion(-) diff --git a/tensorflow/lite/core/acceleration/configuration/c/xnnpack_plugin.cc b/tensorflow/lite/core/acceleration/configuration/c/xnnpack_plugin.cc index 9cbe4617afe23b..c82b0a7964af04 100644 --- a/tensorflow/lite/core/acceleration/configuration/c/xnnpack_plugin.cc +++ b/tensorflow/lite/core/acceleration/configuration/c/xnnpack_plugin.cc @@ -32,6 +32,13 @@ static TfLiteDelegate* CreateDelegate(const void* settings) { const auto* xnnpack_settings = tflite_settings->xnnpack_settings(); if (xnnpack_settings) { options.num_threads = xnnpack_settings->num_threads(); + // If xnnpack_settings->flags is zero, then leave options.flags + // unmodified, i.e. use the default flags (not zero). + // If xnnpack_settings->flags is nonzero, then use exactly + // those flags (i.e. discard the default flags). + if (xnnpack_settings->flags()) { + options.flags = xnnpack_settings->flags(); + } } return TfLiteXNNPackDelegateCreate(&options); } diff --git a/tensorflow/lite/core/acceleration/configuration/c/xnnpack_plugin_test.cc b/tensorflow/lite/core/acceleration/configuration/c/xnnpack_plugin_test.cc index 224e0efe256ff2..5a5c14e8d86073 100644 --- a/tensorflow/lite/core/acceleration/configuration/c/xnnpack_plugin_test.cc +++ b/tensorflow/lite/core/acceleration/configuration/c/xnnpack_plugin_test.cc @@ -45,7 +45,7 @@ class XnnpackTest : public testing::Test { settings_ = flatbuffers::GetRoot( flatbuffer_builder_.GetBufferPointer()); } - ~XnnpackTest() override {} + ~XnnpackTest() override = default; protected: // settings_ points into storage owned by flatbuffer_builder_. @@ -80,4 +80,59 @@ TEST_F(XnnpackTest, SetsCorrectThreadCount) { EXPECT_EQ(thread_count, kNumThreadsForTest); TfLiteXnnpackDelegatePluginCApi()->destroy(delegate); } + +TEST_F(XnnpackTest, UsesDefaultFlagsByDefault) { + TfLiteDelegate *delegate = + TfLiteXnnpackDelegatePluginCApi()->create(settings_); + int flags = TfLiteXNNPackDelegateGetFlags(delegate); + EXPECT_EQ(flags, TfLiteXNNPackDelegateOptionsDefault().flags); + TfLiteXnnpackDelegatePluginCApi()->destroy(delegate); +} + +TEST_F(XnnpackTest, UsesSpecifiedFlagsWhenNonzero) { + XNNPackSettingsBuilder xnnpack_settings_builder(flatbuffer_builder_); + xnnpack_settings_builder.add_flags( + tflite::XNNPackFlags_TFLITE_XNNPACK_DELEGATE_FLAG_QU8); + flatbuffers::Offset xnnpack_settings = + xnnpack_settings_builder.Finish(); + TFLiteSettingsBuilder tflite_settings_builder(flatbuffer_builder_); + tflite_settings_builder.add_xnnpack_settings(xnnpack_settings); + flatbuffers::Offset tflite_settings = + tflite_settings_builder.Finish(); + flatbuffer_builder_.Finish(tflite_settings); + settings_ = flatbuffers::GetRoot( + flatbuffer_builder_.GetBufferPointer()); + + TfLiteDelegate *delegate = + TfLiteXnnpackDelegatePluginCApi()->create(settings_); + int flags = TfLiteXNNPackDelegateGetFlags(delegate); + EXPECT_EQ(flags, tflite::XNNPackFlags_TFLITE_XNNPACK_DELEGATE_FLAG_QU8); + TfLiteXnnpackDelegatePluginCApi()->destroy(delegate); +} + +// Settings flags to XNNPackFlags_TFLITE_XNNPACK_DELEGATE_NO_FLAGS (zero) +// causes flags to be set to their default values, not zero. +// This is potentially confusing behaviour, but we can't distinguish +// the case when flags isn't set from the case when flags is set to zero. +TEST_F(XnnpackTest, UsesDefaultFlagsWhenZero) { + XNNPackSettingsBuilder xnnpack_settings_builder(flatbuffer_builder_); + xnnpack_settings_builder.add_flags( + tflite::XNNPackFlags_TFLITE_XNNPACK_DELEGATE_NO_FLAGS); + flatbuffers::Offset xnnpack_settings = + xnnpack_settings_builder.Finish(); + TFLiteSettingsBuilder tflite_settings_builder(flatbuffer_builder_); + tflite_settings_builder.add_xnnpack_settings(xnnpack_settings); + flatbuffers::Offset tflite_settings = + tflite_settings_builder.Finish(); + flatbuffer_builder_.Finish(tflite_settings); + settings_ = flatbuffers::GetRoot( + flatbuffer_builder_.GetBufferPointer()); + + TfLiteDelegate *delegate = + TfLiteXnnpackDelegatePluginCApi()->create(settings_); + int flags = TfLiteXNNPackDelegateGetFlags(delegate); + EXPECT_EQ(flags, TfLiteXNNPackDelegateOptionsDefault().flags); + TfLiteXnnpackDelegatePluginCApi()->destroy(delegate); +} + } // namespace tflite diff --git a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc index 980e0087189f19..a3f4c9836bbe43 100644 --- a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc +++ b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc @@ -6466,6 +6466,16 @@ void* TfLiteXNNPackDelegateGetThreadPool(TfLiteDelegate* delegate) { static_cast<::tflite::xnnpack::Delegate*>(delegate->data_)->threadpool()); } +int TfLiteXNNPackDelegateGetFlags(TfLiteDelegate* delegate) { + if (delegate == nullptr) { + return 0; + } + + auto* xnnpack_delegate = + static_cast<::tflite::xnnpack::Delegate*>(delegate->data_); + return xnnpack_delegate->options().flags; +} + void TfLiteXNNPackDelegateDelete(TfLiteDelegate* delegate) { if (delegate != nullptr) { ::tflite::xnnpack::Delegate* data = diff --git a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h index 951e0a9de07e6d..2269527e42202b 100644 --- a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h +++ b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h @@ -78,6 +78,12 @@ TfLiteDelegate* TfLiteXNNPackDelegateCreateWithThreadpool( TFL_CAPI_EXPORT void* TfLiteXNNPackDelegateGetThreadPool( TfLiteDelegate* delegate); +// Returns the flags used for an XNNPack delegate. +// See documentation for TfLiteXNNPackDelegateOptions.flags. +// +// WARNING: This API is experimental and subject to change. +TFL_CAPI_EXPORT int TfLiteXNNPackDelegateGetFlags(TfLiteDelegate* delegate); + // Destroys a delegate created with `TfLiteXNNPackDelegateCreate` call. TFL_CAPI_EXPORT void TfLiteXNNPackDelegateDelete(TfLiteDelegate* delegate); From 3273c4540338f5305ba6b85a9cbf2fb0153c912b Mon Sep 17 00:00:00 2001 From: Andrew Goodbody Date: Thu, 27 Jul 2023 12:22:12 +0100 Subject: [PATCH 218/410] [Linaro:ARM_CI] Stop using python venv for building and testing The use of hermetic python should make this venv unnecessary. --- .../tools/ci_build/rel/ubuntu/cpu_arm64_nonpip.sh | 11 +---------- tensorflow/tools/ci_build/rel/ubuntu/cpu_arm64_pip.sh | 11 +---------- 2 files changed, 2 insertions(+), 20 deletions(-) diff --git a/tensorflow/tools/ci_build/rel/ubuntu/cpu_arm64_nonpip.sh b/tensorflow/tools/ci_build/rel/ubuntu/cpu_arm64_nonpip.sh index cae5791083a1fa..39e2bf9f103152 100644 --- a/tensorflow/tools/ci_build/rel/ubuntu/cpu_arm64_nonpip.sh +++ b/tensorflow/tools/ci_build/rel/ubuntu/cpu_arm64_nonpip.sh @@ -52,12 +52,6 @@ sudo chown -R ${CI_BUILD_USER}:${CI_BUILD_GROUP} /usr/lib/python3/dist-packages # Update bazel install_bazelisk -# Set python version string -python_version=$(python3 -c 'import sys; print("python"+str(sys.version_info.major)+"."+str(sys.version_info.minor))') - -# Setup virtual environment -setup_venv_ubuntu ${python_version} - # Need to use the python from the venv export PYTHON_BIN_PATH=$(which python3) @@ -119,9 +113,6 @@ bazel test ${TF_TEST_FLAGS} \ --action_env=PYTHON_BIN_PATH=${PYTHON_BIN_PATH} \ --build_tag_filters=${TF_FILTER_TAGS} \ --test_tag_filters=${TF_FILTER_TAGS} \ - --jobs=32 \ + --local_test_jobs=$(grep -c ^processor /proc/cpuinfo) \ --build_tests_only \ -- ${TF_TEST_TARGETS} - -# Remove virtual environment -remove_venv_ubuntu diff --git a/tensorflow/tools/ci_build/rel/ubuntu/cpu_arm64_pip.sh b/tensorflow/tools/ci_build/rel/ubuntu/cpu_arm64_pip.sh index 817b30da6cff88..bc581f93052fc4 100644 --- a/tensorflow/tools/ci_build/rel/ubuntu/cpu_arm64_pip.sh +++ b/tensorflow/tools/ci_build/rel/ubuntu/cpu_arm64_pip.sh @@ -54,12 +54,6 @@ sudo chown -R ${CI_BUILD_USER}:${CI_BUILD_GROUP} /usr/lib/python3/dist-packages # Update bazel install_bazelisk -# Set python version string -python_version=$(python3 -c 'import sys; print("python"+str(sys.version_info.major)+"."+str(sys.version_info.minor))') - -# Setup virtual environment -setup_venv_ubuntu ${python_version} - # Need to update the version of auditwheel used for aarch64 python3 -m pip install auditwheel~=5.3.0 @@ -172,12 +166,9 @@ bazel test ${TF_TEST_FLAGS} \ --action_env=PYTHON_BIN_PATH=${PYTHON_BIN_PATH} \ --build_tag_filters=${TF_FILTER_TAGS} \ --test_tag_filters=${TF_FILTER_TAGS} \ - --jobs=32 \ + --local_test_jobs=$(grep -c ^processor /proc/cpuinfo) \ --build_tests_only \ -- ${TF_TEST_TARGETS} # remove duplicate wheel and copy wheel to mounted volume for local access rm -rf ${WHL_DIR}/*linux_aarch64.whl && cp -r ${WHL_DIR} . - -# Remove virtual environment -remove_venv_ubuntu From 4d47879c030724e29cacdfcceca95bc6dc25152b Mon Sep 17 00:00:00 2001 From: Alan Kelly Date: Thu, 27 Jul 2023 04:27:22 -0700 Subject: [PATCH 219/410] BatchMatMul implicitly transpose the RHS by changing the order from Row Major to Column Major when adj_y = false. The explicit transpose is removed, reducing memory usage and processing time PiperOrigin-RevId: 551487965 --- tensorflow/lite/kernels/batch_matmul.cc | 63 +++++++++++-------- tensorflow/lite/kernels/batch_matmul_test.cc | 27 +++++++- .../kernels/internal/optimized/batch_matmul.h | 27 ++++++-- 3 files changed, 85 insertions(+), 32 deletions(-) diff --git a/tensorflow/lite/kernels/batch_matmul.cc b/tensorflow/lite/kernels/batch_matmul.cc index 3b5963e693cc61..6c72d9003ea76c 100644 --- a/tensorflow/lite/kernels/batch_matmul.cc +++ b/tensorflow/lite/kernels/batch_matmul.cc @@ -517,7 +517,7 @@ TfLiteStatus EvalInt8Int8(TfLiteContext* context, const OpData* data, const RuntimeShape& rhs_shape, const TfLiteTensor* rhs, const RuntimeShape& output_shape, - TfLiteTensor* output) { + TfLiteTensor* output, bool transpose_lhs) { // Reuse params struct from FullyConnected Op. FullyConnectedParams op_params; int32_t input_offset = -lhs->params.zero_point; @@ -539,11 +539,11 @@ TfLiteStatus EvalInt8Int8(TfLiteContext* context, const OpData* data, GetTensorData(lhs), GetTensorShape(output), GetTensorData(output)); } else { - optimized_ops::BatchMatMul(op_params, rhs_shape, GetTensorData(rhs), - lhs_shape, GetTensorData(lhs), - GetTensorShape(output), - GetTensorData(output), - CpuBackendContext::GetFromContext(context)); + optimized_ops::BatchMatMul( + op_params, rhs_shape, GetTensorData(rhs), lhs_shape, + GetTensorData(lhs), GetTensorShape(output), + GetTensorData(output), + CpuBackendContext::GetFromContext(context), transpose_lhs); } return kTfLiteOk; } @@ -555,7 +555,7 @@ TfLiteStatus EvalInt8Int32(TfLiteContext* context, const OpData* data, const RuntimeShape& rhs_shape, const TfLiteTensor* rhs, const RuntimeShape& output_shape, - TfLiteTensor* output) { + TfLiteTensor* output, bool transpose_lhs) { // Reuse params struct from FullyConnected Op. FullyConnectedParams op_params; int32_t input_offset = -lhs->params.zero_point; @@ -579,11 +579,11 @@ TfLiteStatus EvalInt8Int32(TfLiteContext* context, const OpData* data, GetTensorData(lhs), GetTensorShape(output), GetTensorData(output)); } else { - optimized_ops::BatchMatMul(op_params, rhs_shape, GetTensorData(rhs), - lhs_shape, GetTensorData(lhs), - GetTensorShape(output), - GetTensorData(output), - CpuBackendContext::GetFromContext(context)); + optimized_ops::BatchMatMul( + op_params, rhs_shape, GetTensorData(rhs), lhs_shape, + GetTensorData(lhs), GetTensorShape(output), + GetTensorData(output), + CpuBackendContext::GetFromContext(context), transpose_lhs); } return kTfLiteOk; } @@ -620,7 +620,8 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, OpData* data, const RuntimeShape& lhs_shape, const TfLiteTensor* lhs, const RuntimeShape& rhs_shape, - const TfLiteTensor* rhs, TfLiteTensor* output) { + const TfLiteTensor* rhs, TfLiteTensor* output, + bool transpose_lhs) { if (lhs->type == kTfLiteFloat32 && rhs->type == kTfLiteInt8) { TfLiteTensor* input_quantized; TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/2, @@ -643,11 +644,12 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, } else if (lhs->type == kTfLiteInt8 && rhs->type == kTfLiteInt8) { if (output->type == kTfLiteInt8) { return EvalInt8Int8(context, data, lhs_shape, lhs, rhs_shape, - rhs, GetTensorShape(output), output); + rhs, GetTensorShape(output), output, + transpose_lhs); } else { return EvalInt8Int32(context, data, lhs_shape, lhs, rhs_shape, rhs, GetTensorShape(output), - output); + output, transpose_lhs); } } else if (lhs->type == kTfLiteInt16 && rhs->type == kTfLiteInt16) { return EvalInt16(context, data, lhs_shape, lhs, rhs_shape, rhs, @@ -744,9 +746,18 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } rhs_dims_count = orig_rhs_shape.DimensionsCount(); lhs_dims_count = orig_lhs_shape.DimensionsCount(); - const TfLiteTensor* rhs_tensor = adj_y ? rhs : GetTempRhs(context, node, rhs); + const TfLiteTensor* rhs_tensor = rhs; + bool implicit_transpose_possible = true; + if ((lhs->type == kTfLiteFloat32 && rhs->type == kTfLiteInt8) || + kernel_type == kReference || rhs->type == kTfLiteInt16) { + implicit_transpose_possible = false; + } + bool do_implicit_transpose = !adj_y && implicit_transpose_possible; + if (!adj_y && !implicit_transpose_possible) { + rhs_tensor = GetTempRhs(context, node, rhs); + } const TfLiteTensor* lhs_tensor = adj_x ? GetTempLhs(context, node, lhs) : lhs; - if (!adj_y) { + if (!adj_y && !implicit_transpose_possible) { // TODO(b/154760341) Constant tensors should already be transposed, but // we transpose once if necessary for now. if (!(IsConstantTensor(rhs) && op_data->rhs_transposed)) { @@ -757,8 +768,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { if (adj_x) { TransposeRowsColumns(context, lhs, GetTemporary(context, node, 0)); } - RuntimeShape rhs_shape = - adj_y ? orig_rhs_shape : SwapRowColumnDims(orig_rhs_shape); + RuntimeShape rhs_shape = (adj_y && !do_implicit_transpose) + ? orig_rhs_shape + : SwapRowColumnDims(orig_rhs_shape); RuntimeShape lhs_shape = adj_x ? orig_lhs_shape : SwapRowColumnDims(orig_lhs_shape); @@ -766,11 +778,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { case kTfLiteFloat32: // Note we pass RHS args first, LHS args second. See note above. if (kernel_type == kGenericOptimized) { - optimized_ops::BatchMatMul(rhs_shape, GetTensorData(rhs_tensor), - lhs_shape, GetTensorData(lhs_tensor), - GetTensorShape(output), - GetTensorData(output), - CpuBackendContext::GetFromContext(context)); + optimized_ops::BatchMatMul( + rhs_shape, GetTensorData(rhs_tensor), lhs_shape, + GetTensorData(lhs_tensor), GetTensorShape(output), + GetTensorData(output), + CpuBackendContext::GetFromContext(context), do_implicit_transpose); } else { reference_ops::BatchMatMul(rhs_shape, GetTensorData(rhs_tensor), lhs_shape, GetTensorData(lhs_tensor), @@ -781,7 +793,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { case kTfLiteInt8: case kTfLiteInt16: EvalQuantized(context, node, op_data, lhs_shape, lhs_tensor, - rhs_shape, rhs_tensor, output); + rhs_shape, rhs_tensor, output, + do_implicit_transpose); break; default: TF_LITE_KERNEL_LOG(context, diff --git a/tensorflow/lite/kernels/batch_matmul_test.cc b/tensorflow/lite/kernels/batch_matmul_test.cc index 9e968865f1039c..ad9d8e921f69be 100644 --- a/tensorflow/lite/kernels/batch_matmul_test.cc +++ b/tensorflow/lite/kernels/batch_matmul_test.cc @@ -958,9 +958,11 @@ class QuantizedBatchMatMulOpModel : public SingleOpModel { } input_size_ = total_input_size / batches_; + int rhs_batch_size = adj_y ? units_ : input_size_; + int rhs_channels = adj_y ? input_size_ : units_; lhs_id_ = AddInput(lhs); rhs_id_ = AddInput({lhs.type, - {input_size_, units_}, + {rhs_batch_size, rhs_channels}, 0, 0, GetScale(lhs_id_), @@ -1034,6 +1036,29 @@ TEST_P(QuantizedBatchMatMulOpTest, SimpleTestQuantizedInt8) { EXPECT_THAT(m.GetOutput(), ElementsAre(22, 22, 22, 56, 56, 56)); } +TEST_P(QuantizedBatchMatMulOpTest, SimpleTestQuantizedInt8AdjRHS) { + QuantizedBatchMatMulOpModel m( + /*units=*/3, /*batches*/ 2, + /*lhs=*/{TensorType_INT8, {2, 10}, -63.5, 64}, + /*output=*/{TensorType_INT8, {}, -127, 128}, false, true); + + m.SetWeights({ + 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, + 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9, 10, 10, 10, + }); + + m.SetInput({ + 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0 + 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1 + }); + + ASSERT_EQ(m.Invoke(), kTfLiteOk); + + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({14, 65, 128, 20, 95, 128}))); + EXPECT_THAT(m.GetOutput(), ElementsAre(13, 64, 127, 19, 94, 127)); +} + TEST_P(QuantizedBatchMatMulOpTest, SimpleTestQuantizedInt16) { const float inputs_scale = 10.0 / std::numeric_limits::max(); const float output_scale = 1.0; diff --git a/tensorflow/lite/kernels/internal/optimized/batch_matmul.h b/tensorflow/lite/kernels/internal/optimized/batch_matmul.h index 63256f55310512..502ecf0ee6426e 100644 --- a/tensorflow/lite/kernels/internal/optimized/batch_matmul.h +++ b/tensorflow/lite/kernels/internal/optimized/batch_matmul.h @@ -28,7 +28,8 @@ namespace optimized_ops { inline void BatchMatMul(const RuntimeShape& lhs_shape, const float* lhs_data, const RuntimeShape& rhs_shape, const float* rhs_data, const RuntimeShape& output_shape, float* output_data, - CpuBackendContext* context) { + CpuBackendContext* context, + bool transpose_lhs = false) { using ::tflite::cpu_backend_gemm::Gemm; using ::tflite::cpu_backend_gemm::GemmParams; using ::tflite::cpu_backend_gemm::MatrixParams; @@ -78,7 +79,11 @@ inline void BatchMatMul(const RuntimeShape& lhs_shape, const float* lhs_data, const int accum_depth = extended_lhs_shape.Dims(4); MatrixParams lhs_params; - lhs_params.order = cpu_backend_gemm::Order::kRowMajor; + if (transpose_lhs) { + lhs_params.order = cpu_backend_gemm::Order::kColMajor; + } else { + lhs_params.order = cpu_backend_gemm::Order::kRowMajor; + } lhs_params.rows = lhs_rows; lhs_params.cols = accum_depth; @@ -221,7 +226,8 @@ inline void BatchMatMul(const FullyConnectedParams& params, const RuntimeShape& lhs_shape, const int8_t* lhs_data, const RuntimeShape& rhs_shape, const int8_t* rhs_data, const RuntimeShape& output_shape, int8_t* output_data, - CpuBackendContext* context) { + CpuBackendContext* context, + bool transpose_lhs = false) { using ::tflite::cpu_backend_gemm::Gemm; using ::tflite::cpu_backend_gemm::GemmParams; using ::tflite::cpu_backend_gemm::MatrixParams; @@ -281,7 +287,11 @@ inline void BatchMatMul(const FullyConnectedParams& params, TFLITE_DCHECK_LE(output_activation_min, output_activation_max); MatrixParams lhs_params; - lhs_params.order = cpu_backend_gemm::Order::kRowMajor; + if (transpose_lhs) { + lhs_params.order = cpu_backend_gemm::Order::kColMajor; + } else { + lhs_params.order = cpu_backend_gemm::Order::kRowMajor; + } lhs_params.rows = lhs_rows; lhs_params.cols = accum_depth; lhs_params.zero_point = -filter_offset; @@ -327,7 +337,8 @@ inline void BatchMatMul(const FullyConnectedParams& params, const RuntimeShape& lhs_shape, const int8_t* lhs_data, const RuntimeShape& rhs_shape, const int8_t* rhs_data, const RuntimeShape& output_shape, int32_t* output_data, - CpuBackendContext* context) { + CpuBackendContext* context, + bool transpose_lhs = false) { using ::tflite::cpu_backend_gemm::Gemm; using ::tflite::cpu_backend_gemm::GemmParams; using ::tflite::cpu_backend_gemm::MatrixParams; @@ -387,7 +398,11 @@ inline void BatchMatMul(const FullyConnectedParams& params, TFLITE_DCHECK_LE(output_activation_min, output_activation_max); MatrixParams lhs_params; - lhs_params.order = cpu_backend_gemm::Order::kRowMajor; + if (transpose_lhs) { + lhs_params.order = cpu_backend_gemm::Order::kColMajor; + } else { + lhs_params.order = cpu_backend_gemm::Order::kRowMajor; + } lhs_params.rows = lhs_rows; lhs_params.cols = accum_depth; lhs_params.zero_point = -weights_offset; From 8023633ce4caf48859c863e005edc1526089a191 Mon Sep 17 00:00:00 2001 From: Daniel Lang Date: Thu, 27 Jul 2023 11:33:59 +0000 Subject: [PATCH 220/410] [FTLite] Fix pthreadpool CMake integration With the changes introduced in 9c3e858 it is no longer possible to use a prebuilt version of pthreadpool. The only options are download it or specify a folder where it was downloaded. Add SYSTEM_PTHREADPOOL as a new option that triggers find_library() and fails if the library can't be found. --- tensorflow/lite/CMakeLists.txt | 8 +++++--- tensorflow/lite/g3doc/guide/build_cmake.md | 1 + 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/tensorflow/lite/CMakeLists.txt b/tensorflow/lite/CMakeLists.txt index bc97bac8a1b102..89ec80ab390807 100644 --- a/tensorflow/lite/CMakeLists.txt +++ b/tensorflow/lite/CMakeLists.txt @@ -158,8 +158,10 @@ find_package(ml_dtypes REQUIRED) find_package(ruy REQUIRED) # Download necessary dependencies. # Download pthreadpool source package if it doesn't exist. -if(NOT DEFINED PTHREADPOOL_SOURCE_DIR) - message(STATUS "Downloading pthreadpool to ${CMAKE_BINARY_DIR}/pthreadpool-source (define PTHREADPOOL_SOURCE_DIR to avoid it)") +if(SYSTEM_PTHREADPOOL) + find_library(PTHREADPOOL_LIB pthreadpool REQUIRED) +elseif(NOT DEFINED PTHREADPOOL_SOURCE_DIR) + message(STATUS "Downloading pthreadpool to ${CMAKE_BINARY_DIR}/pthreadpool-source (define SYSTEM_PTHREADPOOL or PTHREADPOOL_SOURCE_DIR to avoid it)") configure_file(cmake/DownloadPThreadPool.cmake "${CMAKE_BINARY_DIR}/pthreadpool-download/CMakeLists.txt") execute_process(COMMAND "${CMAKE_COMMAND}" -G "${CMAKE_GENERATOR}" . WORKING_DIRECTORY "${CMAKE_BINARY_DIR}/pthreadpool-download") @@ -168,7 +170,7 @@ if(NOT DEFINED PTHREADPOOL_SOURCE_DIR) set(PTHREADPOOL_SOURCE_DIR "${CMAKE_BINARY_DIR}/pthreadpool-source" CACHE STRING "pthreadpool source directory") endif() # Configure pthreadpool -if(NOT TARGET pthreadpool) +if(NOT SYSTEM_PTHREADPOOL AND NOT TARGET pthreadpool) set(PTHREADPOOL_BUILD_TESTS OFF CACHE BOOL "") set(PTHREADPOOL_BUILD_BENCHMARKS OFF CACHE BOOL "") set(PTHREADPOOL_ALLOW_DEPRECATED_API OFF CACHE BOOL "") diff --git a/tensorflow/lite/g3doc/guide/build_cmake.md b/tensorflow/lite/g3doc/guide/build_cmake.md index 556c763ac0bb55..adcb1320e82374 100644 --- a/tensorflow/lite/g3doc/guide/build_cmake.md +++ b/tensorflow/lite/g3doc/guide/build_cmake.md @@ -82,6 +82,7 @@ variables to point to your library installations. cmake ../tensorflow_src/tensorflow/lite -DTFLITE_ENABLE_INSTALL=ON \ -DCMAKE_FIND_PACKAGE_PREFER_CONFIG=ON \ -DSYSTEM_FARMHASH=ON \ + -DSYSTEM_PTHREADPOOL=ON \ -Dabsl_DIR=/lib/cmake/absl \ -DEigen3_DIR=/share/eigen3/cmake \ -DFlatBuffers_DIR=/lib/cmake/flatbuffers \ From cdebceac5c2711acafa70b56a09105f02543948e Mon Sep 17 00:00:00 2001 From: Fergus Henderson Date: Thu, 27 Jul 2023 04:39:41 -0700 Subject: [PATCH 221/410] In Subgraph::AddTensors, return an error if `tensors_to_add` parameter is negative. PiperOrigin-RevId: 551490259 --- tensorflow/lite/core/subgraph.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/lite/core/subgraph.cc b/tensorflow/lite/core/subgraph.cc index 62c148a485ceb9..e8a5da1396a4ec 100644 --- a/tensorflow/lite/core/subgraph.cc +++ b/tensorflow/lite/core/subgraph.cc @@ -1782,6 +1782,7 @@ TfLiteStatus Subgraph::AddTensors(int tensors_to_add, int* first_new_tensor_index) { const size_t base_index = tensors_.size(); if (first_new_tensor_index) *first_new_tensor_index = base_index; + if (tensors_to_add < 0) return kTfLiteError; tensors_.resize(tensors_.size() + tensors_to_add); for (size_t i = base_index; i < tensors_.size(); i++) { memset(&tensors_[i], 0, sizeof(tensors_[i])); From 66f54179d8723938ed7c16168337fd23d3bbe6a4 Mon Sep 17 00:00:00 2001 From: Oleg Shyshkov Date: Thu, 27 Jul 2023 04:40:12 -0700 Subject: [PATCH 222/410] [XLA:GPU] Simplify operand index computations (NFC). PiperOrigin-RevId: 551490359 --- .../xla/service/gpu/priority_fusion.cc | 20 +++++++------------ 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/priority_fusion.cc b/tensorflow/compiler/xla/service/gpu/priority_fusion.cc index 2e258ce5f6cc9a..088ef6d2c3691d 100644 --- a/tensorflow/compiler/xla/service/gpu/priority_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/priority_fusion.cc @@ -116,12 +116,13 @@ class GpuPriorityFusionQueue : public FusionQueue { } auto next_consumer = current_consumers_.back(); + int64_t producer_operand_index = + next_consumer->operand_index(current_producer_); current_consumers_.pop_back(); VLOG(5) << "next: " << next_consumer->name() << "(" << next_consumer << ") + " << current_producer_->name() << "(" << current_producer_ << ")"; - auto indices = next_consumer->OperandIndices(current_producer_); - return {next_consumer, {indices.begin(), indices.end()}}; + return {next_consumer, {producer_operand_index}}; } // Calculates the compute cost and free computation of the new fusion in the @@ -248,18 +249,11 @@ class GpuPriorityFusionQueue : public FusionQueue { } std::vector GetFusibleUsers(HloInstruction* producer) const { - auto is_fusible = [&](HloInstruction* user) { - for (int64_t i = 0; i < user->operand_count(); ++i) { - if (user->operand(i) == producer && can_fuse_(user, i)) { - return true; - } - } - return false; - }; std::vector fusible_users; - std::vector prod_users(producer->users()); - for (auto user : prod_users) { - if (is_fusible(user)) { + for (auto user : producer->users()) { + int64_t operand_index = user->operand_index(producer); + + if (can_fuse_(user, operand_index)) { fusible_users.push_back(user); } } From 3e3e3cabf17f358e10d28d1387dbd6b1376753d1 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 27 Jul 2023 05:31:20 -0700 Subject: [PATCH 223/410] Integrate LLVM at llvm/llvm-project@4706251a3186 Updates LLVM usage to match [4706251a3186](https://github.com/llvm/llvm-project/commit/4706251a3186) PiperOrigin-RevId: 551500429 --- third_party/llvm/generated.patch | 23 ++++++++++++----------- third_party/llvm/workspace.bzl | 4 ++-- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch index 442a292926da98..2bea0dc6ac565b 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch @@ -1,12 +1,13 @@ Auto generated patch. Do not edit or delete it, even if empty. -diff -ruN --strip-trailing-cr a/utils/bazel/llvm-project-overlay/libc/test/src/stdlib/BUILD.bazel b/utils/bazel/llvm-project-overlay/libc/test/src/stdlib/BUILD.bazel ---- a/utils/bazel/llvm-project-overlay/libc/test/src/stdlib/BUILD.bazel -+++ b/utils/bazel/llvm-project-overlay/libc/test/src/stdlib/BUILD.bazel -@@ -152,6 +152,7 @@ - deps = [ - "//libc:__support_cpp_limits", - "//libc:__support_cpp_type_traits", -+ "//libc:__support_macros_properties_architectures", - "//libc:errno.__internal__", - "//libc/test/UnitTest:LibcUnitTest", - ], +diff -ruN --strip-trailing-cr a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp +--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp ++++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp +@@ -1714,7 +1714,7 @@ + ");\n", + op.getCppClassName()); + } else { +- body << " result.addAttribute(\"odsResultSegmentSizes\", " ++ body << " result.addAttribute(\"result_segment_sizes\", " + << "parser.getBuilder().getDenseI32ArrayAttr({"; + llvm::interleaveComma(op.getResults(), body, interleaveFn); + body << "}));\n"; diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index c7b951aa0465f0..2644759bfc40b1 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "365d6eb1f7d86cf28dc7d4995c3949e9d8bead58" - LLVM_SHA256 = "12156f9a68392d8248c1239a691719cb8fd25dfcc0265acbc3b054a347a58ca8" + LLVM_COMMIT = "4706251a3186c34da0ee8fd894f7e6b095da8fdc" + LLVM_SHA256 = "01cdfda6790f2b0d897423c6ba8147af6359f261c5aed62e157b9ecdf2ad591e" tf_http_archive( name = name, From 4ca12e7755987d6590722424e21bf1931a7cc13d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 27 Jul 2023 06:25:28 -0700 Subject: [PATCH 224/410] This is an internal change. PiperOrigin-RevId: 551511055 --- tensorflow/lite/kernels/internal/BUILD | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tensorflow/lite/kernels/internal/BUILD b/tensorflow/lite/kernels/internal/BUILD index 8db73e4a9bde22..3082f803a7db1e 100644 --- a/tensorflow/lite/kernels/internal/BUILD +++ b/tensorflow/lite/kernels/internal/BUILD @@ -429,6 +429,14 @@ cc_library( "optimized/4bit/neon_fully_connected_aarch64_sdot.cc", "optimized/4bit/neon_fully_connected_impl.h", ], + # copybara:uncomment_begin(google-only) + # "//tools/cc_target_os:toyota-sa8155-aarch64-oe-linux": [ + # "optimized/4bit/neon_fully_connected.cc", + # "optimized/4bit/neon_fully_connected_aarch64_nosdot.cc", + # "optimized/4bit/neon_fully_connected_aarch64_sdot.cc", + # "optimized/4bit/neon_fully_connected_impl.h", + # ], + # copybara:uncomment_end "//tensorflow:android_arm": [ "optimized/4bit/neon_fully_connected.cc", "optimized/4bit/neon_fully_connected_arm32.cc", @@ -451,6 +459,11 @@ cc_library( "//tensorflow:android_arm64": [ "optimized/4bit/neon_fully_connected.h", ], + # copybara:uncomment_begin(google-only) + # "//tools/cc_target_os:toyota-sa8155-aarch64-oe-linux": [ + # "optimized/4bit/neon_fully_connected.h", + # ], + # copybara:uncomment_end "//tensorflow:android_arm": [ "optimized/4bit/neon_fully_connected.h", ], @@ -461,6 +474,11 @@ cc_library( "//tensorflow:android_arm64": [ "-march=armv8.2-a+dotprod", ], + # copybara:uncomment_begin(google-only) + # "//tools/cc_target_os:toyota-sa8155-aarch64-oe-linux": [ + # "-march=armv8.2-a+dotprod", + # ], + # copybara:uncomment_end "//tensorflow:android_arm": [], "//conditions:default": [], }) + NEON_FLAGS_IF_APPLICABLE, @@ -469,6 +487,11 @@ cc_library( "//tensorflow:android_arm64": [ "FC_4BIT_NEON", ], + # copybara:uncomment_begin(google-only) + # "//tools/cc_target_os:toyota-sa8155-aarch64-oe-linux": [ + # "FC_4BIT_NEON", + # ], + # copybara:uncomment_end "//tensorflow:android_arm": [ "FC_4BIT_NEON", ], From f5946e3be0542c4e51ffcdef2d2e8a21119a30c3 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 27 Jul 2023 06:32:54 -0700 Subject: [PATCH 225/410] Remove a redundant `alwayslink` declaration for the targets ... in lite/c/BUILD and lite/core/c/BUILD that has no `srcs`. `alwayslink` declaration doesn't have any effect since it only applies to `srcs`, not to `deps`. Also delete redundant `private_c_api` alias. All these changes don't affect the target outputs, only the build graph. PiperOrigin-RevId: 551512395 --- tensorflow/lite/c/BUILD | 7 ++----- tensorflow/lite/core/c/BUILD | 14 -------------- 2 files changed, 2 insertions(+), 19 deletions(-) diff --git a/tensorflow/lite/c/BUILD b/tensorflow/lite/c/BUILD index bc00f79e36c19c..025a804cbee3c3 100644 --- a/tensorflow/lite/c/BUILD +++ b/tensorflow/lite/c/BUILD @@ -61,7 +61,6 @@ cc_library_with_tflite_with_c_headers_test( "//tensorflow/lite/core/async/c:types", "//tensorflow/lite/core/c:c_api", ], - alwayslink = 1, # Why?? TODO(b/161243354): eliminate this. ) tflite_cc_library_with_c_headers_test( @@ -90,7 +89,6 @@ cc_library_with_tflite_with_c_headers_test( ":c_api_types", "//tensorflow/lite/core/c:c_api_without_op_resolver", ], - alwayslink = 1, # Why?? TODO(b/161243354): eliminate this. ) # TODO(b/248514738): Deprecate this target in favour @@ -126,7 +124,6 @@ cc_library_with_tflite_with_c_headers_test( "//tensorflow/lite/core/c:common", "//tensorflow/lite/kernels:kernel_util", ], - alwayslink = 1, # Why?? TODO(b/161243354): eliminate this. ) # Same as ":c_api_experimental", but without linking in the default CreateOpResolver implementation. @@ -139,10 +136,10 @@ cc_library_with_tflite_with_c_headers_test( copts = tflite_copts() + tflite_copts_warnings(), tags = ["allow_undefined_symbols"], # For tflite::CreateOpResolver(). deps = ["//tensorflow/lite/core/c:c_api_experimental_without_op_resolver"], - alwayslink = 1, # Why?? TODO(b/161243354): eliminate this. ) -# Same as ":c_api_experimental", but without linking in the default CreateOpResolver implementation. +# Same as ":c_api_experimental", but without linking in the default CreateOpResolver implementation, +# and without depending on targets that use alwayslink=1. cc_library_with_tflite_with_c_headers_test( name = "c_api_experimental_without_op_resolver_without_alwayslink", hdrs = [ diff --git a/tensorflow/lite/core/c/BUILD b/tensorflow/lite/core/c/BUILD index 28fd64d1021052..96a67a6b5b1f48 100644 --- a/tensorflow/lite/core/c/BUILD +++ b/tensorflow/lite/core/c/BUILD @@ -67,19 +67,6 @@ tflite_cc_library_with_c_headers_test( "//tensorflow/lite/c:common", "//tensorflow/lite/core/async/c:types", ], - alwayslink = 1, # Why?? TODO(b/161243354): eliminate this. -) - -# This is a private target, its visibility is set to public only to be -# used by "tflite_custom_c_library". -# Do not use this target directly and don't consider it as a part of the public API. -alias( - name = "private_c_api", - actual = ":c_api", - tags = ["avoid_dep"], - visibility = [ - "//visibility:public", - ], ) tflite_cc_library_with_c_headers_test( @@ -328,7 +315,6 @@ tflite_cc_library_with_c_headers_test( "//tensorflow/lite/profiling/telemetry:profiler", "//tensorflow/lite/profiling/telemetry/c:profiler", ], - alwayslink = 1, # Why?? TODO(b/161243354): eliminate this. ) # Same as ":c_api_experimental", but without linking in the default CreateOpResolver implementation. From 8d71992cb71060369d5280698c07482f0daad57a Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Thu, 27 Jul 2023 06:35:00 -0700 Subject: [PATCH 226/410] [XLA:GPU][NFC] Cleanup Triton GEMM rewriter. Extract commonly used types, rename a field, make a constructor private, fix a comment. PiperOrigin-RevId: 551512811 --- .../xla/service/gpu/gemm_rewriter_triton.cc | 74 +++++++++---------- 1 file changed, 36 insertions(+), 38 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.cc b/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.cc index 82a928071cf208..2dfdd5e35e5418 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_rewriter_triton.cc @@ -156,6 +156,21 @@ FusionDecision RequireTritonFusibleConvert(const HloInstruction* input, // Used to calculate cumulative index transformations done by non-elementwise // instructions between source and target. class DimensionOrder { + // Dimension order constructed for the output shape of `hlo`. + // `hlo` is currently supposed to be either an operand or the output of dot(); + // properties describing the dimensions are stored for later analysis. + explicit DimensionOrder( + const HloInstruction* hlo, const int64_t splittable_dimension_index = -1, + const int64_t splittable_dimension_supported_major_size = 0) + : splittable_dimension_index_(splittable_dimension_index), + splittable_dimension_supported_major_part_size_( + splittable_dimension_supported_major_size) { + dim_order_.reserve(hlo->shape().rank()); + for (const int64_t i : hlo->shape().layout().minor_to_major()) { + dim_order_.push_back({i, 0, hlo->shape().dimensions(i)}); + } + } + public: // Description of one dimension of HLO shape. struct DimDescription { @@ -172,24 +187,9 @@ class DimensionOrder { }; // Sequence describing all dimensions of HLO's output shape // in layout minor-to-major (physical) order. - using DimOrderVector = std::vector; + using RawDimOrder = std::vector; DimensionOrder(const DimensionOrder&) = default; - // Dimension order constructed for the output shape of `hlo`. - // `hlo` is currently supposed to be an operand of dot(); - // dimension indices describing the operand - // are stored along with the dimension order for later analysis. - explicit DimensionOrder( - const HloInstruction* hlo, const int64_t splittable_dimension_index = -1, - const int64_t splittable_dimension_supported_major_size = 0) - : splittable_dimension_index_(splittable_dimension_index), - splittable_dimension_supported_major_part_size_( - splittable_dimension_supported_major_size) { - dim_order_.reserve(hlo->shape().rank()); - for (const int64_t i : hlo->shape().layout().minor_to_major()) { - dim_order_.push_back({i, 0, hlo->shape().dimensions(i)}); - } - } // Create dimension order describing a dot operand according to // the currently supported configurations. @@ -236,7 +236,7 @@ class DimensionOrder { } // Get the raw data of the dimension order. - const DimOrderVector& GetDimOrderVector() const { return dim_order_; } + const RawDimOrder& GetRawDimOrder() const { return dim_order_; } // Index of dot dimension that can be split. // Currently typically LHS non-contracting one. @@ -267,17 +267,18 @@ class DimensionOrder { FusionDecision HandleCopyOrTransposeOrBroadcast(const HloInstruction*, TransformDirection); - DimOrderVector dim_order_; + RawDimOrder dim_order_; const int64_t splittable_dimension_index_; const int64_t splittable_dimension_supported_major_part_size_; }; using DimIterationSpec = TensorIterationSpec::DimIterationSpec; +using RawDimOrder = DimensionOrder::RawDimOrder; +using DimOrderMap = absl::flat_hash_map; TensorIterationSpec DimensionOrderToTensorIterationSpec( const DimensionOrder& order) { - const DimensionOrder::DimOrderVector& dim_order_vector = - order.GetDimOrderVector(); + const RawDimOrder& dim_order_vector = order.GetRawDimOrder(); TensorIterationSpec tensor_spec; int64_t accumulated_stride = 1; for (int dim_order_index = 0; dim_order_index < dim_order_vector.size(); @@ -363,7 +364,7 @@ FusionDecision DimensionOrder::HandleBitcast(const HloInstruction* hlo, const Shape& target_shape = (direction == TransformDirection::kOutputToInput) ? hlo->operand(0)->shape() : hlo->shape(); - DimOrderVector target_dim_order; + RawDimOrder target_dim_order; target_dim_order.reserve(dim_order_.size()); // Size of not yet assigned part of current target dimension. int64_t target_remaining_size = 1; @@ -455,13 +456,13 @@ FusionDecision DimensionOrder::HandleCopyOrTransposeOrBroadcast( (direction == TransformDirection::kOutputToInput) ? hlo : hlo->operand(0); const HloInstruction* dst = (direction == TransformDirection::kOutputToInput) ? hlo->operand(0) : hlo; - std::vector src_physical; + std::vector src_physical; src_physical.reserve(src->shape().rank()); auto dim_order_it = dim_order_.cbegin(); for (int64_t dim_index : src->shape().layout().minor_to_major()) { const int64_t dim_size = src->shape().dimensions(dim_index); int64_t subdim_size_accumulator = 1; - DimOrderVector subdim_group; + RawDimOrder subdim_group; do { subdim_size_accumulator *= dim_order_it->size; subdim_group.push_back(*dim_order_it); @@ -471,13 +472,13 @@ FusionDecision DimensionOrder::HandleCopyOrTransposeOrBroadcast( src_physical.push_back(subdim_group); } // Source physical -> source logical. - std::vector src_logical; + std::vector src_logical; src_logical.resize(src_physical.size()); for (int i = 0; i < src_physical.size(); ++i) { src_logical[src->shape().layout().minor_to_major(i)] = src_physical[i]; } // Source logical -> destination logical. - std::vector dst_logical; + std::vector dst_logical; if (hlo->opcode() == HloOpcode::kTranspose) { const auto transpose = Cast(hlo); std::vector permutation(transpose->dimensions().cbegin(), @@ -520,8 +521,7 @@ FusionDecision RequireTritonGemmSupportedDimOrder(const DimensionOrder& order) { -1, -1, -1, -1}; std::array split_counters = { -1, -1, -1, -1}; - const DimensionOrder::DimOrderVector& dim_order_vector = - order.GetDimOrderVector(); + const RawDimOrder& dim_order_vector = order.GetRawDimOrder(); VLOG(8) << order.ToString(); for (int i = 0; i < dim_order_vector.size(); i++) { const auto [dim_number, subdim_number, size] = dim_order_vector[i]; @@ -740,15 +740,14 @@ int64_t NumAddedParameters(const HloInstruction& hlo) { void TryToFuseWithInputsRecursively( HloInstruction& root, // Dimension orders describing outputs of corresponding instructions. - absl::flat_hash_map& dim_orders, - const GpuVersion gpu_version, + DimOrderMap& dim_orders, const GpuVersion gpu_version, absl::flat_hash_map& old_to_new_mapping, std::vector& fusion_inputs, HloComputation::Builder& builder) { absl::flat_hash_set visited; std::stack to_fuse; - // Instructions at the edge 'to_fuse' that can either get fused too or + // Instructions at the edge of 'to_fuse' that can either get fused too or // become parameters of the fusion. Used to track the number of parameters // of the fusion. absl::flat_hash_set inputs; @@ -831,10 +830,9 @@ StatusOr FuseDot(HloInstruction& dot, // differently shaped tiles but may go through same HLO graph nodes. // Direct dot inputs have well defined dimension orders. - auto fuse_inputs = [&](int operand_number) - -> StatusOr> { + auto fuse_inputs = [&](int operand_number) -> StatusOr { const int operand_count_before = fusion_inputs.size(); - absl::flat_hash_map dim_orders; + DimOrderMap dim_orders; // Direct dot inputs have well defined dimension orders. dim_orders.insert({dot.operand(operand_number), DimensionOrder::FromDotOperand(dot, operand_number)}); @@ -880,7 +878,7 @@ StatusOr FuseDot(HloInstruction& dot, // Fusion at dot's output. // These describe _outputs_ of corresponding HLOs. - absl::flat_hash_map out_dim_orders; + DimOrderMap out_dim_orders; out_dim_orders.insert( {&dot, DimensionOrder::FromDotOutput(dot, /*split_k=*/1, lhs_nc_split_major_part)}); @@ -1222,7 +1220,7 @@ Status PropagateDimensionOrdersToParameters( absl::flat_hash_set visited; std::queue to_process; // Dimension orders describing outputs of corresponding instructions. - absl::flat_hash_map dim_orders; + DimOrderMap dim_orders; TF_RET_CHECK(RequireTritonGemmSupportedDimOrder(origin_dim_order)); dim_orders.insert({&origin, origin_dim_order}); visited.insert(&origin); @@ -1420,15 +1418,15 @@ Status DotFusionAnalysis::ExecuteImpl(const HloComputation* computation, for (const Scope scope : {Scope::LHS, Scope::RHS}) { const int operand_number = static_cast(scope); - const HloInstruction* operand = dot->operand(operand_number); TF_RETURN_IF_ERROR(PropagateDimensionOrdersToParameters( - *operand, DimensionOrder::FromDotOperand(*dot, operand_number, split_k), + *dot->operand(operand_number), + DimensionOrder::FromDotOperand(*dot, operand_number, split_k), parameters_[scope], iter_specs_[scope])); } int64_t lhs_nc_split_major_part_size = -1; if (!ScopeParameters(Scope::LHS).empty()) { - const TensorIterationSpec::DimIterationSpec* lhs_nc_iter_spec = + const DimIterationSpec* lhs_nc_iter_spec = IterSpec(Scope::LHS, *ScopeParameters(Scope::LHS).cbegin(), NonContractingDimensionIndex(*dot, 0)); if (lhs_nc_iter_spec->size() > 1) { From 5ea4397d06102c629a869a56ff8c057d57f974a1 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 27 Jul 2023 06:38:49 -0700 Subject: [PATCH 227/410] Update TFRT dependency to use revision http://github.com/tensorflow/runtime/commit/3bf6c17968a52aea580c5398bbcfc0cf0e069dc5. PiperOrigin-RevId: 551513580 --- third_party/tf_runtime/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/tf_runtime/workspace.bzl b/third_party/tf_runtime/workspace.bzl index fb499cb5f9dae7..1639568a63e79e 100644 --- a/third_party/tf_runtime/workspace.bzl +++ b/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "dcffa4094e13d56a30b2cdcdb709ce5d71b38953" - TFRT_SHA256 = "ce4a95e4fe258353e751af57e6970a186aab77f03f12944ce389667b33d6c5b1" + TFRT_COMMIT = "3bf6c17968a52aea580c5398bbcfc0cf0e069dc5" + TFRT_SHA256 = "25c973d5ea4cdf6a0762fb3dead5162339ba7fa00f67cc2681e55f6da6c796ab" tf_http_archive( name = "tf_runtime", From 48fba8db346f158dc1db1500b856b46a7b952cad Mon Sep 17 00:00:00 2001 From: Mason Chang Date: Thu, 27 Jul 2023 06:42:03 -0700 Subject: [PATCH 228/410] Clean up dead launched flag to always enable op by op fallback. PiperOrigin-RevId: 551514303 --- .../mlir/tf2xla/api/v1/legalize_tf.cc | 47 ++++--------------- 1 file changed, 9 insertions(+), 38 deletions(-) diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/legalize_tf.cc b/tensorflow/compiler/mlir/tf2xla/api/v1/legalize_tf.cc index cf21467ad7ff6d..a4bca9e195b672 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/legalize_tf.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/legalize_tf.cc @@ -128,7 +128,6 @@ bool ShouldFallbackToGraphCompiler( } Status CompileFromMlirToXlaHlo( - bool enable_op_fallback, const std::variant& computation, const tpu::TPUCompileMetadataProto& metadata, llvm::StringRef device_type, const XlaShapeLayoutHelpers::ShapeDeterminationFns& shape_determination_fns, @@ -137,18 +136,11 @@ Status CompileFromMlirToXlaHlo( const std::vector& arg_shapes, std::vector* arg_core_mapping, std::vector>* per_core_arg_shapes) { - if (enable_op_fallback) { LOG_FIRST_N(INFO, 1) << "Compiling MLIR computation to XLA HLO using MLIR tf2xla bridge in " "the op by op fallback mode. This is Phase 2 of the TF2XLA Bridge. " "Old (non-MLIR) bridge may be used in case of unsupported feature " "or compilation failure from the MLIR bridge (full fallback mode)."; - } else { - LOG_FIRST_N(INFO, 1) - << "Compiling MLIR computation to XLA HLO using MLIR tf2xla bridge " - "phase 2. Fallback to the old (non-MLIR) bridge is disabled. " - "Op-by-op fallback is also disabled."; - } mlir::DialectRegistry registry; mlir::RegisterAllTensorFlowDialects(registry); @@ -163,7 +155,7 @@ Status CompileFromMlirToXlaHlo( TF_RETURN_IF_ERROR(CompileSerializedMlirToXlaHlo( SerializeMlirModule(mlir_module.get()), arg_shapes, device_type, - use_tuple_args, enable_op_fallback, shape_determination_fns, + use_tuple_args, /*enable_op_fallback=*/true, shape_determination_fns, compilation_result, custom_legalization_passes, metadata.module_name())); // Compute how arguments are shared across different cores. @@ -197,47 +189,26 @@ tsl::StatusOr LegalizeMlirToHlo( // We could only end up here if the MLIR bridge was explicitly enabled or // if it was in the default/unspecified state and graph analysis in the first // phase has not identified unsupported features. - // Enabling op fallback also enables whole graph fallback if op by op - // fallback failed. - bool enable_op_fallback = true; - Status mlir_bridge_status = tsl::OkStatus(); { CompilationTimer timer; - std::string enabled_string = enable_op_fallback ? "enabled" : "disabled"; - const std::string kMlirBridgeFallback = - absl::StrCat("mlir_bridge_op_fallback_", enabled_string); + const std::string kMlirBridgeFallback = "mlir_bridge_op_fallback_enabled"; mlir_bridge_status = CompileFromMlirToXlaHlo( - enable_op_fallback, computation, metadata, device_type, - shape_determination_fns, use_tuple_args, compilation_result.get(), - custom_legalization_passes, arg_shapes, arg_core_mapping, - per_core_arg_shapes); + computation, metadata, device_type, shape_determination_fns, + use_tuple_args, compilation_result.get(), custom_legalization_passes, + arg_shapes, arg_core_mapping, per_core_arg_shapes); phase2_bridge_compilation_time->GetCell(kMlirBridgeFallback) ->Add(timer.ElapsedCyclesInMilliseconds()); } if (mlir_bridge_status.ok()) { - if (enable_op_fallback) { - VLOG(1) << "Successfully compiled MLIR computation to XLA HLO using MLIR " - "tf2xla bridge"; - mlir_second_phase_count->GetCell(kMlirWithFallbackModeSuccess) - ->IncrementBy(1); - } else { - mlir_second_phase_count->GetCell(kMlirModeSuccess)->IncrementBy(1); - } + VLOG(1) << "Successfully compiled MLIR computation to XLA HLO using MLIR " + "tf2xla bridge"; + mlir_second_phase_count->GetCell(kMlirWithFallbackModeSuccess) + ->IncrementBy(1); return *compilation_result; - } else if (!enable_op_fallback) { - // Don't fallback to the old bridge if op-by-op fallback isn't enabled. - mlir_second_phase_count->GetCell(kMlirModeFailure)->IncrementBy(1); - if (!mlir_bridge_status.ok()) { - tsl::error_logging::Log(kBridgeComponent, - "TFXLA_API_V1_BRIDGE_NO_FALLBACK", - mlir_bridge_status.ToString()) - .IgnoreError(); - } - return mlir_bridge_status; } else { tsl::error_logging::Log(kBridgeComponent, "TFXLA_API_V1_BRIDGE_WITH_FALLBACK_FAIL", From 47b65c8ef5d27403650a34cc505d61d4cba3972c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 27 Jul 2023 07:35:46 -0700 Subject: [PATCH 229/410] Add type annotations to parameter_server_strategy_v2. PiperOrigin-RevId: 551525663 --- tensorflow/python/distribute/BUILD | 3 ++- tensorflow/python/distribute/parameter_server_strategy_v2.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index dade431b0dcf8d..7458f777fd5a58 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -2503,7 +2503,7 @@ distribute_py_strict_test( ], ) -py_strict_library( +pytype_strict_library( name = "parameter_server_strategy_v2", srcs = ["parameter_server_strategy_v2.py"], srcs_version = "PY3", @@ -2520,6 +2520,7 @@ py_strict_library( ":sharded_variable", ":values", "//tensorflow/core:protos_all_py", + "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib", "//tensorflow/python/distribute/coordinator:cluster_coordinator", "//tensorflow/python/eager:context", "//tensorflow/python/eager:remote", diff --git a/tensorflow/python/distribute/parameter_server_strategy_v2.py b/tensorflow/python/distribute/parameter_server_strategy_v2.py index 8bf08b2dd8efdc..14fbba9133b748 100644 --- a/tensorflow/python/distribute/parameter_server_strategy_v2.py +++ b/tensorflow/python/distribute/parameter_server_strategy_v2.py @@ -21,6 +21,7 @@ import os import threading +from tensorflow.python.distribute import cluster_resolver as cluster_resolver_lib from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib from tensorflow.python.distribute import device_util from tensorflow.python.distribute import distribute_lib @@ -424,7 +425,7 @@ def dataset_fn(): """ # pyformat: disable - def __init__(self, cluster_resolver, variable_partitioner=None): + def __init__(self, cluster_resolver: cluster_resolver_lib.ClusterResolver, variable_partitioner: sharded_variable.Partitioner = None): """Initializes the TF2 parameter server strategy. This initializes the `tf.distribute.experimental.ParameterServerStrategy` From 4765c7238ca966074ae6df94dee4682c02737730 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Thu, 27 Jul 2023 08:40:19 -0700 Subject: [PATCH 230/410] [xla] Do not depend on any iree targets by default PiperOrigin-RevId: 551541217 --- .../compiler/xla/mlir/backends/openxla/BUILD | 9 +++++++ .../mlir/backends/openxla/build_config.bzl | 13 ++++++++++ .../mlir/backends/openxla/conversion/BUILD | 24 +++++++++---------- .../mlir/backends/openxla/transforms/BUILD | 9 +++---- .../mlir/backends/openxla/transforms/passes.h | 2 ++ .../backends/openxla/transforms/tests/BUILD | 4 ++++ 6 files changed, 44 insertions(+), 17 deletions(-) create mode 100644 tensorflow/compiler/xla/mlir/backends/openxla/build_config.bzl diff --git a/tensorflow/compiler/xla/mlir/backends/openxla/BUILD b/tensorflow/compiler/xla/mlir/backends/openxla/BUILD index 30ab0c23dd1e6b..e17129b1968409 100644 --- a/tensorflow/compiler/xla/mlir/backends/openxla/BUILD +++ b/tensorflow/compiler/xla/mlir/backends/openxla/BUILD @@ -7,6 +7,15 @@ package( licenses = ["notice"], ) +# Add `--define=xla_gpu_with_openxla_compiler=1` to build command to enable experimental +# OpenXLA/IREE backend compiler. +config_setting( + name = "with_openxla_compiler", + values = { + "define": "xla_gpu_with_openxla_compiler=1", + }, +) + # copybara:uncomment_begin(not supported in OSS build) # # build_test( diff --git a/tensorflow/compiler/xla/mlir/backends/openxla/build_config.bzl b/tensorflow/compiler/xla/mlir/backends/openxla/build_config.bzl new file mode 100644 index 00000000000000..3ab222a6da6929 --- /dev/null +++ b/tensorflow/compiler/xla/mlir/backends/openxla/build_config.bzl @@ -0,0 +1,13 @@ +"""Helpers for conditional OpenXLA compilation.""" + +def if_openxla(then, otherwise = []): + return select({ + "//tensorflow/compiler/xla/mlir/backends/openxla:with_openxla_compiler": then, + "//conditions:default": otherwise, + }) + +def if_not_openxla(then, otherwise = []): + return select({ + "//tensorflow/compiler/xla/mlir/backends/openxla:with_openxla_compiler": otherwise, + "//conditions:default": then, + }) diff --git a/tensorflow/compiler/xla/mlir/backends/openxla/conversion/BUILD b/tensorflow/compiler/xla/mlir/backends/openxla/conversion/BUILD index f2262c9c55428e..adc00ecc48d2d1 100644 --- a/tensorflow/compiler/xla/mlir/backends/openxla/conversion/BUILD +++ b/tensorflow/compiler/xla/mlir/backends/openxla/conversion/BUILD @@ -1,4 +1,5 @@ load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") +load("//tensorflow/compiler/xla/mlir/backends/openxla:build_config.bzl", "if_openxla") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -10,7 +11,7 @@ package( # # cc_library( # name = "de_bufferization", -# hdrs = ["de_bufferization.h"], +# hdrs = if_openxla(["de_bufferization.h"]), # deps = [ # "@llvm-project//llvm:Support", # "@llvm-project//mlir:IR", @@ -20,14 +21,13 @@ package( # # cc_library( # name = "convert_compiled_ops", -# srcs = ["convert_compiled_ops.cc"], -# hdrs = ["convert_compiled_ops.h"], +# srcs = if_openxla(["convert_compiled_ops.cc"]), +# hdrs = if_openxla(["convert_compiled_ops.h"]), # # TODO(ezhulenev): Override cc_library()'s internal default value of ["//buildenv/target:gce"] # # because IREE targets are not compatible with `non_prod` constraint. # compatible_with = [], # deps = [ # ":de_bufferization", -# "//third_party/iree/llvm-external-projects/iree-dialects:IREEInputDialect", # "@llvm-project//llvm:Support", # "@llvm-project//mlir:ArithDialect", # "@llvm-project//mlir:BufferizationDialect", @@ -41,38 +41,36 @@ package( # "//tensorflow/compiler/xla/service/gpu:gpu_executable", # "//tensorflow/compiler/xla/service/gpu:launch_dimensions", # "//tensorflow/compiler/xla/service/gpu:thunk", -# ], +# ] + if_openxla(["//third_party/iree/llvm-external-projects/iree-dialects:IREEInputDialect"]), # ) # # cc_library( # name = "convert_memref_ops", -# srcs = ["convert_memref_ops.cc"], -# hdrs = ["convert_memref_ops.h"], +# srcs = if_openxla(["convert_memref_ops.cc"]), +# hdrs = if_openxla(["convert_memref_ops.h"]), # # TODO(ezhulenev): Override cc_library()'s internal default value of ["//buildenv/target:gce"] # # because IREE targets are not compatible with `non_prod` constraint. # compatible_with = [], # deps = [ # ":de_bufferization", -# "//third_party/iree/llvm-external-projects/iree-dialects:IREEInputDialect", # "@llvm-project//mlir:ArithDialect", # "@llvm-project//mlir:FuncDialect", # "@llvm-project//mlir:IR", # "@llvm-project//mlir:MemRefDialect", # "@llvm-project//mlir:TensorDialect", # "@llvm-project//mlir:Transforms", -# ], +# ] + if_openxla(["//third_party/iree/llvm-external-projects/iree-dialects:IREEInputDialect"]), # ) # # cc_library( # name = "convert_while_op", -# srcs = ["convert_while_op.cc"], -# hdrs = ["convert_while_op.h"], +# srcs = if_openxla(["convert_while_op.cc"]), +# hdrs = if_openxla(["convert_while_op.h"]), # # TODO(ezhulenev): Override cc_library()'s internal default value of ["//buildenv/target:gce"] # # because IREE targets are not compatible with `non_prod` constraint. # compatible_with = [], # deps = [ # ":de_bufferization", -# "//third_party/iree/llvm-external-projects/iree-dialects:IREEInputDialect", # "@llvm-project//llvm:Support", # "@llvm-project//mlir:ArithDialect", # "@llvm-project//mlir:BufferizationDialect", @@ -83,7 +81,7 @@ package( # "@llvm-project//mlir:TensorDialect", # "@llvm-project//mlir:Transforms", # "//tensorflow/compiler/xla/mlir_hlo:lhlo", -# ], +# ] + if_openxla(["//third_party/iree/llvm-external-projects/iree-dialects:IREEInputDialect"]), # ) # # copybara:uncomment_end diff --git a/tensorflow/compiler/xla/mlir/backends/openxla/transforms/BUILD b/tensorflow/compiler/xla/mlir/backends/openxla/transforms/BUILD index c72c0f2ce1db5d..335fad7d0b6195 100644 --- a/tensorflow/compiler/xla/mlir/backends/openxla/transforms/BUILD +++ b/tensorflow/compiler/xla/mlir/backends/openxla/transforms/BUILD @@ -1,6 +1,7 @@ load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_portable") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") +load("//tensorflow/compiler/xla/mlir/backends/openxla:build_config.bzl", "if_not_openxla", "if_openxla") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -26,17 +27,17 @@ package( # # cc_library( # name = "passes", -# srcs = [ +# srcs = if_openxla([ # "convert_to_openxla.cc", # "passes.cc", -# ], +# ]), # hdrs = ["passes.h"], # # TODO(ezhulenev): Override cc_library()'s default compatibility because IREE targets are not # # compatible with `non_prod` constraint. # compatible_with = [], +# defines = if_not_openxla(["XLA_DISABLE_OPENXLA_COMPILER=1"]), # deps = [ # ":passes_inc_gen", -# "//third_party/iree/llvm-external-projects/iree-dialects:IREEInputDialect", # "@llvm-project//llvm:Support", # "@llvm-project//mlir:ArithDialect", # "@llvm-project//mlir:FuncDialect", @@ -51,7 +52,7 @@ package( # "//tensorflow/compiler/xla/mlir/backends/openxla/conversion:convert_memref_ops", # "//tensorflow/compiler/xla/mlir/backends/openxla/conversion:convert_while_op", # "//tensorflow/compiler/xla/mlir_hlo:lhlo", -# ], +# ] + if_openxla(["//third_party/iree/llvm-external-projects/iree-dialects:IREEInputDialect"]), # ) # # copybara:uncomment_end_and_comment_begin diff --git a/tensorflow/compiler/xla/mlir/backends/openxla/transforms/passes.h b/tensorflow/compiler/xla/mlir/backends/openxla/transforms/passes.h index 50ed90d121d407..723cccd9eb86bf 100644 --- a/tensorflow/compiler/xla/mlir/backends/openxla/transforms/passes.h +++ b/tensorflow/compiler/xla/mlir/backends/openxla/transforms/passes.h @@ -30,6 +30,8 @@ namespace xla::gpu { class ThunkSequence; inline void populateOpenXlaRuntimePasses(mlir::OpPassManager&, ThunkSequence*) { } + +inline void registerOpenXlaPases() {} } // namespace xla::gpu //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/xla/mlir/backends/openxla/transforms/tests/BUILD b/tensorflow/compiler/xla/mlir/backends/openxla/transforms/tests/BUILD index 0238b82c42ebc1..50a76e76551f73 100644 --- a/tensorflow/compiler/xla/mlir/backends/openxla/transforms/tests/BUILD +++ b/tensorflow/compiler/xla/mlir/backends/openxla/transforms/tests/BUILD @@ -10,6 +10,10 @@ package( # # glob_lit_tests( # data = [":test_utilities"], +# default_tags = [ +# "manual", +# "notap", +# ], # driver = "//tensorflow/compiler/xla:run_lit.sh", # test_file_exts = ["mlir"], # ) From 55368a589120f0e211d7b76a32eab75a164093ac Mon Sep 17 00:00:00 2001 From: Nicolas Perez Date: Thu, 27 Jul 2023 08:43:06 -0700 Subject: [PATCH 231/410] Implement XLA kernel for general Conv Op. PiperOrigin-RevId: 551541935 --- .../compiler/jit/mark_for_compilation_pass.cc | 1 + .../mlir/tensorflow/ir/tf_generated_ops.td | 39 ++ .../transforms/legalization_op_config.cc | 1 + .../transforms/legalization_op_config_test.cc | 2 +- tensorflow/compiler/tests/BUILD | 3 + tensorflow/compiler/tests/conv2d_test.py | 188 ++++-- tensorflow/compiler/tests/conv3d_test.py | 632 ++++++++++++++++++ tensorflow/compiler/tests/test_utils.py | 34 +- tensorflow/compiler/tf2xla/kernels/BUILD | 1 + .../tf2xla/kernels/conv_op_helpers.cc | 29 + .../compiler/tf2xla/kernels/conv_op_helpers.h | 14 + .../compiler/tf2xla/kernels/conv_ops.cc | 80 +++ .../core/api_def/base_api/api_def_Conv.pbtxt | 4 +- 13 files changed, 967 insertions(+), 61 deletions(-) diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 23391ac6fd9824..896eb840737d53 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -2044,6 +2044,7 @@ absl::flat_hash_set GetKnownXLAAllowlistOp() { "CheckNumerics", "Cholesky", "ControlTrigger", + "Conv", "Conv2D", "Conv2DBackpropFilter", "Conv2DBackpropInput", diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index ab078e2ad31018..29d063213a3fef 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -2696,6 +2696,45 @@ used for communication with all hosts.}]>:$network_configs TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>; } +def TF_ConvOp : TF_Op<"Conv", [Pure]> { + let summary = [{ +Computes a N-D convolution given (N+1+batch_dims)-D `input` and (N+2)-D `filter` tensors. + }]; + + let description = [{ +General function for computing a N-D convolution. It is required that +`1 <= N <= 3`. + }]; + + let arguments = (ins + Arg, [{Tensor of type T and shape `batch_shape + spatial_shape + [in_channels]` in the +case that `channels_last_format = true` or shape +`batch_shape + [in_channels] + spatial_shape` if `channels_last_format = false`. +spatial_shape is N-dimensional with `N=2` or `N=3`. +Also note that `batch_shape` is dictated by the parameter `batch_dims` +and defaults to 1.}]>:$input, + Arg, [{An `(N+2)-D` Tensor with the same type as `input` and shape +`spatial_filter_shape + [in_channels, out_channels]`, where spatial_filter_shape +is N-dimensional with `N=2` or `N=3`. +}]>:$filter, + + I64ArrayAttr:$strides, + TF_AnyStrAttrOf<["SAME", "VALID", "EXPLICIT"]>:$padding, + DefaultValuedOptionalAttr:$explicit_paddings, + DefaultValuedOptionalAttr, "\"CHANNELS_LAST\"">:$data_format, + DefaultValuedOptionalAttr:$dilations, + DefaultValuedOptionalAttr:$batch_dims, + DefaultValuedOptionalAttr:$groups + ); + + let results = (outs + Res, [{A (N+1+batch_dims)-D tensor. The dimension order is determined by the value of +`channels_last_format`, see below for details.}]>:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_Conv2DOp : TF_Op<"Conv2D", [InferTensorType, Pure, TF_LayoutSensitiveInterface]> { let summary = [{ Computes a 2-D convolution given 4-D `input` and `filter` tensors. diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config.cc index c5df672d2d9378..fa138cc9bca4bb 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config.cc @@ -178,6 +178,7 @@ bool IsOpTypeAllowedTf2XlaFallback(const TypeID& type_id) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc index 79220db6087a14..a194a5b330a20e 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc @@ -130,7 +130,7 @@ TEST_F(LegalizationOpConfigTest, CountLoweringsSet) { // from MLIR to TF2XLA), these numbers should change. Or if TF Dialect adds // a new op, we should expect these to change too. EXPECT_EQ(mlir_lowering_count, 71); - EXPECT_EQ(tf2xla_fallback_count, 295); + EXPECT_EQ(tf2xla_fallback_count, 296); EXPECT_EQ(non_categorized_count, 419); } diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index c5d58bcde0b053..a678c6fbea5881 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -609,6 +609,7 @@ tf_xla_py_strict_test( "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip ], deps = [ + ":test_utils", ":xla_test", "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", @@ -616,8 +617,10 @@ tf_xla_py_strict_test( "//tensorflow/python/ops:gradient_checker", "//tensorflow/python/ops:nn_grad", "//tensorflow/python/ops:nn_ops", + "//tensorflow/python/ops:nn_ops_gen", "//tensorflow/python/platform:test", "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", ], ) diff --git a/tensorflow/compiler/tests/conv2d_test.py b/tensorflow/compiler/tests/conv2d_test.py index 0a0929bafffa77..13191a52e1f261 100644 --- a/tensorflow/compiler/tests/conv2d_test.py +++ b/tensorflow/compiler/tests/conv2d_test.py @@ -34,25 +34,35 @@ ("_data_format_NCHW", "NCHW"), ) +CONV_CONFIGS = ( + ("_Conv2D_data_format_NHWC", "NHWC", "Conv2D"), + ("_Conv2D_data_format_NCHW", "NCHW", "Conv2D"), + ("_Conv_data_format_NHWC", "NHWC", "Conv"), + ("_Conv_data_format_NCHW", "NCHW", "Conv"), +) + class Conv2DTest(xla_test.XLATestCase, parameterized.TestCase): - def _VerifyValues(self, - input_sizes=None, - filter_sizes=None, - strides=None, - dilations=None, - padding=None, - data_format_src="NHWC", - data_format_dst="NHWC", - expected=None): + def _VerifyValues( + self, + input_sizes=None, + filter_sizes=None, + strides=None, + dilations=None, + padding=None, + data_format_src="NHWC", + data_format_dst="NHWC", + expected=None, + op_name="Conv2D", + ): """Tests that tf.nn.conv2d produces the expected value. Args: - input_sizes: Input tensor dimensions in - [batch, input_rows, input_cols, input_depth]. - filter_sizes: Filter tensor dimensions in - [kernel_rows, kernel_cols, input_depth, output_depth]. + input_sizes: Input tensor dimensions in [batch, input_rows, input_cols, + input_depth]. + filter_sizes: Filter tensor dimensions in [kernel_rows, kernel_cols, + input_depth, output_depth]. strides: Strides. dilations: RHS dilations. padding: Padding type. @@ -60,6 +70,7 @@ def _VerifyValues(self, data_format_dst: Data format verification will run and input is converted to. expected: Expected output. + op_name: Name of operation to test (Conv/Conv2D) """ total_size_1 = np.prod(input_sizes) @@ -87,19 +98,35 @@ def _VerifyValues(self, t1 = array_ops.placeholder(dtypes.float32, shape=input_sizes) t2 = array_ops.placeholder(dtypes.float32, shape=filter_sizes) with self.test_scope(): - out = nn_ops.conv2d( - t1, - t2, - strides=strides, - padding=padding, - data_format=data_format_dst, - dilations=dilations) + if op_name == "Conv": + conv_format = ( + "CHANNELS_LAST" if data_format_dst == "NHWC" else "CHANNELS_FIRST" + ) + out = gen_nn_ops.conv( + t1, + t2, + strides=strides, + padding=padding, + data_format=conv_format, + dilations=dilations, + ) + elif op_name == "Conv2D": + out = nn_ops.conv2d( + t1, + t2, + strides=strides, + padding=padding, + data_format=data_format_dst, + dilations=dilations, + ) + else: + raise ValueError("Invalid op name: %s" % op_name) value = sess.run(out, {t1: x1, t2: x2}) self.assertAllClose(expected, value, 1e-3) - @parameterized.named_parameters(*DATA_FORMATS) - def testConv2D1x1Filter(self, data_format): + @parameterized.named_parameters(*CONV_CONFIGS) + def testConv2D1x1Filter(self, data_format, op_name): expected_output = np.reshape([ 30.0, 36.0, 42.0, 66.0, 81.0, 96.0, 102.0, 126.0, 150.0, 138.0, 171.0, 204.0, 174.0, 216.0, 258.0, 210.0, 261.0, 312.0 @@ -111,10 +138,12 @@ def testConv2D1x1Filter(self, data_format): padding="VALID", data_format_src="NHWC", data_format_dst=data_format, - expected=expected_output) + expected=expected_output, + op_name=op_name, + ) - @parameterized.named_parameters(*DATA_FORMATS) - def testConv2D2x2Filter(self, data_format): + @parameterized.named_parameters(*CONV_CONFIGS) + def testConv2D2x2Filter(self, data_format, op_name): expected_output = np.reshape( [2271.0, 2367.0, 2463.0, 2901.0, 3033.0, 3165.0], [1, 1, 2, 3]) self._VerifyValues( @@ -124,10 +153,12 @@ def testConv2D2x2Filter(self, data_format): padding="VALID", data_format_src="NHWC", data_format_dst=data_format, - expected=expected_output) + expected=expected_output, + op_name=op_name, + ) - @parameterized.named_parameters(*DATA_FORMATS) - def testConv2D2x2Filter2x1Dilation(self, data_format): + @parameterized.named_parameters(*CONV_CONFIGS) + def testConv2D2x2Filter2x1Dilation(self, data_format, op_name): expected_output = np.array([[[[72], [82], [92]], [[112], [122], [132]]]]) self._VerifyValues( input_sizes=[1, 4, 4, 1], @@ -137,10 +168,12 @@ def testConv2D2x2Filter2x1Dilation(self, data_format): padding="VALID", data_format_src="NHWC", data_format_dst=data_format, - expected=expected_output) + expected=expected_output, + op_name=op_name, + ) - @parameterized.named_parameters(*DATA_FORMATS) - def testConv2D1x2Filter(self, data_format): + @parameterized.named_parameters(*CONV_CONFIGS) + def testConv2D1x2Filter(self, data_format, op_name): expected_output = np.reshape([ 231.0, 252.0, 273.0, 384.0, 423.0, 462.0, 690.0, 765.0, 840.0, 843.0, 936.0, 1029.0 @@ -152,10 +185,12 @@ def testConv2D1x2Filter(self, data_format): padding="VALID", data_format_src="NHWC", data_format_dst=data_format, - expected=expected_output) + expected=expected_output, + op_name=op_name, + ) - @parameterized.named_parameters(*DATA_FORMATS) - def testConv2D2x2FilterStride2(self, data_format): + @parameterized.named_parameters(*CONV_CONFIGS) + def testConv2D2x2FilterStride2(self, data_format, op_name): expected_output = np.reshape([2271.0, 2367.0, 2463.0], [1, 1, 1, 3]) self._VerifyValues( input_sizes=[1, 2, 3, 3], @@ -164,10 +199,12 @@ def testConv2D2x2FilterStride2(self, data_format): padding="VALID", data_format_src="NHWC", data_format_dst=data_format, - expected=expected_output) + expected=expected_output, + op_name=op_name, + ) - @parameterized.named_parameters(*DATA_FORMATS) - def testConv2D2x2FilterStride2Same(self, data_format): + @parameterized.named_parameters(*CONV_CONFIGS) + def testConv2D2x2FilterStride2Same(self, data_format, op_name): expected_output = np.reshape( [2271.0, 2367.0, 2463.0, 1230.0, 1305.0, 1380.0], [1, 1, 2, 3]) self._VerifyValues( @@ -177,10 +214,12 @@ def testConv2D2x2FilterStride2Same(self, data_format): padding="SAME", data_format_src="NHWC", data_format_dst=data_format, - expected=expected_output) + expected=expected_output, + op_name=op_name, + ) - @parameterized.named_parameters(*DATA_FORMATS) - def testConv2DEmptyDilation(self, data_format): + @parameterized.named_parameters(*CONV_CONFIGS) + def testConv2DEmptyDilation(self, data_format, op_name): self._VerifyValues( input_sizes=[0, 2, 3, 3], filter_sizes=[1, 1, 3, 3], @@ -189,10 +228,12 @@ def testConv2DEmptyDilation(self, data_format): padding="VALID", data_format_src="NHWC", data_format_dst=data_format, - expected=np.zeros([0, 2, 3, 3])) + expected=np.zeros([0, 2, 3, 3]), + op_name=op_name, + ) - @parameterized.named_parameters(*DATA_FORMATS) - def testConv2D2x2FilterDilation(self, data_format): + @parameterized.named_parameters(*CONV_CONFIGS) + def testConv2D2x2FilterDilation(self, data_format, op_name): self._VerifyValues( input_sizes=[1, 2, 3, 3], filter_sizes=[2, 2, 3, 3], @@ -201,10 +242,12 @@ def testConv2D2x2FilterDilation(self, data_format): padding="VALID", data_format_src="NHWC", data_format_dst=data_format, - expected=np.reshape([2667, 2781, 2895], [1, 1, 1, 3])) + expected=np.reshape([2667, 2781, 2895], [1, 1, 1, 3]), + op_name=op_name, + ) - @parameterized.named_parameters(*DATA_FORMATS) - def testConv2D1x2FilterDilation(self, data_format): + @parameterized.named_parameters(*CONV_CONFIGS) + def testConv2D1x2FilterDilation(self, data_format, op_name): self._VerifyValues( input_sizes=[1, 2, 3, 3], filter_sizes=[1, 2, 3, 3], @@ -213,11 +256,15 @@ def testConv2D1x2FilterDilation(self, data_format): padding="VALID", data_format_src="NHWC", data_format_dst=data_format, - expected=np.array([[[[231, 252, 273], [384, 423, 462]], - [[690, 765, 840], [843, 936, 1029]]]])) + expected=np.array([[ + [[231, 252, 273], [384, 423, 462]], + [[690, 765, 840], [843, 936, 1029]], + ]]), + op_name=op_name, + ) - @parameterized.named_parameters(*DATA_FORMATS) - def testConv2DKernelSizeMatchesInputSizeDilation(self, data_format): + @parameterized.named_parameters(*CONV_CONFIGS) + def testConv2DKernelSizeMatchesInputSizeDilation(self, data_format, op_name): self._VerifyValues( input_sizes=[1, 3, 3, 1], filter_sizes=[2, 2, 1, 2], @@ -226,7 +273,46 @@ def testConv2DKernelSizeMatchesInputSizeDilation(self, data_format): padding="VALID", data_format_src="NHWC", data_format_dst=data_format, - expected=np.reshape([108, 128], [1, 1, 1, 2])) + expected=np.reshape([108, 128], [1, 1, 1, 2]), + op_name=op_name, + ) + + def testConvExpandedBatch(self): + tensor_in_sizes_batch = [10, 2, 3, 3] + tensor_in_sizes_expanded_batch = [2, 5, 2, 3, 3] + batch_dims = 2 + filter_in_sizes = [1, 1, 3, 3] + filter_in = np.arange( + 1, np.prod(filter_in_sizes) + 1, dtype=np.float32 + ).reshape(filter_in_sizes) + x1 = np.arange( + 1, np.prod(tensor_in_sizes_batch) + 1, dtype=np.float32 + ).reshape(tensor_in_sizes_batch) + x2 = x1.reshape(tensor_in_sizes_expanded_batch) + + with self.session() as sess: + t1 = array_ops.placeholder(dtypes.bfloat16, shape=tensor_in_sizes_batch) + t2 = array_ops.placeholder( + dtypes.bfloat16, shape=tensor_in_sizes_expanded_batch + ) + filter_t = array_ops.placeholder(dtypes.bfloat16, shape=filter_in_sizes) + + out1 = gen_nn_ops.conv( + t1, filter_t, strides=[1, 1, 1, 1], padding="VALID" + ) + out2 = gen_nn_ops.conv( + t2, + filter_t, + strides=[1, 1, 1, 1], + padding="VALID", + batch_dims=batch_dims, + ) + value1 = sess.run(out1, {t1: x1, filter_t: filter_in}) + value2 = sess.run(out2, {t2: x2, filter_t: filter_in}) + + self.assertEqual(list(value1.shape), tensor_in_sizes_batch) + self.assertEqual(list(value2.shape), tensor_in_sizes_expanded_batch) + self.assertAllCloseAccordingToType(value1, value2.reshape(value1.shape)) class Conv2DBackpropInputTest(xla_test.XLATestCase, parameterized.TestCase): diff --git a/tensorflow/compiler/tests/conv3d_test.py b/tensorflow/compiler/tests/conv3d_test.py index 55d9593eee4906..9cfa0920f5d1c7 100644 --- a/tensorflow/compiler/tests/conv3d_test.py +++ b/tensorflow/compiler/tests/conv3d_test.py @@ -14,18 +14,650 @@ # ============================================================================== """Tests for 3D convolutions using the XLA JIT.""" +from absl.testing import parameterized import numpy as np +from tensorflow.compiler.tests import test_utils from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import nn_ops import tensorflow.python.ops.nn_grad # pylint: disable=unused-import from tensorflow.python.platform import googletest +CONV_CONFIGS = ( + ("_Conv3D_data_format_NDHWC", "NDHWC", "Conv3D"), + ("_Conv3D_data_format_NCDHW", "NCDHW", "Conv3D"), + ("_Conv_data_format_NDHWC", "NDHWC", "Conv"), + ("_Conv_data_format_NCDHW", "NCDHW", "Conv"), +) + + +# Test outputs computed in prod (colab) by running nn.conv3d on a GPU device +# with its GPU (non-xla) kernel. +class Conv3DTest(xla_test.XLATestCase, parameterized.TestCase): + + def _VerifyValues( + self, + input_sizes=None, + filter_sizes=None, + strides=None, + dilations=None, + padding=None, + data_format_src="NDHWC", + data_format_dst="NDHWC", + expected=None, + op_name="Conv3D", + ): + """Tests that tf.nn.conv3d produces the expected value. + + Args: + input_sizes: Input tensor dimensions in [batch, input_rows, input_cols, + input_depth]. + filter_sizes: Filter tensor dimensions in [kernel_rows, kernel_cols, + input_depth, output_depth]. + strides: Strides. + dilations: RHS dilations. + padding: Padding type. + data_format_src: Data format input is in. + data_format_dst: Data format verification will run and input is converted + to. + expected: Expected output. + op_name: Name of operation to test (Conv/Conv2D) + """ + + total_size_1 = np.prod(input_sizes) + total_size_2 = np.prod(filter_sizes) + x1 = np.reshape( + [f * 1.0 / total_size_1 for f in range(1, total_size_1 + 1)], + input_sizes, + ) + x2 = np.reshape( + [f * 1.0 / total_size_2 for f in range(1, total_size_2 + 1)], + filter_sizes, + ) + strides = [1] + strides + [1] + if dilations is None: + dilations = [1, 1, 1] + dilations = [1] + dilations + [1] + + # Convert between data formats. + expected = test_utils.ConvertBetweenDataFormats( + expected, data_format_src, data_format_dst + ) + x1 = test_utils.ConvertBetweenDataFormats( + x1, data_format_src, data_format_dst + ) + input_sizes = test_utils.PermuteDimsBetweenDataFormats( + input_sizes, data_format_src, data_format_dst + ) + strides = test_utils.PermuteDimsBetweenDataFormats( + strides, data_format_src, data_format_dst + ) + dilations = test_utils.PermuteDimsBetweenDataFormats( + dilations, data_format_src, data_format_dst + ) + + with self.session() as sess: + t1 = array_ops.placeholder(dtypes.bfloat16, shape=input_sizes) + t2 = array_ops.placeholder(dtypes.bfloat16, shape=filter_sizes) + with self.test_scope(): + if op_name == "Conv": + conv_format = ( + "CHANNELS_LAST" + if data_format_dst == "NDHWC" + else "CHANNELS_FIRST" + ) + out = gen_nn_ops.conv( + t1, + t2, + strides=strides, + padding=padding, + data_format=conv_format, + dilations=dilations, + ) + elif op_name == "Conv3D": + out = nn_ops.conv3d( + t1, + t2, + strides=strides, + padding=padding, + data_format=data_format_dst, + dilations=dilations, + ) + else: + raise ValueError("Invalid op name: %s" % op_name) + + value = sess.run(out, {t1: x1, t2: x2}) + self.assertAllCloseAccordingToType(expected, value) + + @parameterized.named_parameters(*CONV_CONFIGS) + def testConv3D1x1x1Filter(self, data_format, op_name): + expected_output = np.reshape( + [ + 0.18518518518518517, + 0.2222222222222222, + 0.25925925925925924, + 0.4074074074074074, + 0.5, + 0.5925925925925926, + 0.6296296296296297, + 0.7777777777777777, + 0.9259259259259259, + 0.8518518518518519, + 1.0555555555555556, + 1.259259259259259, + 1.074074074074074, + 1.3333333333333333, + 1.5925925925925926, + 1.2962962962962963, + 1.6111111111111112, + 1.9259259259259258, + ], + [1, 2, 3, 1, 3], + ) + + # These are equivalent to the Conv2D1x1 case. + self._VerifyValues( + input_sizes=[1, 2, 3, 1, 3], + filter_sizes=[1, 1, 1, 3, 3], + strides=[1, 1, 1], + padding="VALID", + expected=expected_output, + data_format_src="NDHWC", + data_format_dst=data_format, + op_name=op_name, + ) + self._VerifyValues( + input_sizes=[1, 2, 1, 3, 3], + filter_sizes=[1, 1, 1, 3, 3], + strides=[1, 1, 1], + padding="VALID", + expected=np.reshape(expected_output, [1, 2, 1, 3, 3]), + data_format_src="NDHWC", + data_format_dst=data_format, + op_name=op_name, + ) + self._VerifyValues( + input_sizes=[1, 1, 2, 3, 3], + filter_sizes=[1, 1, 1, 3, 3], + strides=[1, 1, 1], + padding="VALID", + expected=np.reshape(expected_output, [1, 1, 2, 3, 3]), + data_format_src="NDHWC", + data_format_dst=data_format, + op_name=op_name, + ) + + @parameterized.named_parameters(*CONV_CONFIGS) + def testConv3D1x1x1Filter2x1x1Dilation(self, data_format, op_name): + expected_output = np.reshape( + [ + 0.05555555555555555, + 0.1111111111111111, + 0.16666666666666666, + 0.2222222222222222, + 0.2777777777777778, + 0.3333333333333333, + 0.3888888888888889, + 0.4444444444444444, + 0.5, + 0.5555555555555556, + 0.6111111111111112, + 0.6666666666666666, + 0.7222222222222222, + 0.7777777777777778, + 0.8333333333333334, + 0.8888888888888888, + 0.9444444444444444, + 1.0, + ], + [1, 3, 6, 1, 1], + ) + + self._VerifyValues( + input_sizes=[1, 3, 6, 1, 1], + filter_sizes=[1, 1, 1, 1, 1], + strides=[1, 1, 1], + padding="VALID", + dilations=[2, 1, 1], + expected=expected_output, + data_format_src="NDHWC", + data_format_dst=data_format, + op_name=op_name, + ) + + # Expected values computed using scipy's correlate function. + @parameterized.named_parameters(*CONV_CONFIGS) + def testConv3D2x2x2Filter(self, data_format, op_name): + expected_output = np.reshape( + [ + 3.7719907407407405, + 3.850694444444445, + 3.929398148148149, + 4.265046296296295, + 4.357638888888888, + 4.450231481481481, + 6.730324074074074, + 6.892361111111109, + 7.054398148148148, + 7.223379629629629, + 7.399305555555557, + 7.575231481481481, + 9.688657407407408, + 9.934027777777779, + 10.17939814814815, + 10.181712962962962, + 10.440972222222221, + 10.700231481481481, + ], + [1, 3, 1, 2, 3], + ) + # expected_shape = [1, 3, 1, 2, 5] + self._VerifyValues( + input_sizes=[1, 4, 2, 3, 3], # b, z, y, x, fin + filter_sizes=[2, 2, 2, 3, 3], # z, y, x, fin, fout + strides=[1, 1, 1], + padding="VALID", + expected=expected_output, + data_format_src="NDHWC", + data_format_dst=data_format, + op_name=op_name, + ) + + @parameterized.named_parameters(*CONV_CONFIGS) + def testConv3D2x2x2Filter1x2x1Dilation(self, data_format, op_name): + expected_output = np.reshape( + [ + 1.1388888888888888, + 1.2013888888888888, + 1.3263888888888888, + 1.3888888888888888, + 1.5138888888888888, + 1.5763888888888888, + 1.701388888888889, + 1.763888888888889, + 2.263888888888889, + 2.3263888888888893, + 2.451388888888889, + 2.513888888888889, + 2.6388888888888893, + 2.701388888888889, + 2.826388888888889, + 2.888888888888889, + 3.388888888888889, + 3.451388888888889, + 3.576388888888889, + 3.6388888888888884, + 3.7638888888888893, + 3.8263888888888893, + 3.9513888888888893, + 4.013888888888889, + ], + [1, 3, 4, 2, 1], + ) + + self._VerifyValues( + input_sizes=[1, 4, 6, 3, 1], + filter_sizes=[2, 2, 2, 1, 1], + strides=[1, 1, 1], + padding="VALID", + dilations=[1, 2, 1], + expected=expected_output, + data_format_src="NDHWC", + data_format_dst=data_format, + op_name=op_name, + ) + + @parameterized.named_parameters(*CONV_CONFIGS) + def testConv3DStrides(self, data_format, op_name): + expected_output = np.reshape( + [ + 0.06071428571428571, + 0.08988095238095238, + 0.10238095238095238, + 0.11488095238095238, + 0.12738095238095237, + 0.13988095238095238, + 0.08452380952380953, + 0.26071428571428573, + 0.35238095238095235, + 0.36488095238095236, + 0.3773809523809524, + 0.3898809523809524, + 0.4023809523809524, + 0.23452380952380952, + 0.46071428571428574, + 0.6148809523809524, + 0.6273809523809524, + 0.6398809523809523, + 0.6523809523809524, + 0.6648809523809525, + 0.3845238095238095, + 1.1273809523809524, + 1.4898809523809524, + 1.5023809523809524, + 1.5148809523809523, + 1.5273809523809523, + 1.5398809523809525, + 0.8845238095238095, + 1.3273809523809526, + 1.7523809523809522, + 1.764880952380952, + 1.7773809523809523, + 1.7898809523809525, + 1.8023809523809526, + 1.0345238095238096, + 1.5273809523809525, + 2.0148809523809526, + 2.0273809523809523, + 2.0398809523809525, + 2.052380952380952, + 2.0648809523809524, + 1.1845238095238095, + 2.1940476190476192, + 2.8898809523809526, + 2.9023809523809527, + 2.9148809523809525, + 2.9273809523809526, + 2.9398809523809524, + 1.6845238095238095, + 2.394047619047619, + 3.1523809523809523, + 3.1648809523809525, + 3.177380952380952, + 3.1898809523809524, + 3.2023809523809526, + 1.8345238095238097, + 2.594047619047619, + 3.4148809523809525, + 3.427380952380952, + 3.4398809523809524, + 3.4523809523809526, + 3.4648809523809523, + 1.9845238095238096, + ], + [1, 3, 3, 7, 1], + ) + self._VerifyValues( + input_sizes=[1, 5, 8, 7, 1], + filter_sizes=[1, 2, 3, 1, 1], + strides=[2, 3, 1], # different stride for each spatial dimension + padding="SAME", + expected=expected_output, + data_format_src="NDHWC", + data_format_dst=data_format, + op_name=op_name, + ) + + @parameterized.named_parameters(*CONV_CONFIGS) + def testConv3D2x2x2FilterStride2(self, data_format, op_name): + expected_output = np.reshape( + [ + 3.7719907407407405, + 3.850694444444445, + 3.929398148148149, + 9.688657407407408, + 9.934027777777779, + 10.17939814814815, + ], + [1, 2, 1, 1, 3], + ) + self._VerifyValues( + input_sizes=[1, 4, 2, 3, 3], + filter_sizes=[2, 2, 2, 3, 3], + strides=[2, 2, 2], + padding="VALID", + expected=expected_output, + data_format_src="NDHWC", + data_format_dst=data_format, + op_name=op_name, + ) + + @parameterized.named_parameters(*CONV_CONFIGS) + def testConv3DStride3(self, data_format, op_name): + expected_output = np.reshape( + [ + 1.5114087301587302, + 1.5716765873015872, + 1.6319444444444446, + 1.5634920634920635, + 1.6267361111111112, + 1.6899801587301588, + 1.6155753968253967, + 1.681795634920635, + 1.748015873015873, + 1.9280753968253967, + 2.012152777777778, + 2.096230158730159, + 1.9801587301587302, + 2.067212301587302, + 2.154265873015873, + 2.0322420634920637, + 2.122271825396825, + 2.2123015873015874, + 4.428075396825396, + 4.65500992063492, + 4.881944444444444, + 4.480158730158729, + 4.710069444444444, + 4.939980158730158, + 4.532242063492063, + 4.7651289682539675, + 4.9980158730158735, + 4.844742063492064, + 5.095486111111112, + 5.346230158730158, + 4.896825396825397, + 5.150545634920635, + 5.4042658730158735, + 4.94890873015873, + 5.205605158730158, + 5.462301587301588, + ], + [1, 2, 2, 3, 3], + ) + self._VerifyValues( + input_sizes=[1, 6, 7, 8, 2], + filter_sizes=[3, 2, 1, 2, 3], + strides=[3, 3, 3], + padding="VALID", + expected=expected_output, + data_format_src="NDHWC", + data_format_dst=data_format, + op_name=op_name, + ) + + @parameterized.named_parameters(*CONV_CONFIGS) + def testConv3D2x2x2FilterStride2Same(self, data_format, op_name): + expected_output = np.reshape( + [ + 3.7719907407407405, + 3.850694444444445, + 3.929398148148149, + 2.0162037037037037, + 2.0659722222222223, + 2.1157407407407405, + 9.688657407407408, + 9.934027777777779, + 10.17939814814815, + 4.599537037037037, + 4.732638888888889, + 4.8657407407407405, + ], + [1, 2, 1, 2, 3], + ) + self._VerifyValues( + input_sizes=[1, 4, 2, 3, 3], + filter_sizes=[2, 2, 2, 3, 3], + strides=[2, 2, 2], + padding="SAME", + expected=expected_output, + data_format_src="NDHWC", + data_format_dst=data_format, + op_name=op_name, + ) + + @parameterized.named_parameters(*CONV_CONFIGS) + def testKernelSmallerThanStride(self, data_format, op_name): + expected_output = np.reshape( + [ + 0.037037037037037035, + 0.1111111111111111, + 0.25925925925925924, + 0.3333333333333333, + 0.7037037037037037, + 0.7777777777777778, + 0.9259259259259259, + 1.0, + ], + [1, 2, 2, 2, 1], + ) + + self._VerifyValues( + input_sizes=[1, 3, 3, 3, 1], + filter_sizes=[1, 1, 1, 1, 1], + strides=[2, 2, 2], + padding="SAME", + expected=expected_output, + data_format_src="NDHWC", + data_format_dst=data_format, + op_name=op_name, + ) + self._VerifyValues( + input_sizes=[1, 3, 3, 3, 1], + filter_sizes=[1, 1, 1, 1, 1], + strides=[2, 2, 2], + padding="VALID", + expected=expected_output, + data_format_src="NDHWC", + data_format_dst=data_format, + op_name=op_name, + ) + + expected_output = np.reshape( + [ + 0.5408163265306123, + 0.5801749271137027, + 0.28061224489795916, + 0.8163265306122448, + 0.8556851311953353, + 0.4030612244897959, + 0.41873177842565595, + 0.43403790087463556, + 0.19642857142857142, + 2.4693877551020407, + 2.5087463556851315, + 1.1377551020408163, + 2.7448979591836733, + 2.7842565597667637, + 1.260204081632653, + 1.168731778425656, + 1.1840379008746356, + 0.5178571428571429, + 1.0951166180758019, + 1.1060495626822158, + 0.4464285714285714, + 1.1716472303206997, + 1.1825801749271136, + 0.4770408163265306, + 0.3691690962099125, + 0.37244897959183676, + 0.125, + ], + [1, 3, 3, 3, 1], + ) + self._VerifyValues( + input_sizes=[1, 7, 7, 7, 1], + filter_sizes=[2, 2, 2, 1, 1], + strides=[3, 3, 3], + padding="SAME", + expected=expected_output, + data_format_src="NDHWC", + data_format_dst=data_format, + op_name=op_name, + ) + + expected_output = np.reshape( + [ + 0.5408163265306123, + 0.5801749271137027, + 0.8163265306122448, + 0.8556851311953353, + 2.4693877551020407, + 2.5087463556851315, + 2.7448979591836733, + 2.7842565597667637, + ], + [1, 2, 2, 2, 1], + ) + self._VerifyValues( + input_sizes=[1, 7, 7, 7, 1], + filter_sizes=[2, 2, 2, 1, 1], + strides=[3, 3, 3], + padding="VALID", + expected=expected_output, + data_format_src="NDHWC", + data_format_dst=data_format, + op_name=op_name, + ) + + @parameterized.named_parameters(*CONV_CONFIGS) + def testKernelSizeMatchesInputSize(self, data_format, op_name): + expected_output = np.reshape([1.5625, 1.875], [1, 1, 1, 1, 2]) + self._VerifyValues( + input_sizes=[1, 2, 1, 2, 1], + filter_sizes=[2, 1, 2, 1, 2], + strides=[1, 1, 1], + padding="VALID", + expected=expected_output, + data_format_src="NDHWC", + data_format_dst=data_format, + op_name=op_name, + ) + + def testConvExpandedBatch(self): + tensor_in_sizes_batch = [10, 2, 3, 1, 3] + tensor_in_sizes_expanded_batch = [2, 5, 2, 3, 1, 3] + batch_dims = 2 + filter_in_sizes = [1, 1, 1, 3, 3] + filter_in = np.arange( + 1, np.prod(filter_in_sizes) + 1, dtype=np.float32 + ).reshape(filter_in_sizes) + x1 = np.arange( + 1, np.prod(tensor_in_sizes_batch) + 1, dtype=np.float32 + ).reshape(tensor_in_sizes_batch) + x2 = x1.reshape(tensor_in_sizes_expanded_batch) + + with self.session() as sess: + t1 = array_ops.placeholder(dtypes.bfloat16, shape=tensor_in_sizes_batch) + t2 = array_ops.placeholder( + dtypes.bfloat16, shape=tensor_in_sizes_expanded_batch + ) + filter_t = array_ops.placeholder(dtypes.bfloat16, shape=filter_in_sizes) + + out1 = gen_nn_ops.conv( + t1, filter_t, strides=[1, 1, 1, 1, 1], padding="VALID" + ) + out2 = gen_nn_ops.conv( + t2, + filter_t, + strides=[1, 1, 1, 1, 1], + padding="VALID", + batch_dims=batch_dims, + ) + value1 = sess.run(out1, {t1: x1, filter_t: filter_in}) + value2 = sess.run(out2, {t2: x2, filter_t: filter_in}) + + self.assertEqual(list(value1.shape), tensor_in_sizes_batch) + self.assertEqual(list(value2.shape), tensor_in_sizes_expanded_batch) + self.assertAllCloseAccordingToType(value1, value2.reshape(value1.shape)) + + # Test cloned from # tensorflow/python/kernel_tests/conv3d_backprop_filter_v2_grad_test.py class Conv3DBackpropFilterV2GradTest(xla_test.XLATestCase): diff --git a/tensorflow/compiler/tests/test_utils.py b/tensorflow/compiler/tests/test_utils.py index 049bddc4aa7685..a506dfd07c77f0 100644 --- a/tensorflow/compiler/tests/test_utils.py +++ b/tensorflow/compiler/tests/test_utils.py @@ -19,17 +19,27 @@ def ConvertBetweenDataFormats(x, data_format_src, data_format_dst): - """Converts 4D tensor between data formats.""" + """Converts 4D/5D tensor between data formats.""" - valid_data_formats = ["NHWC", "NCHW", "HWNC", "HWCN"] + valid_data_formats = ["NHWC", "NCHW", "HWNC", "HWCN", "NDHWC", "NCDHW"] + if len(data_format_src) != len(data_format_dst): + raise ValueError( + "data_format_src and data_format_dst must have the same dimension, got" + " %s and %s." % (len(data_format_src), len(data_format_dst)) + ) if data_format_src not in valid_data_formats: raise ValueError("data_format_src must be of %s, got %s." % (valid_data_formats, data_format_src)) if data_format_dst not in valid_data_formats: raise ValueError("data_format_dst must be of %s, got %s." % (valid_data_formats, data_format_dst)) - if len(x.shape) != 4: - raise ValueError("x must be 4D, got shape %s." % x.shape) + if len(x.shape) != 4 and len(x.shape) != 5: + raise ValueError("x must be 4D or 5D, got shape %s." % x.shape) + if len(x.shape) != len(data_format_src): + raise ValueError( + "x must be the same dimensions as data_format_src (%s), got shape %s." + % (len(data_format_src), x.shape) + ) if data_format_src == data_format_dst: return x @@ -42,15 +52,25 @@ def ConvertBetweenDataFormats(x, data_format_src, data_format_dst): def PermuteDimsBetweenDataFormats(dims, data_format_src, data_format_dst): """Get new shape for converting between data formats.""" - valid_data_formats = ["NHWC", "NCHW", "HWNC", "HWCN"] + valid_data_formats = ["NHWC", "NCHW", "HWNC", "HWCN", "NDHWC", "NCDHW"] + if len(data_format_src) != len(data_format_dst): + raise ValueError( + "data_format_src and data_format_dst must have the same dimension, got" + " %s and %s." % (len(data_format_src), len(data_format_dst)) + ) if data_format_src not in valid_data_formats: raise ValueError("data_format_src must be of %s, got %s." % (valid_data_formats, data_format_src)) if data_format_dst not in valid_data_formats: raise ValueError("data_format_dst must be of %s, got %s." % (valid_data_formats, data_format_dst)) - if len(dims) != 4: - raise ValueError("dims must be of length 4, got %s." % dims) + if len(dims) != 4 and len(dims) != 5: + raise ValueError("dims must be of length 4 or 5, got %s." % dims) + if len(dims) != len(data_format_src): + raise ValueError( + "dims must be the same dimensions as data_format_src (%s), got %s." + % (len(data_format_src), dims) + ) if data_format_src == data_format_dst: return dims diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 1c60ba5874746b..2fac8f21e97a06 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -347,6 +347,7 @@ cc_library( "//tensorflow/core/kernels:conv_grad_shape_utils", "//tensorflow/core/platform:statusor", "//tensorflow/tsl/platform:tensor_float_32_hdr_lib", + "@com_google_absl//absl/status", "@com_google_absl//absl/types:span", ] + if_static(["//tensorflow/tsl/platform:tensor_float_32_utils"]), ) diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc index 242c022c892faf..8eaa39a1fcde12 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/status/status.h" #include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" @@ -206,6 +207,34 @@ StatusOr ConvOpAttrs::Create(int num_spatial_dims, bool depthwise, return attrs; } +StatusOr ConvNDOpAttrs::Create(OpKernelConstruction* ctx) { + ConvNDOpAttrs attrs; + TF_RETURN_IF_ERROR(ctx->GetAttr("groups", &attrs.groups)); + TF_RETURN_IF_ERROR(ctx->GetAttr("batch_dims", &attrs.batch_dims)); + if (attrs.batch_dims < 1) { + return absl::InvalidArgumentError("batch_dims must be non-negative."); + } + TF_RETURN_IF_ERROR(ctx->GetAttr("dilations", &attrs.dilations)); + TF_RETURN_IF_ERROR(ctx->GetAttr("strides", &attrs.strides)); + TF_RETURN_IF_ERROR(ctx->GetAttr("padding", &attrs.padding)); + if (attrs.padding == EXPLICIT) { + TF_RETURN_IF_ERROR( + ctx->GetAttr("explicit_paddings", &attrs.explicit_paddings)); + } + + string data_format_str; + TF_RETURN_IF_ERROR(ctx->GetAttr("data_format", &data_format_str)); + if (!(data_format_str == "CHANNELS_LAST" || + data_format_str == "CHANNELS_FIRST")) { + return absl::InvalidArgumentError( + absl::StrCat("Unknown data format: ", data_format_str)); + } + attrs.data_format = + data_format_str == "CHANNELS_LAST" ? FORMAT_NHWC : FORMAT_NCHW; + + return attrs; +} + StatusOr MakeXlaForwardConvOp(StringPiece /*type_string*/, xla::XlaOp conv_input, xla::XlaOp filter, diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h index 70c579cde73cc5..bec541e47cced9 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h +++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h @@ -56,6 +56,20 @@ struct ConvOpAttrs { TensorFormat data_format; }; +// Helper for the general Conv Op. +struct ConvNDOpAttrs { + // Constructs a ConvOpAttrs, reading most of the attributes from `ctx`. + static StatusOr Create(OpKernelConstruction* ctx); + + int groups; + int batch_dims; + std::vector dilations; + std::vector strides; + Padding padding; + std::vector explicit_paddings; + TensorFormat data_format; +}; + // Creates a new XLA forward or backward convolution with the given inputs and // attributes. StatusOr MakeXlaForwardConvOp(StringPiece type_string, diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc index 0f1b53c8a56b25..02dbb8fe2c8f0f 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc @@ -15,6 +15,9 @@ limitations under the License. // XLA-specific Ops for 2D convolution. +#include +#include + #include "tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" @@ -29,6 +32,7 @@ limitations under the License. #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/op_requires.h" #include "tensorflow/core/framework/ops_util.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -65,6 +69,82 @@ class ConvOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(ConvOp); }; +class ConvNDOp : public XlaOpKernel { + public: + explicit ConvNDOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + StatusOr attrs = ConvNDOpAttrs::Create(ctx); + OP_REQUIRES_OK(ctx, attrs.status()); + attrs_ = attrs.value(); + } + + void Compile(XlaOpKernelContext* ctx) override { + // Need to know input rank ahead of time to determine type of convolution. + OP_REQUIRES_VALUE(xla::Shape input_shape, ctx, ctx->InputXlaShape(0)); + int num_spatial_dims = input_shape.rank() - 1 - attrs_.batch_dims; + OP_REQUIRES_OK(ctx, + CheckValidPadding(attrs_.padding, attrs_.explicit_paddings, + /*num_dims=*/num_spatial_dims + 2, + attrs_.data_format)); + + ConvOpAttrs forward_attrs; + forward_attrs.depthwise = false; + forward_attrs.num_spatial_dims = num_spatial_dims; + forward_attrs.dilations = attrs_.dilations.empty() + ? std::vector(num_spatial_dims + 2, 1) + : attrs_.dilations; + forward_attrs.strides = attrs_.strides; + forward_attrs.padding = attrs_.padding; + forward_attrs.explicit_paddings = attrs_.explicit_paddings; + forward_attrs.data_format = attrs_.data_format; + + xla::XlaOp input = ctx->Input(0); + xla::XlaOp filter = ctx->Input(1); + + if (attrs_.batch_dims == 0) { + // Expand dummy batch dimension. + xla::Shape expanded_input_shape(input_shape); + for (int i = 0; i < expanded_input_shape.rank() - 1; ++i) { + expanded_input_shape.set_dimensions(i + 1, input_shape.dimensions(i)); + } + expanded_input_shape.set_dimensions(0, 1); + input = xla::Reshape(input, expanded_input_shape.dimensions()); + } else if (attrs_.batch_dims > 1) { + // Flatten batch_dims. + std::vector to_collapse(attrs_.batch_dims); + for (int i = 0; i < attrs_.batch_dims; ++i) { + to_collapse[i] = i; + } + input = xla::Collapse(input, to_collapse); + } + + StatusOr forward = MakeXlaForwardConvOp( + ctx->op_kernel().type_string(), input, filter, forward_attrs); + OP_REQUIRES_OK(ctx, forward.status()); + + xla::XlaOp out = forward.value(); + auto* builder = out.builder(); + OP_REQUIRES_VALUE(xla::Shape out_shape, ctx, builder->GetShape(out)); + // Reshape output. + if (attrs_.batch_dims == 0) { + xla::Shape no_batch_shape(out_shape); + no_batch_shape.DeleteDimension(0); + out = xla::Reshape(out, no_batch_shape.dimensions()); + } else if (attrs_.batch_dims > 1) { + xla::Shape expanded_out_shape(input_shape); + for (int i = attrs_.batch_dims; i < input_shape.rank(); ++i) { + expanded_out_shape.set_dimensions( + i, out_shape.dimensions(i - (attrs_.batch_dims - 1))); + } + out = xla::Reshape(out, expanded_out_shape.dimensions()); + } + ctx->SetOutput(0, out); + } + + protected: + ConvNDOpAttrs attrs_; +}; +REGISTER_XLA_CONV_OP(Name("Conv"), ConvNDOp); + class Conv2DOp : public ConvOp { public: explicit Conv2DOp(OpKernelConstruction* ctx) diff --git a/tensorflow/core/api_def/base_api/api_def_Conv.pbtxt b/tensorflow/core/api_def/base_api/api_def_Conv.pbtxt index 40d682bbe3bf3b..9bdcd533eab446 100644 --- a/tensorflow/core/api_def/base_api/api_def_Conv.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_Conv.pbtxt @@ -23,7 +23,7 @@ END out_arg { name: "output" description: < Date: Thu, 27 Jul 2023 16:55:12 +0100 Subject: [PATCH 232/410] [Linaro:ARM_CI] Retry flaky tests on AARCH64 as temp measure Until these flaky tests are resolved in x86 builds have them retry in AARCH64 builds. --- tensorflow/tools/ci_build/rel/ubuntu/cpu_arm64_nonpip.sh | 2 +- tensorflow/tools/ci_build/rel/ubuntu/cpu_arm64_pip.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/tools/ci_build/rel/ubuntu/cpu_arm64_nonpip.sh b/tensorflow/tools/ci_build/rel/ubuntu/cpu_arm64_nonpip.sh index cae5791083a1fa..f4d6d0f49bdb09 100644 --- a/tensorflow/tools/ci_build/rel/ubuntu/cpu_arm64_nonpip.sh +++ b/tensorflow/tools/ci_build/rel/ubuntu/cpu_arm64_nonpip.sh @@ -96,7 +96,7 @@ source tensorflow/tools/ci_build/build_scripts/ARM_SKIP_TESTS_EXTENDED.sh # Export optional variables for running the tests export TF_BUILD_FLAGS="--config=mkl_aarch64_threadpool --copt=-flax-vector-conversions" -export TF_TEST_FLAGS="${TF_BUILD_FLAGS} \ +export TF_TEST_FLAGS="${TF_BUILD_FLAGS} --flaky_test_attempts=3 \ --test_env=TF_ENABLE_ONEDNN_OPTS=1 --test_env=TF2_BEHAVIOR=1 --define=tf_api_version=2 \ --test_lang_filters=py --test_size_filters=small,medium \ --test_output=errors --verbose_failures=true --test_keep_going --notest_verbose_timeout_warnings" diff --git a/tensorflow/tools/ci_build/rel/ubuntu/cpu_arm64_pip.sh b/tensorflow/tools/ci_build/rel/ubuntu/cpu_arm64_pip.sh index 817b30da6cff88..30c8b940a897c4 100644 --- a/tensorflow/tools/ci_build/rel/ubuntu/cpu_arm64_pip.sh +++ b/tensorflow/tools/ci_build/rel/ubuntu/cpu_arm64_pip.sh @@ -101,7 +101,7 @@ source tensorflow/tools/ci_build/build_scripts/ARM_SKIP_TESTS_EXTENDED.sh # Export optional variables for running the tests export TF_BUILD_FLAGS="--config=mkl_aarch64_threadpool --copt=-flax-vector-conversions" -export TF_TEST_FLAGS="${TF_BUILD_FLAGS} \ +export TF_TEST_FLAGS="${TF_BUILD_FLAGS} --flaky_test_attempts=3 \ --test_env=TF_ENABLE_ONEDNN_OPTS=1 --test_env=TF2_BEHAVIOR=1 --define=tf_api_version=2 \ --test_lang_filters=py --test_size_filters=small,medium \ --test_output=errors --verbose_failures=true --test_keep_going --notest_verbose_timeout_warnings" From 36747a8d1de7e8593c24d95bdbf33298886745c5 Mon Sep 17 00:00:00 2001 From: shuw Date: Thu, 27 Jul 2023 09:02:50 -0700 Subject: [PATCH 233/410] Bump gemm rewrite test tolerance --- .../compiler/xla/service/gpu/tests/gemm_rewrite_test.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc index 64482553843808..34a00c6e10051a 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc @@ -422,7 +422,7 @@ ENTRY AddDotsFunc { )"; - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{2.5e-5, 1e-5})); MatchOptimizedHlo(hlo_text, R"( ; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,5,3], y: f32[5,3,4]) -> f32[5,2,4] { @@ -461,7 +461,7 @@ ENTRY AddDotsFunc { )"; - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{2.5e-5, 1e-5})); MatchOptimizedHlo(hlo_text, R"( ; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[3,2,5], y: f32[5,3,4]) -> f32[5,2,4] { @@ -582,7 +582,7 @@ ENTRY AddDotsFunc { )"; - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{2.5e-5, 1e-5})); MatchOptimizedHlo(hlo_text, R"( ; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[5,2,3], y: f32[5,3,4]) -> f32[2,5,4] { @@ -622,7 +622,7 @@ ENTRY AddDotsFunc { )"; - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{2.5e-5, 1e-5})); MatchOptimizedHlo(hlo_text, R"( ; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[5,2,3], y: f32[5,3,4]) -> f32[2,4,5] { From 9da78ff4f4a09ad20932b7b213a6524d9471f061 Mon Sep 17 00:00:00 2001 From: Michael Hudgins Date: Thu, 27 Jul 2023 09:06:37 -0700 Subject: [PATCH 234/410] Temporarily reenable flaky_test_attempts for upcoming TensorFlow 2.14 release and branch cut window. This flag will be removed from master after branch cut. PiperOrigin-RevId: 551548089 --- .bazelrc | 4 ++++ tensorflow/tools/ci_build/rel/ubuntu/cpu_arm64_cpp.sh | 2 +- tensorflow/tools/ci_build/rel/ubuntu/cpu_arm64_nonpip.sh | 2 +- tensorflow/tools/ci_build/rel/ubuntu/cpu_arm64_pip.sh | 2 +- .../tf_sig_build_dockerfiles/devel.usertools/cpu.bazelrc | 3 ++- .../tf_sig_build_dockerfiles/devel.usertools/gpu.bazelrc | 2 +- 6 files changed, 10 insertions(+), 5 deletions(-) diff --git a/.bazelrc b/.bazelrc index b860e90fc79360..21d45f9991ff44 100644 --- a/.bazelrc +++ b/.bazelrc @@ -415,6 +415,8 @@ build:rbe --google_default_credentials build:rbe --bes_backend=buildeventservice.googleapis.com build:rbe --bes_results_url="https://source.cloud.google.com/results/invocations" build:rbe --bes_timeout=600s +// TODO(b/290857564): Remove from mainline after 2.14 branch cut +build:rbe --flaky_test_attempts=3 build:rbe --define=EXECUTOR=remote build:rbe --jobs=800 build:rbe --remote_executor=grpcs://remotebuildexecution.googleapis.com @@ -555,6 +557,8 @@ try-import %workspace%/.bazelrc.user # Here are bazelrc configs for release builds build:release_base --config=v2 test:release_base --test_size_filters=small,medium +// TODO(b/290857564): Remove from mainline after 2.14 branch cut +test:release_base --flaky_test_attempts=3 build:release_cpu_linux --config=release_base build:release_cpu_linux --config=avx_linux diff --git a/tensorflow/tools/ci_build/rel/ubuntu/cpu_arm64_cpp.sh b/tensorflow/tools/ci_build/rel/ubuntu/cpu_arm64_cpp.sh index 34ba8a10fcfeab..122d8838c0c5e3 100644 --- a/tensorflow/tools/ci_build/rel/ubuntu/cpu_arm64_cpp.sh +++ b/tensorflow/tools/ci_build/rel/ubuntu/cpu_arm64_cpp.sh @@ -71,7 +71,7 @@ source tensorflow/tools/ci_build/build_scripts/ARM_SKIP_TESTS_EXTENDED.sh export TF_BUILD_FLAGS="--config=mkl_aarch64_threadpool --copt=-flax-vector-conversions" export TF_TEST_FLAGS="${TF_BUILD_FLAGS} \ --test_env=TF_ENABLE_ONEDNN_OPTS=1 --test_env=TF2_BEHAVIOR=1 --define=tf_api_version=2 \ - --test_lang_filters=cc --test_size_filters=small,medium \ + --test_lang_filters=cc --flaky_test_attempts=3 --test_size_filters=small,medium \ --test_output=errors --verbose_failures=true --test_keep_going --notest_verbose_timeout_warnings" export TF_TEST_TARGETS="${DEFAULT_BAZEL_TARGETS} ${ARM_SKIP_TESTS}" export TF_FILTER_TAGS="-no_oss,-oss_excluded,-oss_serial,-v1only,-benchmark-test,-no_aarch64,-gpu,-tpu,-no_oss_py39,-no_oss_py310" diff --git a/tensorflow/tools/ci_build/rel/ubuntu/cpu_arm64_nonpip.sh b/tensorflow/tools/ci_build/rel/ubuntu/cpu_arm64_nonpip.sh index cae5791083a1fa..b40b95c275746f 100644 --- a/tensorflow/tools/ci_build/rel/ubuntu/cpu_arm64_nonpip.sh +++ b/tensorflow/tools/ci_build/rel/ubuntu/cpu_arm64_nonpip.sh @@ -98,7 +98,7 @@ source tensorflow/tools/ci_build/build_scripts/ARM_SKIP_TESTS_EXTENDED.sh export TF_BUILD_FLAGS="--config=mkl_aarch64_threadpool --copt=-flax-vector-conversions" export TF_TEST_FLAGS="${TF_BUILD_FLAGS} \ --test_env=TF_ENABLE_ONEDNN_OPTS=1 --test_env=TF2_BEHAVIOR=1 --define=tf_api_version=2 \ - --test_lang_filters=py --test_size_filters=small,medium \ + --test_lang_filters=py --flaky_test_attempts=3 --test_size_filters=small,medium \ --test_output=errors --verbose_failures=true --test_keep_going --notest_verbose_timeout_warnings" export TF_TEST_TARGETS="${DEFAULT_BAZEL_TARGETS} ${ARM_SKIP_TESTS}" export TF_FILTER_TAGS="-no_oss,-oss_excluded,-oss_serial,-v1only,-benchmark-test,-no_aarch64,-gpu,-tpu,-no_oss_py39,-no_oss_py310" diff --git a/tensorflow/tools/ci_build/rel/ubuntu/cpu_arm64_pip.sh b/tensorflow/tools/ci_build/rel/ubuntu/cpu_arm64_pip.sh index 817b30da6cff88..c5a625c29b58cc 100644 --- a/tensorflow/tools/ci_build/rel/ubuntu/cpu_arm64_pip.sh +++ b/tensorflow/tools/ci_build/rel/ubuntu/cpu_arm64_pip.sh @@ -103,7 +103,7 @@ source tensorflow/tools/ci_build/build_scripts/ARM_SKIP_TESTS_EXTENDED.sh export TF_BUILD_FLAGS="--config=mkl_aarch64_threadpool --copt=-flax-vector-conversions" export TF_TEST_FLAGS="${TF_BUILD_FLAGS} \ --test_env=TF_ENABLE_ONEDNN_OPTS=1 --test_env=TF2_BEHAVIOR=1 --define=tf_api_version=2 \ - --test_lang_filters=py --test_size_filters=small,medium \ + --test_lang_filters=py --flaky_test_attempts=3 --test_size_filters=small,medium \ --test_output=errors --verbose_failures=true --test_keep_going --notest_verbose_timeout_warnings" export TF_TEST_TARGETS="${DEFAULT_BAZEL_TARGETS} ${ARM_SKIP_TESTS}" export TF_FILTER_TAGS="-no_oss,-oss_excluded,-oss_serial,-v1only,-benchmark-test,-no_aarch64,-gpu,-tpu,-no_oss_py39,-no_oss_py310" diff --git a/tensorflow/tools/tf_sig_build_dockerfiles/devel.usertools/cpu.bazelrc b/tensorflow/tools/tf_sig_build_dockerfiles/devel.usertools/cpu.bazelrc index 80fdc6bb2a1250..0fe43aaf627fea 100644 --- a/tensorflow/tools/tf_sig_build_dockerfiles/devel.usertools/cpu.bazelrc +++ b/tensorflow/tools/tf_sig_build_dockerfiles/devel.usertools/cpu.bazelrc @@ -53,9 +53,10 @@ test --test_summary=short # Pass --config=nonpip to run the same suite of tests. If you want to run just # one test for investigation, you don't need --config=nonpip; just run the # bazel test invocation as normal. + test:nonpip_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py39,-no_oss_py310 test:nonpip_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py39,-no_oss_py310 -test:nonpip_filters --test_lang_filters=py --test_size_filters=small,medium +test:nonpip_filters --test_lang_filters=py --flaky_test_attempts=3 --test_size_filters=small,medium test:nonpip --config=nonpip_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... # For building libtensorflow archives diff --git a/tensorflow/tools/tf_sig_build_dockerfiles/devel.usertools/gpu.bazelrc b/tensorflow/tools/tf_sig_build_dockerfiles/devel.usertools/gpu.bazelrc index 4b64c5fd2c5e5c..a51ae89424a6ca 100644 --- a/tensorflow/tools/tf_sig_build_dockerfiles/devel.usertools/gpu.bazelrc +++ b/tensorflow/tools/tf_sig_build_dockerfiles/devel.usertools/gpu.bazelrc @@ -78,7 +78,7 @@ test --test_summary=short # bazel test invocation as normal. test:nonpip_filters --test_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-no_cuda11,-no_oss_py39,-no_oss_py310 test:nonpip_filters --build_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-no_cuda11,-no_oss_py39,-no_oss_py310 -test:nonpip_filters --test_lang_filters=py --test_size_filters=small,medium +test:nonpip_filters --test_lang_filters=py --flaky_test_attempts=3 --test_size_filters=small,medium test:nonpip --config=nonpip_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... # For building libtensorflow archives From b0f659b50bc8347a1afcb5d044f9d09c0aebcc48 Mon Sep 17 00:00:00 2001 From: Haoyu Zhang Date: Thu, 27 Jul 2023 09:26:43 -0700 Subject: [PATCH 235/410] Fix typo in error message. PiperOrigin-RevId: 551553440 --- tensorflow/core/kernels/image/encode_png_op.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/kernels/image/encode_png_op.cc b/tensorflow/core/kernels/image/encode_png_op.cc index e92ea1d9260e6d..bde5d478846051 100644 --- a/tensorflow/core/kernels/image/encode_png_op.cc +++ b/tensorflow/core/kernels/image/encode_png_op.cc @@ -61,7 +61,7 @@ class EncodePngOp : public OpKernel { void Compute(OpKernelContext* context) override { const Tensor& image = context->input(0); OP_REQUIRES(context, image.dims() >= 3, - errors::InvalidArgument("images must be ast least rank 3", + errors::InvalidArgument("images must be at least rank 3", image.shape().DebugString())); OP_REQUIRES(context, image.NumElements() >= 0, errors::Internal("Invalid image provided.")); From 007324c0c15723c34806d63abb5d25f1d19b53e7 Mon Sep 17 00:00:00 2001 From: Scott Zhu Date: Thu, 27 Jul 2023 09:28:07 -0700 Subject: [PATCH 236/410] Disable the failing test for ASAN for now. PiperOrigin-RevId: 551553829 --- tensorflow/python/ops/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/python/ops/BUILD b/tensorflow/python/ops/BUILD index 9343cc010648b3..4d4561dbf3473a 100644 --- a/tensorflow/python/ops/BUILD +++ b/tensorflow/python/ops/BUILD @@ -4481,6 +4481,7 @@ py_strict_library( py_strict_test( name = "weak_tensor_ops_test", srcs = ["weak_tensor_ops_test.py"], + tags = ["noasan"], # TODO(b/293304945): Reenable. deps = [ ":array_ops", ":array_ops_gen", From d044bd6974b500901c8d1067d937e35d6e8a148a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 27 Jul 2023 09:32:14 -0700 Subject: [PATCH 237/410] Add EagerWeakTensor and GraphWeakTensor. EagerWeakTensor and GraphWeakTensor are wrapper classes that are introduced for WeakTensor to pass instance checks for core.Value or core.Symbol. EagerWeakTensor is created when WeakTensor is created with an eager Tensor and GraphWeakTensor is created when WeakTensor is created with a graph Tensor. PiperOrigin-RevId: 551554969 --- tensorflow/python/eager/execute.py | 7 + tensorflow/python/framework/BUILD | 7 +- tensorflow/python/framework/weak_tensor.py | 119 +++++++---- .../python/framework/weak_tensor_test.py | 199 ++++++++++-------- tensorflow/python/ops/weak_tensor_ops.py | 8 +- .../python/ops/weak_tensor_test_util.py | 4 +- 6 files changed, 211 insertions(+), 133 deletions(-) diff --git a/tensorflow/python/eager/execute.py b/tensorflow/python/eager/execute.py index 94236fa66fdfcc..d524dd90aae649 100644 --- a/tensorflow/python/eager/execute.py +++ b/tensorflow/python/eager/execute.py @@ -50,6 +50,13 @@ def quick_execute(op_name, num_outputs, inputs, attrs, ctx, name=None): # pylint: disable=protected-access try: ctx.ensure_initialized() + # Convert any objects of type core_types.Tensor to Tensor. + inputs = [ + tensor_conversion_registry.convert(t) + if isinstance(t, core_types.Tensor) + else t + for t in inputs + ] tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name, inputs, attrs, num_outputs) except core._NotOkStatusException as e: diff --git a/tensorflow/python/framework/BUILD b/tensorflow/python/framework/BUILD index b281baf9be9400..8b09a321682609 100644 --- a/tensorflow/python/framework/BUILD +++ b/tensorflow/python/framework/BUILD @@ -1615,7 +1615,7 @@ py_strict_library( ], ) -py_strict_library( +pytype_strict_library( name = "weak_tensor", srcs = ["weak_tensor.py"], srcs_version = "PY3", @@ -1627,7 +1627,9 @@ py_strict_library( ":ops", ":tensor", ":tensor_conversion_registry", + ":tensor_spec", "//tensorflow/python/eager:context", + "//tensorflow/python/types:core", "//third_party/py/numpy", ], ) @@ -1642,16 +1644,19 @@ tf_py_strict_test( ":constant_op", ":dtypes", ":errors", + ":extension_type", ":ops", ":tensor", ":test_lib", ":weak_tensor", "//tensorflow/python/eager:backprop", + "//tensorflow/python/eager:context", "//tensorflow/python/eager:def_function", "//tensorflow/python/module", "//tensorflow/python/platform:test", "//tensorflow/python/saved_model:load", "//tensorflow/python/saved_model:save", + "//tensorflow/python/types:core", ], ) diff --git a/tensorflow/python/framework/weak_tensor.py b/tensorflow/python/framework/weak_tensor.py index 6b3f456028972d..e89a07d7f04284 100644 --- a/tensorflow/python/framework/weak_tensor.py +++ b/tensorflow/python/framework/weak_tensor.py @@ -14,7 +14,6 @@ # ============================================================================= """An extension type that represents WeakTensor.""" - from typing import Optional import numpy as np @@ -27,6 +26,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor as tensor_lib from tensorflow.python.framework import tensor_conversion_registry +from tensorflow.python.types import core _ALLOWED_WEAK_DTYPES = ( @@ -48,9 +48,7 @@ def replace_gradient_components(self, weak_tensor, component_grads): return weak_tensor._type_spec._from_components([component_grads]) # pylint: disable=protected-access -# TODO(b/285024542): Modify the isinstance() checks to include WeakTensor. -# instance. -class WeakTensor(extension_type.ExtensionType): +class WeakTensor(extension_type.BatchableExtensionType, core.Tensor): """A weakly typed Tensor. A simple wrapper class that contains a normal Tensor. @@ -91,13 +89,6 @@ def __getattr__(self, *args, **kwargs): # directly expose Tensor's methods. return getattr(self.tensor, *args, **kwargs) - def __array__(self, dtype=None): - # We need to explicitly call np.array() because - # self_tensor.__array__() for scalars raise: - # ValueError: object __array__ method not producing an array - # resource_variable_ops also follows the same pattern. - return np.array(self.tensor.__array__(dtype)) - def _disallow(self, task): raise errors.OperatorNotAllowedInGraphError( f"{task} is not allowed. You can attempt the following resolutions to" @@ -143,21 +134,6 @@ def __tf_tensor__( ): return self.tensor.__tf_tensor__(dtype=dtype, name=name) - def __format__(self, format_spec): - return f"{self.tensor.__format__(format_spec)} weakly typed" - - def __complex__(self): - return self.tensor.__complex__() - - def __int__(self): - return self.tensor.__int__() - - def __float__(self): - return self.tensor.__float__() - - def __index__(self): - return self.tensor.__index__() - def __deepcopy__(self, memo): # Eager Tensors are immutable so it's safe to return themselves as a copy. del memo @@ -167,20 +143,33 @@ def to_tensor(self): """Converts this 'WeakTensor' into a 'tf.Tensor'.""" return self.tensor - def numpy(self): - """Copy of the contents of this WeakTensor into a NumPy array or scalar.""" - if not isinstance(self.tensor, ops.EagerTensor): - raise ValueError("WeakTensor.numpy() is only supported in eager mode.") - return self.tensor.numpy() - def _as_graph_element(self): """Convert `self` to a graph element.""" return self.tensor @classmethod def from_tensor(cls, tensor): - """Converts a 'tf.Tensor' into a 'WeakTensor'.""" - return WeakTensor(tensor) + """Converts a 'tf.Tensor' into a 'WeakTensor'. + + This should be the standard way of creating a WeakTensor instead + of directly calling the WeakTensor constructor. + + Args: + tensor: The `tf.Tensor` that should be converted into a 'WeakTensor'. + + Returns: + A `EagerWeakTensor` or 'GraphWeakTensor' that holds the `tensor`. + """ + if isinstance(tensor, core.Value): + return EagerWeakTensor(tensor) + if isinstance(tensor, core.Symbol): + return GraphWeakTensor(tensor) + raise errors.InvalidArgumentError( + None, + None, + "WeakTensor can only be constructed from tf.Tensor or tf.WeakTensor," + f" but {type(tensor)} was given.", + ) # Redefine `shape` and `dtype` rather than relying on `getattr` because the # class derives from core.Tensor which returns None in the two methods. @@ -199,6 +188,50 @@ def is_tensor_like(self): __composite_gradient__ = WeakTensorGradient() +# EagerWeakTensor and GraphWeakTensor are wrapper classes that are +# introduced for WeakTensor to pass instance checks for core.Value or +# core.Symbol. +class EagerWeakTensor(core.Value, WeakTensor): + """A weakly typed Eager Tensor.""" + + __name__ = "tf.EagerWeakTensor" + + # Methods that are only avilable for EagerTensor. + def numpy(self): + """Copy of the contents of this EagerWeakTensor into a NumPy array or scalar.""" + if not isinstance(self.tensor, ops.EagerTensor): + raise ValueError("WeakTensor.numpy() is only supported in eager mode.") + return self.tensor.numpy() + + def __complex__(self): + return self.tensor.__complex__() + + def __int__(self): + return self.tensor.__int__() + + def __float__(self): + return self.tensor.__float__() + + def __index__(self): + return self.tensor.__index__() + + def __format__(self, format_spec): + return f"{self.tensor.__format__(format_spec)} weakly typed" + + def __array__(self, dtype=None): + # We need to explicitly call np.array() because + # self_tensor.__array__() for scalars raise: + # ValueError: object __array__ method not producing an array + # resource_variable_ops also follows the same pattern. + return np.array(self.tensor.__array__(dtype)) + + +class GraphWeakTensor(core.Symbol, WeakTensor): + """A weakly typed Graph Tensor.""" + + __name__ = "tf.GraphWeakTensor" + + class _WeakTensorIterator(object): """Iterates over the leading dim of a WeakTensor. Performs no error checks.""" @@ -215,20 +248,26 @@ def __iter__(self): def __next__(self): if self._index == self._limit: raise StopIteration - result = WeakTensor(self._weak_tensor.tensor[self._index]) + result = WeakTensor.from_tensor((self._weak_tensor.tensor[self._index])) self._index += 1 return result -def maybe_convert_to_weak_tensor(t, is_weak): - return WeakTensor(t) if is_weak else t +def convert_to_weak_tensor_or_tensor(t, to_weak): + if to_weak: + return WeakTensor.from_tensor(t) + # We should return a normal Tensor because is_weak = False. + if isinstance(t, WeakTensor): + return t.tensor + return t -# convert_to_tensor(WeakTensor) should return a WeakTensor because WeakTensor is -# a 'Tensor' with a special dtype. +# convert_to_tensor(WeakTensor) should return a Tensor because convert_to_tensor +# is mostly used internally and we want to limit the scope of WeakTensor +# creation to tf.constant and WeakTensor patched ops. def weak_tensor_conversion_function(t): if isinstance(t, WeakTensor): - return t + return t.tensor tensor_conversion_registry.register_tensor_conversion_function( diff --git a/tensorflow/python/framework/weak_tensor_test.py b/tensorflow/python/framework/weak_tensor_test.py index 8fec36d82e3fe4..152e7df414ca5b 100644 --- a/tensorflow/python/framework/weak_tensor_test.py +++ b/tensorflow/python/framework/weak_tensor_test.py @@ -17,132 +17,106 @@ import numpy as np from tensorflow.python.eager import backprop +from tensorflow.python.eager import context from tensorflow.python.eager import def_function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import extension_type from tensorflow.python.framework import ops from tensorflow.python.framework import tensor from tensorflow.python.framework import test_util -from tensorflow.python.framework import weak_tensor +from tensorflow.python.framework.weak_tensor import EagerWeakTensor +from tensorflow.python.framework.weak_tensor import GraphWeakTensor +from tensorflow.python.framework.weak_tensor import WeakTensor from tensorflow.python.module import module from tensorflow.python.platform import googletest from tensorflow.python.saved_model.load import load from tensorflow.python.saved_model.save import save +from tensorflow.python.types import core class WeakTensorTest(test_util.TensorFlowTestCase): + @test_util.run_in_graph_and_eager_modes def test_weak_tensor_basic(self): - a = weak_tensor.WeakTensor(constant_op.constant(1, dtypes.int32)) + a = WeakTensor.from_tensor(constant_op.constant(1, dtypes.int32)) self.assertEqual(a.dtype, dtypes.int32) self.assertEqual(a.shape, []) - self.assertEqual(a.numpy(), 1) - self.assertEqual(np.array(a), 1) - with self.assertRaises(TypeError): - _ = weak_tensor.WeakTensor(constant_op.constant(1, dtypes.int16)) b = [1.0, 2.0], [3.0, 4.0] - bwt = weak_tensor.WeakTensor(constant_op.constant(b, dtypes.float32)) - self.assertEqual(bwt.dtype, dtypes.float32) - self.assertEqual(bwt.shape, [2, 2]) - self.assertAllEqual(bwt.numpy(), np.array(b, dtype=np.float32)) - self.assertAllEqual(np.array(bwt), np.array(b, dtype=np.float32)) + b_wt = WeakTensor.from_tensor(constant_op.constant(b, dtypes.float32)) + self.assertEqual(b_wt.dtype, dtypes.float32) + self.assertEqual(b_wt.shape, [2, 2]) + @test_util.run_in_graph_and_eager_modes def test_weak_tensor_init(self): # Make sure an exception is thrown for unallowed dtypes. t = constant_op.constant(1, dtypes.int16) with self.assertRaises(TypeError): - _ = weak_tensor.WeakTensor(t) - - def test_weak_tensor_num_methods(self): - t = constant_op.constant(1, dtypes.int32) - wt = weak_tensor.WeakTensor(t) - - self.assertEqual(complex(wt), complex(1)) - self.assertEqual(int(wt), int(1)) - self.assertEqual(float(wt), float(1)) - self.assertEqual(wt.__index__(), int(1)) + _ = WeakTensor.from_tensor(t) + + @test_util.run_in_graph_and_eager_modes + def test_weak_tensor_inheritance(self): + a = WeakTensor.from_tensor(constant_op.constant([1, 2, 3], dtypes.int32)) + self.assertIsInstance(a, WeakTensor) + self.assertIsInstance(a, core.Tensor) + self.assertIsInstance(a, extension_type.ExtensionType) + if context.executing_eagerly(): + self.assertIsInstance(a, core.Value) + self.assertIsInstance(a, EagerWeakTensor) + else: + self.assertIsInstance(a, core.Symbol) + self.assertIsInstance(a, GraphWeakTensor) + + def test_weak_tensor_eager_methods(self): + wt = WeakTensor.from_tensor(constant_op.constant(2, dtypes.int32)) + b = [1.0, 2.0], [3.0, 4.0] + b_wt = WeakTensor.from_tensor(constant_op.constant(b, dtypes.float32)) - def test_weak_tensor_format(self): - t = constant_op.constant(2, dtypes.int32) - wt = weak_tensor.WeakTensor(t) - # Format to binary representation. + self.assertEqual(complex(wt), complex(2)) + self.assertEqual(int(wt), int(2)) + self.assertEqual(float(wt), float(2)) + self.assertEqual(wt.__index__(), int(2)) + self.assertEqual(wt.numpy(), 2) self.assertEqual(format(wt, 'b'), '10 weakly typed') + self.assertEqual(np.array(wt), 2) + self.assertAllEqual(np.array(b_wt), np.array(b, dtype=np.float32)) + @test_util.run_in_graph_and_eager_modes def test_weak_tensor_bool(self): # Test to make sure WeakTensor(bool) isn't used as a bool. with self.assertRaises(TypeError): - if weak_tensor.WeakTensor(constant_op.constant(True)): + if WeakTensor.from_tensor(constant_op.constant(True)): raise TypeError('Type error is raised because WeakTensor != bool') - def test_weak_tensor_iter(self): - # Test normal weakTensor iteration. - t = constant_op.constant([0, 1, 2], dtypes.int32) - wt = weak_tensor.WeakTensor(t) - it_weak_tensor = iter(wt) - for i in range(len(wt)): - self.assertEqual( - next(it_weak_tensor), weak_tensor.WeakTensor(constant_op.constant(i)) - ) - - # Test multi-dimensional weakTensor iteration. - t_multi = constant_op.constant([[1, 2], [3, 4]], dtypes.int32) - wt_multi = weak_tensor.WeakTensor(t_multi) - it_wt_multi_tensor = iter(wt_multi) - self.assertEqual( - next(it_wt_multi_tensor), weak_tensor.WeakTensor(t_multi[0]) - ) - self.assertEqual( - next(it_wt_multi_tensor), weak_tensor.WeakTensor(t_multi[1]) - ) - - # Test scalar weakTensor iteration. - t_scalar = constant_op.constant(1, dtypes.int32) - wt_scalar = weak_tensor.WeakTensor(t_scalar) - with self.assertRaises(TypeError): - # Cannot iterate over a scalar tensor. - _ = iter(wt_scalar) - - # Make sure iteration is not allowed in Graph mode. - ops.disable_eager_execution() - with self.assertRaisesRegex( - errors.OperatorNotAllowedInGraphError, - 'Iterating over a symbolic `tf.WeakTensor` is not allowed. You can' - ' attempt the following resolutions to the problem: If you are running' - ' in Graph mode, use Eager execution mode or decorate this function' - ' with @tf.function. If you are using AutoGraph, you can try decorating' - ' this function with @tf.function. If that does not work, then you may' - ' be using an unsupported feature or your source code may not be' - ' visible to AutoGraph. See' - ' https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/reference/limitations.md#access-to-source-code' - ' for more information.', - ): - _ = iter(wt) - ops.enable_eager_execution() - + @test_util.run_in_graph_and_eager_modes def test_weak_tensor_getattr(self): - t = constant_op.constant(1, dtypes.int32) - wt = weak_tensor.WeakTensor(t) + wt = WeakTensor.from_tensor(constant_op.constant(1, dtypes.int32)) wt_name = getattr(wt, '__name__', None) - self.assertEqual(wt_name, 'tf.WeakTensor') + if context.executing_eagerly(): + self.assertEqual(wt_name, 'tf.EagerWeakTensor') + else: + self.assertEqual(wt_name, 'tf.GraphWeakTensor') + @test_util.run_in_graph_and_eager_modes def test_weak_tensor_in_tf_func(self): @def_function.function() def f(x): return x t = constant_op.constant(1, dtypes.int32) - wt = weak_tensor.WeakTensor(t) + wt = WeakTensor.from_tensor(t) res = f(wt) - self.assertIsInstance(res, weak_tensor.WeakTensor) + self.assertIsInstance(res, WeakTensor) _ = f(t) self.assertEqual(f.experimental_get_tracing_count(), 2) + @test_util.run_in_graph_and_eager_modes def test_weak_tensor_in_tf_func_with_branch_error(self): a = constant_op.constant(1, dtypes.int32) - b = weak_tensor.WeakTensor(constant_op.constant(1, dtypes.int32)) + b = WeakTensor.from_tensor(a) @def_function.function() def f(c, a, b): @@ -155,10 +129,11 @@ def f(c, a, b): # if and else branch cannot return two different types in a tf.function. _ = f(constant_op.constant(2, dtypes.int32), a, b) + @test_util.run_in_graph_and_eager_modes def test_weak_tensor_in_tf_func_with_spec(self): # Test weak tensor spec with matching input. - weak_tensor_spec = weak_tensor.WeakTensor.Spec(tensor.TensorSpec([2])) - wt = weak_tensor.WeakTensor(constant_op.constant([1.0, 2.0])) + weak_tensor_spec = WeakTensor.Spec(tensor.TensorSpec([2])) + wt = WeakTensor.from_tensor(constant_op.constant([1.0, 2.0])) @def_function.function(input_signature=[weak_tensor_spec]) def f(x): @@ -166,20 +141,21 @@ def f(x): _ = f(wt) # Test weak tensor spec with mismatching input. - wt_mismatch = weak_tensor.WeakTensor(constant_op.constant([1.0, 2.0, 3.0])) + wt_mismatch = WeakTensor.from_tensor(constant_op.constant([1.0, 2.0, 3.0])) with self.assertRaises(TypeError): _ = f(wt_mismatch) + @test_util.run_in_graph_and_eager_modes def test_weak_tensor_gradient(self): - x = weak_tensor.WeakTensor(constant_op.constant([3.0, 4.0, 5.0])) + x = WeakTensor.from_tensor(constant_op.constant([3.0, 4.0, 5.0])) with backprop.GradientTape() as g: g.watch(x) y = x dy_dx = g.gradient(y, x) - self.assertEqual( - dy_dx, weak_tensor.WeakTensor(constant_op.constant([1.0, 1.0, 1.0])) - ) + self.assertAllEqual(dy_dx, [1.0, 1.0, 1.0]) + self.assertIsInstance(dy_dx, WeakTensor) + @test_util.run_in_graph_and_eager_modes def test_weak_tensor_in_restored_function(self): class CustomModule(module.Module): @@ -190,13 +166,13 @@ def __call__(self, x): return x m = CustomModule() - a = weak_tensor.WeakTensor(constant_op.constant(1, dtypes.int32)) + a = WeakTensor.from_tensor(constant_op.constant(1, dtypes.int32)) _ = m(a) save(m, '/tmp/f') m_loaded = load('/tmp/f') res = m_loaded(a) - self.assertIsInstance(res, weak_tensor.WeakTensor) + self.assertIsInstance(res, WeakTensor) b = constant_op.constant(1, dtypes.int32) with self.assertRaisesRegex( @@ -207,7 +183,7 @@ def __call__(self, x): def test_weak_tensor_format_to_string(self): # __str__ test in eager mode t = constant_op.constant([1.0, 2.0], dtypes.float32) - wt = weak_tensor.WeakTensor(t) + wt = WeakTensor(t) wt_str = 'tf.Tensor([1. 2.], shape=(2,), dtype=float32, weak=True)' self.assertEqual(str(wt), wt_str) @@ -222,7 +198,7 @@ def test_weak_tensor_format_to_string(self): def f(): # __str__ test in graph mode t = constant_op.constant([1.0, 2.0], dtypes.float32) - wt = weak_tensor.WeakTensor(t) + wt = WeakTensor(t) wt_str = 'Tensor("Const:0", shape=(2,), dtype=float32, weak=True)' self.assertEqual(str(wt), wt_str) @@ -233,6 +209,53 @@ def f(): _ = f() + def test_weak_tensor_iter(self): + # Test normal weakTensor iteration. + t = constant_op.constant([0, 1, 2], dtypes.int32) + wt = WeakTensor.from_tensor(t) + it_weak_tensor = iter(wt) + for i in range(len(wt)): + self.assertAllEqual( + next(it_weak_tensor), + WeakTensor.from_tensor(constant_op.constant(i)), + ) + + # Test multi-dimensional weakTensor iteration. + t_multi = constant_op.constant([[1, 2], [3, 4]], dtypes.int32) + wt_multi = WeakTensor(t_multi) + it_wt_multi_tensor = iter(wt_multi) + self.assertAllEqual( + next(it_wt_multi_tensor), WeakTensor.from_tensor(t_multi[0]) + ) + self.assertAllEqual( + next(it_wt_multi_tensor), WeakTensor.from_tensor(t_multi[1]) + ) + + # Test scalar weakTensor iteration. + t_scalar = constant_op.constant(1, dtypes.int32) + wt_scalar = WeakTensor.from_tensor(t_scalar) + with self.assertRaises(TypeError): + # Cannot iterate over a scalar tensor. + _ = iter(wt_scalar) + + @test_util.deprecated_graph_mode_only + def test_weak_tensor_iter_graph_mode(self): + # Make sure iteration is not allowed in Graph mode. + wt = WeakTensor.from_tensor(constant_op.constant([0, 1, 2], dtypes.int32)) + with self.assertRaisesRegex( + errors.OperatorNotAllowedInGraphError, + 'Iterating over a symbolic `tf.WeakTensor` is not allowed. You can' + ' attempt the following resolutions to the problem: If you are running' + ' in Graph mode, use Eager execution mode or decorate this function' + ' with @tf.function. If you are using AutoGraph, you can try decorating' + ' this function with @tf.function. If that does not work, then you may' + ' be using an unsupported feature or your source code may not be' + ' visible to AutoGraph. See' + ' https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/reference/limitations.md#access-to-source-code' + ' for more information.', + ): + _ = iter(wt) + if __name__ == '__main__': ops.enable_eager_execution() diff --git a/tensorflow/python/ops/weak_tensor_ops.py b/tensorflow/python/ops/weak_tensor_ops.py index 75535ab505f8ce..44662802413c10 100644 --- a/tensorflow/python/ops/weak_tensor_ops.py +++ b/tensorflow/python/ops/weak_tensor_ops.py @@ -97,7 +97,9 @@ def wrapper(*args, **kwargs): # Only return WeakTensor when dtype is NOT specified. if bound_kwargs.get("dtype", None) is not None: is_weak = False - return weak_tensor.maybe_convert_to_weak_tensor(op(**bound_kwargs), is_weak) + return weak_tensor.convert_to_weak_tensor_or_tensor( + op(**bound_kwargs), is_weak + ) wrapper = tf_decorator.make_decorator(op, wrapper) @@ -144,7 +146,9 @@ def wrapper(*args, **kwargs): bound_kwargs[x_arg_name] = _convert_or_cast(x, target_type, "x") bound_kwargs[y_arg_name] = _convert_or_cast(y, target_type, "y") - return weak_tensor.maybe_convert_to_weak_tensor(op(**bound_kwargs), is_weak) + return weak_tensor.convert_to_weak_tensor_or_tensor( + op(**bound_kwargs), is_weak + ) wrapper = tf_decorator.make_decorator(op, wrapper) diff --git a/tensorflow/python/ops/weak_tensor_test_util.py b/tensorflow/python/ops/weak_tensor_test_util.py index 2eb399cbda0c47..71b2097795fbc8 100644 --- a/tensorflow/python/ops/weak_tensor_test_util.py +++ b/tensorflow/python/ops/weak_tensor_test_util.py @@ -24,7 +24,7 @@ def convert_to_input_type(base_input, input_type, dtype=None): if input_type == "WeakTensor": - return WeakTensor(constant_op.constant(base_input, dtype=dtype)) + return WeakTensor.from_tensor(constant_op.constant(base_input, dtype=dtype)) elif input_type == "Tensor": return constant_op.constant(base_input, dtype=dtype) elif input_type == "NumPy": @@ -37,7 +37,7 @@ def convert_to_input_type(base_input, input_type, dtype=None): def get_weak_tensor(*args, **kwargs): - return WeakTensor(constant_op.constant(*args, **kwargs)) + return WeakTensor.from_tensor(constant_op.constant(*args, **kwargs)) class DtypeConversionTestEnv: From d640178dfae5dd2f8595702a239ebeacf82cf024 Mon Sep 17 00:00:00 2001 From: pjpratik <118897289+pjpratik@users.noreply.github.com> Date: Thu, 27 Jul 2023 22:23:09 +0530 Subject: [PATCH 238/410] Update image_searcher.md --- .../inference_with_metadata/task_library/image_searcher.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/lite/g3doc/inference_with_metadata/task_library/image_searcher.md b/tensorflow/lite/g3doc/inference_with_metadata/task_library/image_searcher.md index 95a0294e66f4b5..fa1a901f17fcc8 100644 --- a/tensorflow/lite/g3doc/inference_with_metadata/task_library/image_searcher.md +++ b/tensorflow/lite/g3doc/inference_with_metadata/task_library/image_searcher.md @@ -71,9 +71,9 @@ dependencies { // Other dependencies // Import the Task Vision Library dependency (NNAPI is included) - implementation 'org.tensorflow:tensorflow-lite-task-vision:0.4.0' + implementation 'org.tensorflow:tensorflow-lite-task-vision:0.4.4' // Import the GPU delegate plugin Library for GPU inference - implementation 'org.tensorflow:tensorflow-lite-gpu-delegate-plugin:0.4.0' + implementation 'org.tensorflow:tensorflow-lite-gpu-delegate-plugin:0.4.4' } ``` From 79f5dee582e3d3ff9f1e2ca09344fb7ab3ffe274 Mon Sep 17 00:00:00 2001 From: pjpratik <118897289+pjpratik@users.noreply.github.com> Date: Thu, 27 Jul 2023 22:24:34 +0530 Subject: [PATCH 239/410] Update nl_classifier.md --- .../inference_with_metadata/task_library/nl_classifier.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow/lite/g3doc/inference_with_metadata/task_library/nl_classifier.md b/tensorflow/lite/g3doc/inference_with_metadata/task_library/nl_classifier.md index 8ce43494bd9912..2c4b2459aca5d1 100644 --- a/tensorflow/lite/g3doc/inference_with_metadata/task_library/nl_classifier.md +++ b/tensorflow/lite/g3doc/inference_with_metadata/task_library/nl_classifier.md @@ -54,9 +54,9 @@ dependencies { // Other dependencies // Import the Task Vision Library dependency (NNAPI is included) - implementation 'org.tensorflow:tensorflow-lite-task-text:0.3.0' + implementation 'org.tensorflow:tensorflow-lite-task-text:0.4.4' // Import the GPU delegate plugin Library for GPU inference - implementation 'org.tensorflow:tensorflow-lite-gpu-delegate-plugin:0.3.0' + implementation 'org.tensorflow:tensorflow-lite-gpu-delegate-plugin:0.4.4' } ``` @@ -94,7 +94,7 @@ Add the TensorFlowLiteTaskText pod in Podfile ``` target 'MySwiftAppWithTaskAPI' do use_frameworks! - pod 'TensorFlowLiteTaskText', '~> 0.4.3' + pod 'TensorFlowLiteTaskText', '~> 0.4.4' end ``` From 7b9c7e6ba2ffb365ee6498132ef8c4b608033ce3 Mon Sep 17 00:00:00 2001 From: pjpratik <118897289+pjpratik@users.noreply.github.com> Date: Thu, 27 Jul 2023 22:25:43 +0530 Subject: [PATCH 240/410] Update text_searcher.md --- .../inference_with_metadata/task_library/text_searcher.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/lite/g3doc/inference_with_metadata/task_library/text_searcher.md b/tensorflow/lite/g3doc/inference_with_metadata/task_library/text_searcher.md index 7179d989a5a9db..000d2b9ed28ddf 100644 --- a/tensorflow/lite/g3doc/inference_with_metadata/task_library/text_searcher.md +++ b/tensorflow/lite/g3doc/inference_with_metadata/task_library/text_searcher.md @@ -79,9 +79,9 @@ dependencies { // Other dependencies // Import the Task Vision Library dependency (NNAPI is included) - implementation 'org.tensorflow:tensorflow-lite-task-vision:0.4.0' + implementation 'org.tensorflow:tensorflow-lite-task-vision:0.4.4' // Import the GPU delegate plugin Library for GPU inference - implementation 'org.tensorflow:tensorflow-lite-gpu-delegate-plugin:0.4.0' + implementation 'org.tensorflow:tensorflow-lite-gpu-delegate-plugin:0.4.4' } ``` From a53ed37fd464b57e1f3c8ed82e7da6c28de550d2 Mon Sep 17 00:00:00 2001 From: pjpratik <118897289+pjpratik@users.noreply.github.com> Date: Thu, 27 Jul 2023 22:27:07 +0530 Subject: [PATCH 241/410] Update bert_nl_classifier.md --- .../task_library/bert_nl_classifier.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/lite/g3doc/inference_with_metadata/task_library/bert_nl_classifier.md b/tensorflow/lite/g3doc/inference_with_metadata/task_library/bert_nl_classifier.md index 3e8a62413c2a12..f156880316ebb5 100644 --- a/tensorflow/lite/g3doc/inference_with_metadata/task_library/bert_nl_classifier.md +++ b/tensorflow/lite/g3doc/inference_with_metadata/task_library/bert_nl_classifier.md @@ -49,7 +49,7 @@ dependencies { // Other dependencies // Import the Task Text Library dependency (NNAPI is included) - implementation 'org.tensorflow:tensorflow-lite-task-text:0.3.0' + implementation 'org.tensorflow:tensorflow-lite-task-text:0.4.4' } ``` @@ -85,7 +85,7 @@ Add the TensorFlowLiteTaskText pod in Podfile ``` target 'MySwiftAppWithTaskAPI' do use_frameworks! - pod 'TensorFlowLiteTaskText', '~> 0.2.0' + pod 'TensorFlowLiteTaskText', '~> 0.4.4' end ``` From 40881181e736d1cca564b383efa3ae720a81ba9d Mon Sep 17 00:00:00 2001 From: pjpratik <118897289+pjpratik@users.noreply.github.com> Date: Thu, 27 Jul 2023 22:28:03 +0530 Subject: [PATCH 242/410] Update bert_question_answerer.md --- .../task_library/bert_question_answerer.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/lite/g3doc/inference_with_metadata/task_library/bert_question_answerer.md b/tensorflow/lite/g3doc/inference_with_metadata/task_library/bert_question_answerer.md index 1e40d1d03e55b7..2754eb1bd0ba9a 100644 --- a/tensorflow/lite/g3doc/inference_with_metadata/task_library/bert_question_answerer.md +++ b/tensorflow/lite/g3doc/inference_with_metadata/task_library/bert_question_answerer.md @@ -49,7 +49,7 @@ dependencies { // Other dependencies // Import the Task Text Library dependency (NNAPI is included) - implementation 'org.tensorflow:tensorflow-lite-task-text:0.3.0' + implementation 'org.tensorflow:tensorflow-lite-task-text:0.4.4' } ``` @@ -86,7 +86,7 @@ Add the TensorFlowLiteTaskText pod in Podfile ``` target 'MySwiftAppWithTaskAPI' do use_frameworks! - pod 'TensorFlowLiteTaskText', '~> 0.2.0' + pod 'TensorFlowLiteTaskText', '~> 0.4.4' end ``` From c2aba639c0df2c825487d1479b0687969799c2b8 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 27 Jul 2023 09:54:51 -0700 Subject: [PATCH 243/410] Move legalize-tf-types pass to MOT directory This pass is currently used to lower qint quantized types. Move this to quantization/ directory to consolidate with other passes for lowering UniformQuantized ops. Also rename pass to convert-tf-quant-types to better reflect its intent. PiperOrigin-RevId: 551561491 --- tensorflow/compiler/mlir/BUILD | 1 + .../mlir/lite/stablehlo/odml_to_stablehlo.cc | 1 - tensorflow/compiler/mlir/python/BUILD | 1 + tensorflow/compiler/mlir/python/mlir.cc | 3 +- .../mlir/quantization/stablehlo/BUILD | 26 ++++++++- .../passes/bridge/convert_tf_quant_types.cc} | 56 ++++++++----------- .../bridge/convert_tf_quant_types_test.cc} | 23 +++++--- .../stablehlo/passes/bridge/passes.h | 6 ++ .../stablehlo/passes/bridge/passes.td | 13 +++++ .../tests/bridge/convert-tf-quant-types.mlir} | 2 +- .../mlir/tf2xla/api/v0/compile_mlir_util.cc | 3 +- .../compiler/mlir/tf2xla/transforms/BUILD | 22 -------- .../compiler/mlir/tf2xla/transforms/passes.h | 4 -- .../transforms/xla_legalize_tf_passes.td | 18 ------ tensorflow/compiler/mlir/tf_mlir_opt_main.cc | 3 +- 15 files changed, 91 insertions(+), 91 deletions(-) rename tensorflow/compiler/mlir/{tf2xla/transforms/legalize_tf_types.cc => quantization/stablehlo/passes/bridge/convert_tf_quant_types.cc} (82%) rename tensorflow/compiler/mlir/{tf2xla/transforms/legalize_tf_types_test.cc => quantization/stablehlo/passes/bridge/convert_tf_quant_types_test.cc} (83%) rename tensorflow/compiler/mlir/{tf2xla/tests/legalize-tf-types.mlir => quantization/stablehlo/tests/bridge/convert-tf-quant-types.mlir} (97%) diff --git a/tensorflow/compiler/mlir/BUILD b/tensorflow/compiler/mlir/BUILD index dc58c4ddbc18a9..fc91afaec875b2 100644 --- a/tensorflow/compiler/mlir/BUILD +++ b/tensorflow/compiler/mlir/BUILD @@ -50,6 +50,7 @@ cc_library( ":register_common_dialects", "//tensorflow/compiler/mlir/lite:tensorflow_lite", "//tensorflow/compiler/mlir/lite:tf_tfl_passes", # buildcleaner:keep + "//tensorflow/compiler/mlir/quantization/stablehlo:bridge_passes", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:bridge_pass_test_pipeline_registration", "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_pass_registration", diff --git a/tensorflow/compiler/mlir/lite/stablehlo/odml_to_stablehlo.cc b/tensorflow/compiler/mlir/lite/stablehlo/odml_to_stablehlo.cc index 2df891a82022b4..ac41944f862a3e 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/odml_to_stablehlo.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/odml_to_stablehlo.cc @@ -381,7 +381,6 @@ void initAllPasses() { // These are in compiler/mlir/tf2xla and not part of the above MHLO passes. mlir::mhlo::registerTfXlaPasses(); mlir::mhlo::registerLegalizeTFPass(); - mlir::mhlo::registerLegalizeTfTypesPassPass(); mlir::xla_framework::registerXlaFrameworkPasses(); tensorflow::RegisterConvertMlirToXlaHloPipelineWithDefaults(); tensorflow::RegisterGraphOptimizationPasses(); diff --git a/tensorflow/compiler/mlir/python/BUILD b/tensorflow/compiler/mlir/python/BUILD index de42208c0aa9fc..54ebc8481ba178 100644 --- a/tensorflow/compiler/mlir/python/BUILD +++ b/tensorflow/compiler/mlir/python/BUILD @@ -15,6 +15,7 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", + "//tensorflow/compiler/mlir/quantization/stablehlo:bridge_passes", "@com_google_absl//absl/strings", "@llvm-project//mlir:FuncExtensions", "@llvm-project//llvm:Support", diff --git a/tensorflow/compiler/mlir/python/mlir.cc b/tensorflow/compiler/mlir/python/mlir.cc index 85d8d4e087904a..40de3d6ef789d3 100644 --- a/tensorflow/compiler/mlir/python/mlir.cc +++ b/tensorflow/compiler/mlir/python/mlir.cc @@ -49,6 +49,7 @@ limitations under the License. #include "tensorflow/c/tf_status_helper.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_import.h" #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.h" #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h" @@ -96,7 +97,7 @@ static void RegisterPasses() { // passes. mlir::mhlo::registerTfXlaPasses(); mlir::mhlo::registerLegalizeTFPass(); - mlir::mhlo::registerLegalizeTfTypesPassPass(); + mlir::stablehlo::registerBridgePasses(); mlir::tosa::registerLegalizeTosaPasses(); mlir::tosa::registerTFtoTOSALegalizationPipeline(); mlir::tosa::registerTFLtoTOSALegalizationPipeline(); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD index aa796e88734e4b..ccee2dc1f1408e 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD @@ -113,14 +113,14 @@ cc_library( srcs = [ "passes/bridge/convert_mhlo_quant_to_int.cc", "passes/bridge/convert_tf_quant_ops_to_mhlo.cc", + "passes/bridge/convert_tf_quant_types.cc", ], hdrs = [ "passes/bridge/passes.h", ], compatible_with = get_compatible_with_portable(), visibility = [ - "//tensorflow/compiler/mlir/lite:__subpackages__", - "//tensorflow/compiler/mlir/tf2xla:__subpackages__", + "//tensorflow/compiler/mlir:__subpackages__", ], deps = [ ":bridge_passes_inc_gen", @@ -135,6 +135,7 @@ cc_library( "//tensorflow/compiler/xla/mlir_hlo:chlo_legalize_to_hlo", "//tensorflow/compiler/xla/translate/hlo_to_mhlo:attribute_importer", "//tensorflow/core:framework", + "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core/framework:numeric_types", "//tensorflow/core/util/quantization:uniform_quant_ops_attr_proto_cc", @@ -182,6 +183,27 @@ tf_cc_test( ], ) +tf_cc_test( + name = "convert_tf_quant_types_test", + srcs = ["passes/bridge/convert_tf_quant_types_test.cc"], + deps = [ + ":bridge_passes", + "//tensorflow/compiler/mlir:register_common_dialects", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:serialize_mlir_module_utils", + "//tensorflow/core:test", + "//tensorflow/core/lib/monitoring:cell_reader", + "//tensorflow/tsl/platform:statusor", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + ], +) + cc_library( name = "quantize_passes", srcs = [ diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_types.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_types.cc similarity index 82% rename from tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_types.cc rename to tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_types.cc index 7879389fa24239..ef79092b8f7deb 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_types.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_types.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,8 +17,6 @@ limitations under the License. // some generic types that are legal in MHLO. This pass legalizes TF types into // types that are legal in MHLO. For example, TF::Qint8Type is converted to i8. // Rewrites here should run before TF to MHLO op legalizations are run. -// TODO(b/180234029): The rewrite here should be part of the LegalizeTF pass -// rather than its own pass. #include #include @@ -39,12 +37,13 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/core/lib/monitoring/counter.h" -#define DEBUG_TYPE "xla-legalize-tf-types" - namespace mlir { -namespace mhlo { +namespace stablehlo { namespace { +#define GEN_PASS_DEF_CONVERTTFQUANTTYPES +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.h.inc" + // TODO: b/290366702 - Temporarily added metrics for debugging. auto *mlir_tf_quant_op_count = tensorflow::monitoring::Counter<1>::New( "/tensorflow/core/tf2xla/tf_quant_op_count" /*metric_name*/, @@ -101,9 +100,6 @@ bool IsUnSupportedOp(Operation *op) { >(op); } -// TODO(b/180234863): What's below this line is generic so convert it to a -// utility. - bool IsIllegalType(Type type) { return IsIllegalElementType(getElementTypeOrSelf(type)); } @@ -117,22 +113,20 @@ Type ToLegalType(Type type) { return type; } -class TfTypeConverter : public TypeConverter { +class TFQuantTypeConverter : public TypeConverter { public: - TfTypeConverter() { + TFQuantTypeConverter() { addConversion([](Type type) -> Type { return IsIllegalType(type) ? ToLegalType(type) : type; }); } }; -// An Op is illegal iff it contains an illegalType. -// TODO: b/289560952 - Move quantization related passes to MOT directories. Also -// reconsider the correct way to handle conversions of quantized types without -// quantization ops. -class TfTypeConversionTarget : public ConversionTarget { +// An Op is illegal iff it is non-UQ op and it contains qint types. +class TFQuantTypeConversionTarget : public ConversionTarget { public: - explicit TfTypeConversionTarget(MLIRContext &ctx, TfTypeConverter &converter) + explicit TFQuantTypeConversionTarget(MLIRContext &ctx, + TFQuantTypeConverter &converter) : ConversionTarget(ctx), converter_(converter) { markUnknownOpDynamicallyLegal([this](Operation *op) { // Do not convert UnifromQuantized ops. @@ -149,12 +143,12 @@ class TfTypeConversionTarget : public ConversionTarget { } private: - TfTypeConverter &converter_; + TFQuantTypeConverter &converter_; }; -class TfTypePattern : public ConversionPattern { +class TFQuantTypePattern : public ConversionPattern { public: - TfTypePattern(MLIRContext *ctx, TypeConverter &converter) + TFQuantTypePattern(MLIRContext *ctx, TypeConverter &converter) : ConversionPattern(converter, MatchAnyOpTypeTag(), 1, ctx) {} // The dialect conversion framework will call this matchAndRewrite on each @@ -189,30 +183,28 @@ class TfTypePattern : public ConversionPattern { } }; -#define GEN_PASS_DEF_LEGALIZETFTYPESPASS -#include "tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf_passes.h.inc" - -struct LegalizeTfTypesPass - : public impl::LegalizeTfTypesPassBase { +struct ConvertTFQuantTypes + : public impl::ConvertTFQuantTypesBase { void runOnOperation() override; }; -void LegalizeTfTypesPass::runOnOperation() { - TfTypeConverter converter; +// TODO: b/289560952 - add qint <-> int casts around TF UQ ops. +void ConvertTFQuantTypes::runOnOperation() { + TFQuantTypeConverter converter; RewritePatternSet patterns(&getContext()); - patterns.add(&getContext(), converter); + patterns.add(&getContext(), converter); populateFunctionOpInterfaceTypeConversionPattern(patterns, converter); - TfTypeConversionTarget target(getContext(), converter); + TFQuantTypeConversionTarget target(getContext(), converter); if (failed(applyFullConversion(getOperation(), target, std::move(patterns)))) return signalPassFailure(); } } // namespace -std::unique_ptr> CreateLegalizeTfTypesPass() { - return std::make_unique(); +std::unique_ptr> CreateConvertTFQuantTypesPass() { + return std::make_unique(); } -} // namespace mhlo +} // namespace stablehlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_types_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_types_test.cc similarity index 83% rename from tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_types_test.cc rename to tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_types_test.cc index e35c00001ffabd..1f21ab98c8746a 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_types_test.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_types_test.cc @@ -19,22 +19,25 @@ limitations under the License. #include #include "absl/strings/string_view.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project -#include "tensorflow/compiler/mlir/tf2xla/transforms/passes.h" -#include "tensorflow/compiler/mlir/tf2xla/transforms/test_utils.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.h" +#include "tensorflow/compiler/mlir/register_common_dialects.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/monitoring/cell_reader.h" #include "tensorflow/tsl/platform/statusor.h" -namespace tensorflow { +namespace mlir { +namespace stablehlo { namespace { +using ::mlir::DialectRegistry; using ::mlir::MLIRContext; using ::mlir::ModuleOp; using ::mlir::OwningOpRef; -using ::mlir::mhlo::test::GetMlirModuleFromString; using ::tensorflow::monitoring::testing::CellReader; static constexpr char kMetricsName[] = @@ -43,12 +46,15 @@ static constexpr char kMetricsName[] = class LegalizeTfTypesTest : public ::testing::Test { protected: void CreateModule(const char* module_string) { - TF_ASSERT_OK_AND_ASSIGN(module_, - GetMlirModuleFromString(module_string, &context_)); + DialectRegistry mlir_registry; + RegisterCommonToolingDialects(mlir_registry); + context_.appendDialectRegistry(mlir_registry); + TF_ASSERT_OK( + tensorflow::DeserializeMlirModule(module_string, &context_, &module_)); pm_ = std::make_unique(&context_); pm_->addNestedPass( - mlir::mhlo::CreateLegalizeTfTypesPass()); + mlir::stablehlo::CreateConvertTFQuantTypesPass()); } mlir::LogicalResult Run() { return pm_->run(module_.get()); } @@ -99,4 +105,5 @@ TEST_F(LegalizeTfTypesTest, RecordsStreamzNoQuantOps) { } } // namespace -} // namespace tensorflow +} // namespace stablehlo +} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.h b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.h index 9318f13ddce049..605e290e3316be 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.h @@ -40,9 +40,15 @@ CreateConvertTFQuantOpsToMHLOPass(); void PopulateLegalizeTfQuantizationPatterns(MLIRContext *context, RewritePatternSet *patterns); +// Creates an instance of the ConvertTFQuantTypes pass, which will convert TF +// qint types to int types and surround TF UniformQuantized ops with qint <-> +// int casts. +std::unique_ptr> CreateConvertTFQuantTypesPass(); + #define GEN_PASS_REGISTRATION #define GEN_PASS_DECL_CONVERTMHLOQUANTTOINT #define GEN_PASS_DECL_CONVERTTFQUANTOPSTOMHLO +#define GEN_PASS_DECL_CONVERTTFQUANTTYPES #include "tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.h.inc" } // namespace stablehlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.td b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.td index 08f1e32796905b..08a2987c03d764 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.td +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.td @@ -48,3 +48,16 @@ def ConvertTFQuantOpsToMHLO : Pass<"quant-convert-tf-quant-ops-to-mhlo", "mlir:: "mhlo::MhloDialect", "tf_type::TFTypeDialect", "quant::QuantizationDialect"]; } + +def ConvertTFQuantTypes : Pass<"convert-tf-quant-types", "mlir::func::FuncOp"> { + let summary = "Replace TensorFlow qint types with int types."; + + let description = [{ + Converts TF ops with qint types to int types. Some UniformQuantized ops + argument/result allow qint type only. For such cases, add qint <-> int + tf.Cast around the ops so that they are still valid. + }]; + + let constructor = "::mlir::stablehlo::CreateConvertTFQuantTypesPass()"; + let dependentDialects = ["TF::TensorFlowDialect", "tf_type::TFTypeDialect"]; +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-types.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert-tf-quant-types.mlir similarity index 97% rename from tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-types.mlir rename to tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert-tf-quant-types.mlir index 2df83ec95b2766..c1fdf2366c7443 100644 --- a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-types.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert-tf-quant-types.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt -xla-legalize-tf-types %s | FileCheck %s +// RUN: stablehlo-quant-opt -convert-tf-quant-types %s | FileCheck %s func.func @relu_qint8(%arg0: tensor<1x!tf_type.qint8>) -> tensor<1x!tf_type.qint8> { // CHECK: func @relu_qint8(%arg0: tensor<1xi8>) -> tensor<1xi8> { diff --git a/tensorflow/compiler/mlir/tf2xla/api/v0/compile_mlir_util.cc b/tensorflow/compiler/mlir/tf2xla/api/v0/compile_mlir_util.cc index faf17af46284eb..4458e16c415c26 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v0/compile_mlir_util.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v0/compile_mlir_util.cc @@ -431,7 +431,8 @@ void CreateConvertMlirToXlaHloPipeline( pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass()); pm.addNestedPass(mlir::TF::CreateLowerQuantizedPass()); - pm.addPass(mlir::mhlo::CreateLegalizeTfTypesPass()); + pm.addNestedPass( + mlir::stablehlo::CreateConvertTFQuantTypesPass()); for (auto& target_pass : custom_legalization_passes) { pm.addNestedPass(std::move(target_pass)); diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/BUILD b/tensorflow/compiler/mlir/tf2xla/transforms/BUILD index f3db3b60632c69..0057624e2a7c4f 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/transforms/BUILD @@ -241,7 +241,6 @@ cc_library( "infeed_ops_xla_adjust_layout.cc", "legalize_tf_collective.cc", "legalize_tf_communication.cc", - "legalize_tf_types.cc", "tf_xla_passes.h.inc", "tfxla_device_specific_transforms.cc", "verify_tfxla_legalization.cc", @@ -333,27 +332,6 @@ cc_library( ], ) -tf_cc_test( - name = "legalize_tf_types_test", - srcs = ["legalize_tf_types_test.cc"], - deps = [ - ":test_utils", - ":xla_legalize_tf", - "//tensorflow/compiler/mlir/tensorflow", - "//tensorflow/compiler/mlir/tensorflow:serialize_mlir_module_utils", - "//tensorflow/core:test", - "//tensorflow/core/lib/monitoring:cell_reader", - "//tensorflow/tsl/platform:statusor", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest_main", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:Support", - ], -) - cc_library( name = "tf2xla_rewriter", srcs = [ diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/passes.h b/tensorflow/compiler/mlir/tf2xla/transforms/passes.h index f6dd6f0b09985f..4b34dfd318b1cb 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/passes.h +++ b/tensorflow/compiler/mlir/tf2xla/transforms/passes.h @@ -59,10 +59,6 @@ std::unique_ptr> createLegalizeTFPass( std::unique_ptr> createLegalizeTFNoFallbackPass( bool allow_partial_conversion = false); -/// Replaces types that do not exist in MHLO with equivalent types that do -/// exist. -std::unique_ptr> CreateLegalizeTfTypesPass(); - /// Converter to be used along with the fallback Tf2Xla patterns below. class Tf2XlaTypeConverter : public TypeConverter { public: diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf_passes.td b/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf_passes.td index 8276f202f8e7ee..babfbd67e10b1a 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf_passes.td +++ b/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf_passes.td @@ -72,24 +72,6 @@ def LegalizeTFNoFallback : Pass<"xla-legalize-tf-no-fallback", "mlir::func::Func "shape::ShapeDialect", "func::FuncDialect", "sparse_tensor::SparseTensorDialect"]; } -def LegalizeTfTypesPass : Pass<"xla-legalize-tf-types"> { - let summary = "Replace TensorFlow types with types that are legal in the MHLO dialect"; - - let description = [{ - The TF dialect uses some TF types that are illegal in the MHLO dialect and - some generic types that are legal in MHLO. This pass legalizes TF types into - types that are legal in MHLO. Rewrites here should run before TF to MHLO op - legalizations are run. - - Specifically, this pass replaces each quantized integer type with the - corresponding ordinary types. For example, `TF::Qint8Type` is replaced with - `i8` everywhere it occurs. Types that are replaced are `TF::Qint8Type`, - `TF::Qint16Type`, `TF::Qint32Type`, `TF::Quint8Type`, and `TF::Quint16Type`. - }]; - - let constructor = "::mlir::mhlo::CreateLegalizeTfTypesPass()"; -} - def LegalizeTFCollective : Pass<"xla-legalize-tf-collective", "ModuleOp"> { let summary = "Legalize TF/XLA collective ops (TensorFlow dialect) to the HLO dialect"; diff --git a/tensorflow/compiler/mlir/tf_mlir_opt_main.cc b/tensorflow/compiler/mlir/tf_mlir_opt_main.cc index 3d0a3826c851a2..2115bd02359448 100644 --- a/tensorflow/compiler/mlir/tf_mlir_opt_main.cc +++ b/tensorflow/compiler/mlir/tf_mlir_opt_main.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow//compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h" #include "tensorflow/compiler/mlir/init_mlir.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.h" #include "tensorflow/compiler/mlir/register_common_dialects.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/test_passes.h" @@ -52,7 +53,7 @@ int main(int argc, char **argv) { // These are in compiler/mlir/tf2xla and not part of the above MHLO passes. mlir::mhlo::registerLegalizeTfPasses(); mlir::mhlo::registerTfXlaPasses(); - mlir::mhlo::registerLegalizeTfTypesPassPass(); + mlir::stablehlo::registerBridgePasses(); mlir::tosa::registerLegalizeTosaPasses(); mlir::tosa::registerTFtoTOSALegalizationPipeline(); mlir::tosa::registerTFLtoTOSALegalizationPipeline(); From d1983430f27e7aafaa6515fff0c2ace59d05de76 Mon Sep 17 00:00:00 2001 From: pjpratik <118897289+pjpratik@users.noreply.github.com> Date: Thu, 27 Jul 2023 22:29:05 +0530 Subject: [PATCH 244/410] Update audio_classifier.md --- .../inference_with_metadata/task_library/audio_classifier.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/lite/g3doc/inference_with_metadata/task_library/audio_classifier.md b/tensorflow/lite/g3doc/inference_with_metadata/task_library/audio_classifier.md index 873e9929db4a1d..5f62b56c0fde2c 100644 --- a/tensorflow/lite/g3doc/inference_with_metadata/task_library/audio_classifier.md +++ b/tensorflow/lite/g3doc/inference_with_metadata/task_library/audio_classifier.md @@ -63,9 +63,9 @@ dependencies { // Other dependencies // Import the Audio Task Library dependency (NNAPI is included) - implementation 'org.tensorflow:tensorflow-lite-task-audio:0.4.0' + implementation 'org.tensorflow:tensorflow-lite-task-audio:0.4.4' // Import the GPU delegate plugin Library for GPU inference - implementation 'org.tensorflow:tensorflow-lite-gpu-delegate-plugin:0.4.0' + implementation 'org.tensorflow:tensorflow-lite-gpu-delegate-plugin:0.4.4' } ``` From 04075f489e7be0de8817fe9d26182ead60b3d87d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 27 Jul 2023 10:02:24 -0700 Subject: [PATCH 245/410] Go: Update generated wrapper functions for TensorFlow ops. PiperOrigin-RevId: 551563589 --- tensorflow/go/op/wrappers.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 82c2c91141d63c..168bc2e5a6868f 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -7147,7 +7147,7 @@ func ConvGroups(value int64) ConvAttr { } } -// Computes a N-D convolution given (N+2)-D `input` and `filter` tensors. +// Computes a N-D convolution given (N+1+batch_dims)-D `input` and (N+2)-D `filter` tensors. // // General function for computing a N-D convolution. It is required that // `1 <= N <= 3`. @@ -7173,7 +7173,7 @@ func ConvGroups(value int64) ConvAttr { // // padding: The type of padding algorithm to use. // -// Returns A 4-D tensor. The dimension order is determined by the value of +// Returns A (N+1+batch_dims)-D tensor. The dimension order is determined by the value of // `channels_last_format`, see below for details. func Conv(scope *Scope, input tf.Output, filter tf.Output, strides []int64, padding string, optional ...ConvAttr) (output tf.Output) { if scope.Err() != nil { From 362f2fbc72190b230638f2ae6315beaedf61524b Mon Sep 17 00:00:00 2001 From: Yishuang Pang Date: Thu, 27 Jul 2023 10:12:24 -0700 Subject: [PATCH 246/410] Allow dynamic shaped starting indices when legalizing mhlo.gather op. PiperOrigin-RevId: 551567029 --- .../mlir/tensorflow/tests/legalize_hlo.mlir | 39 ++++++++++++++++++- .../tensorflow/transforms/legalize_hlo.cc | 10 +++-- 2 files changed, 44 insertions(+), 5 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir index 2b11d4f77e2860..ed5fad4c67beea 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir @@ -2769,6 +2769,43 @@ func.func @convert_gather_to_slice_batch_size_1(%arg0: tensor<1x2944xi32>, %arg1 func.return %0 : tensor<1x1504xi32> } +// CHECK-LABEL: func @convert_gather_slice_dynamic_indices( +// CHECK-SAME: %[[ARG_0:.*]]: tensor<256000x1024xi8>, +// CHECK-SAME: %[[ARG_1:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_0:.*]] = "tf.GatherNd"(%[[ARG_0]], %[[ARG_1]]) : (tensor<256000x1024xi8>, tensor) -> tensor +// CHECK: return %[[VAL_0]] : tensor +// CHECK: } +func.func @convert_gather_slice_dynamic_indices(%arg0: tensor<256000x1024xi8>, %arg1: tensor) -> tensor { + %0 = "mhlo.gather"(%arg0, %arg1) { + dimension_numbers = #mhlo.gather< + offset_dims = [2], + collapsed_slice_dims = [0], + start_index_map = [0], + index_vector_dim = 2 + >, + slice_sizes = dense<[1, 1024]> : tensor<2xi64> + } : (tensor<256000x1024xi8>, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: func @convert_gather_scalar_dynamic_indices( +// CHECK-SAME: %[[ARG_0:.*]]: tensor<256000xf32>, +// CHECK-SAME: %[[ARG_1:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_0:.*]] = "tf.GatherNd"(%[[ARG_0]], %[[ARG_1]]) : (tensor<256000xf32>, tensor) -> tensor +// CHECK: return %[[VAL_0]] : tensor +// CHECK: } +func.func @convert_gather_scalar_dynamic_indices(%arg0: tensor<256000xf32>, %arg1: tensor) -> tensor { + %0 = "mhlo.gather"(%arg0, %arg1) { + dimension_numbers = #mhlo.gather< + collapsed_slice_dims = [0], + start_index_map = [0], + index_vector_dim = 2 + >, + slice_sizes = dense<1> : tensor<1xi64> + } : (tensor<256000xf32>, tensor) -> tensor + func.return %0 : tensor +} + // CHECK-LABEL: func @convert_gather_to_slice( // CHECK-SAME: %[[ARG_0:.*]]: tensor<3x2944xi32>, // CHECK-SAME: %[[ARG_1:.*]]: tensor<3x2xi32>) @@ -2812,7 +2849,7 @@ func.func @convert_gather_to_slice(%arg0: tensor<3x2944xi32>, %arg1: tensor<3x2x // CHECK-LABEL: func @convert_gather_to_slice_dynamic_error func.func @convert_gather_to_slice_dynamic_error(%arg0: tensor<3x?xi32>, %arg1: tensor<3x2xi32>) -> tensor<3x1504xi32> { - // expected-error @+1 {{Dynamic shaped inputs are not supported.}} + // expected-error @+1 {{Dynamic shaped operand is not supported.}} %0 = "mhlo.gather"(%arg0, %arg1) { dimension_numbers = #mhlo.gather< offset_dims = [1], diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc index 9cee8917a8b60f..b577f603f7085b 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc @@ -2984,8 +2984,8 @@ class ConvertGatherOp : public OpConversionPattern { ShapedType operand_type = operand.getType().cast(); ShapedType start_indices_type = start_indices.getType().cast(); ShapedType result_type = gather_op.getResult().getType().cast(); - if (!operand_type.hasStaticShape() || - !start_indices_type.hasStaticShape() || !result_type.hasStaticShape()) { + if (!operand_type.hasStaticShape()) { + gather_op.emitOpError() << "Dynamic shaped operand is not supported."; return failure(); } @@ -3113,8 +3113,10 @@ class ConvertGatherOp : public OpConversionPattern { ShapedType result_type = gather_op.getResult().getType().cast(); if (!operand_type.hasStaticShape() || !start_indices_type.hasStaticShape() || !result_type.hasStaticShape()) { - gather_op.emitOpError() << "Dynamic shaped inputs are not supported."; - return failure(); + return rewriter.notifyMatchFailure( + gather_op, + "Dynamic shaped inputs are not supported when legalizing mhlo.gather " + "op to tf.slice."); } auto start_index_map = gather_op.getDimensionNumbers().getStartIndexMap(); From 97a878f7d283e264618708c2fd67c127574661fb Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 27 Jul 2023 10:15:44 -0700 Subject: [PATCH 247/410] Move the ComplexType handling into the type lowering pass for collectives. PiperOrigin-RevId: 551568154 --- .../mlir/dtensor_collective_type_lowering.cc | 152 ++++++++++++++- .../dtensor_collective_type_lowering.mlir | 52 +++++ .../dtensor/mlir/utils/collective_lowering.cc | 178 +----------------- 3 files changed, 205 insertions(+), 177 deletions(-) diff --git a/tensorflow/dtensor/mlir/dtensor_collective_type_lowering.cc b/tensorflow/dtensor/mlir/dtensor_collective_type_lowering.cc index d49c397237414d..9c3171cbf0131d 100644 --- a/tensorflow/dtensor/mlir/dtensor_collective_type_lowering.cc +++ b/tensorflow/dtensor/mlir/dtensor_collective_type_lowering.cc @@ -152,7 +152,118 @@ mlir::LogicalResult ConvertShortIntReduce(ReduceOpType reduce_op) { return mlir::success(); } -// A Walk that allows mutatatoin inside parent. +// Complex for AllReduce and ReduceScatter +template +mlir::LogicalResult ConvertComplexReduce(ReduceOpType reduce_op) { + ReduceOpType real_reduce_op; + ReduceOpType imag_reduce_op; + mlir::OpBuilder builder(reduce_op); + StatusOr output_layout = ExtractRequiredSingleLayoutFromOp(reduce_op); + if (!output_layout.ok()) { + return reduce_op.emitOpError(output_layout.status().message()); + } + + const mlir::Value tensor_input = reduce_op.getInput(); + const mlir::Value tensor_result = reduce_op.getResult(); + const mlir::TensorType complex_input_tensor_type = + tensor_input.getType().dyn_cast(); + if (!complex_input_tensor_type) { + return mlir::success(); + } + const mlir::TensorType complex_result_tensor_type = + tensor_result.getType().dyn_cast(); + if (!complex_result_tensor_type) { + return mlir::success(); + } + auto input_element_type = mlir::dyn_cast( + complex_input_tensor_type.getElementType()); + if (!input_element_type) { + return mlir::success(); + } + auto real_input_tensor_type = + mlir::RankedTensorType::get(complex_input_tensor_type.getShape(), + input_element_type.getElementType()); + auto real_result_tensor_type = + mlir::RankedTensorType::get(complex_result_tensor_type.getShape(), + input_element_type.getElementType()); + const mlir::Value tensor_temp_real = builder.create( + reduce_op.getLoc(), real_input_tensor_type, tensor_input); + const mlir::Value tensor_temp_imag = builder.create( + reduce_op.getLoc(), real_input_tensor_type, tensor_input); + real_reduce_op = mlir::dyn_cast(builder.clone(*reduce_op)); + real_reduce_op->setOperand(0, tensor_temp_real); + real_reduce_op->getResult(0).setType(real_result_tensor_type); + imag_reduce_op = mlir::dyn_cast(builder.clone(*reduce_op)); + imag_reduce_op->setOperand(0, tensor_temp_imag); + imag_reduce_op->getResult(0).setType(real_result_tensor_type); + const mlir::Type output_type = reduce_op.getResult().getType(); + auto complex_reduce_op = builder.create( + reduce_op->getLoc(), output_type, real_reduce_op.getResult(), + imag_reduce_op.getResult()); + StatusOr desired_layout = + ExtractRequiredSingleLayoutFromOp(reduce_op); + SetSingleLayoutOnOp(complex_reduce_op, *desired_layout); + reduce_op.getOutput().replaceAllUsesWith(complex_reduce_op.getResult()); + reduce_op.erase(); + return mlir::success(); +} + +// Complex for AllToAll, AllGather, and AllScatter +template +mlir::LogicalResult ConvertComplexCollectives(CollectiveType op) { + CollectiveType real_op; + CollectiveType imag_op; + mlir::OpBuilder builder(op); + StatusOr output_layout = ExtractRequiredSingleLayoutFromOp(op); + if (!output_layout.ok()) { + return op.emitOpError(output_layout.status().message()); + } + + const mlir::Value tensor_input = op.getInput(); + const mlir::Value tensor_result = op.getResult(); + const mlir::TensorType complex_input_tensor_type = + tensor_input.getType().dyn_cast(); + if (!complex_input_tensor_type) { + return mlir::success(); + } + const mlir::TensorType& complex_result_tensor_type = + tensor_result.getType().dyn_cast(); + if (!complex_result_tensor_type) { + return mlir::success(); + } + + auto input_element_type = mlir::dyn_cast( + complex_input_tensor_type.getElementType()); + if (!input_element_type) { + return mlir::success(); + } + auto real_input_tensor_type = + mlir::RankedTensorType::get(complex_input_tensor_type.getShape(), + input_element_type.getElementType()); + auto real_result_tensor_type = + mlir::RankedTensorType::get(complex_result_tensor_type.getShape(), + input_element_type.getElementType()); + const mlir::Value tensor_temp_real = builder.create( + op.getLoc(), real_input_tensor_type, tensor_input); + const mlir::Value tensor_temp_imag = builder.create( + op.getLoc(), real_input_tensor_type, tensor_input); + real_op = mlir::dyn_cast(builder.clone(*op)); + real_op->setOperand(0, tensor_temp_real); + real_op->getResult(0).setType(real_result_tensor_type); + imag_op = mlir::dyn_cast(builder.clone(*op)); + imag_op->setOperand(0, tensor_temp_imag); + imag_op->getResult(0).setType(real_result_tensor_type); + const mlir::Type output_type = op.getResult().getType(); + auto complex_op = builder.create( + op.getLoc(), output_type, real_op.getResult(), imag_op.getResult()); + const Layout desired_layout = op.getOutputLayout(); + SetSingleLayoutOnOp(complex_op, desired_layout); + op.getOutput().replaceAllUsesWith(complex_op.getResult()); + op.erase(); + return mlir::success(); +} + +// A Walk that allows mutation inside parent. template > mlir::LogicalResult MutatingWalk(mlir::Operation* parent, FuncT func) { llvm::SmallVector ops; @@ -172,6 +283,45 @@ class DTensorCollectiveTypeLoweringPass void runOnOperation() override { mlir::func::FuncOp func = getOperation(); + if (mlir::failed( + MutatingWalk(func, [&](mlir::TF::DTensorAllReduceOp all_reduce) { + // Lower integer type all reduce + return ConvertComplexReduce(all_reduce); + }))) { + signalPassFailure(); + } + + if (mlir::failed( + MutatingWalk(func, [&](mlir::TF::DTensorAllScatterOp all_scatter) { + // Lower complex type all scatter + return ConvertComplexCollectives(all_scatter); + }))) { + signalPassFailure(); + } + + if (mlir::failed( + MutatingWalk(func, [&](mlir::TF::DTensorAllGatherOp all_gather) { + // Lower complex type all gather. + return ConvertComplexCollectives(all_gather); + }))) { + signalPassFailure(); + } + + if (mlir::failed( + MutatingWalk(func, [&](mlir::TF::DTensorAllToAllOp all_to_all) { + // Lower complex type all to all + return ConvertComplexCollectives(all_to_all); + }))) { + signalPassFailure(); + } + + if (mlir::failed(MutatingWalk( + func, [&](mlir::TF::DTensorReduceScatterOp reduce_scatter) { + // Lower complex type reduce scatter. + return ConvertComplexReduce(reduce_scatter); + }))) { + signalPassFailure(); + } if (mlir::failed( MutatingWalk(func, [&](mlir::TF::DTensorAllReduceOp all_reduce) { diff --git a/tensorflow/dtensor/mlir/tests/dtensor_collective_type_lowering.mlir b/tensorflow/dtensor/mlir/tests/dtensor_collective_type_lowering.mlir index b4a6f9a4904d0e..d987100b2acb67 100644 --- a/tensorflow/dtensor/mlir/tests/dtensor_collective_type_lowering.mlir +++ b/tensorflow/dtensor/mlir/tests/dtensor_collective_type_lowering.mlir @@ -1,5 +1,57 @@ // RUN: dtensor-opt -split-input-file -dtensor-collective-type-lowering -verify-diagnostics %s| FileCheck %s --dump-input=fail +// Check the lowering of AllScatter on CPU with any complex reduction. +// CHECK-LABEL: func @lower_allgather_complex64 +func.func @lower_allgather_complex64(%arg0: tensor, + %arg1: tensor<1x2xcomplex> {tf._layout = "sharding_specs:x,unsharded, mesh:|x=2,y=1|0,1|0,1|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1", tf._mesh = "|x=2,y=1|0,1|0,1|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1"}) { + // CHECK: "tf_device.cluster" + // CHECK-NEXT: %[[REAL:.*]] = "tf.Real"(%arg1) + // CHECK-NEXT: %[[IMAG:.*]] = "tf.Imag"(%arg1) + // CHECK-NEXT: %[[ALLGATHER_OUT_REAL:.*]] = "tf.DTensorAllGather"(%[[REAL]]) + // CHECK-SAME: _layout = ["sharding_specs:unsharded,unsharded, mesh:|x=2,y=1|0,1|0,1|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1"] + // CHECK-SAME: (tensor<1x2xf32>) -> tensor<2x2xf32> + // CHECK-NEXT: %[[ALLGATHER_OUT_IMAG:.*]] = "tf.DTensorAllGather"(%[[IMAG]]) + // CHECK-SAME: _layout = ["sharding_specs:unsharded,unsharded, mesh:|x=2,y=1|0,1|0,1|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1"] + // CHECK-SAME: (tensor<1x2xf32>) -> tensor<2x2xf32> + // CHECK-NEXT: %[[OUTPUT:.*]] = "tf.Complex"(%[[ALLGATHER_OUT_REAL]], %[[ALLGATHER_OUT_IMAG]]) + // CHECK-SAME: _layout = ["sharding_specs:unsharded,unsharded, mesh:|x=2,y=1|0,1|0,1|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1"] + // CHECK-SAME: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xcomplex> + // CHECK-NEXT return %[[OUTPUT]] + %0 = "tf_device.cluster"() ({ + %1 = "tf.DTensorAllGather"(%arg1) {_layout = ["sharding_specs:unsharded,unsharded, mesh:|x=2,y=1|0,1|0,1|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1"], input_layout = #dtensor.layout, output_layout = #dtensor.layout} : (tensor<1x2xcomplex>) -> tensor<2x2xcomplex> + tf_device.return %1 : tensor<2x2xcomplex> + }) {_mesh = "|x=2,y=1|0,1|0,1|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1"} : () -> tensor<2x2xcomplex> + func.return +} + +// ----- + +// Check the lowering of DTensorAllToAll on TPU with any complex reduction. +// CHECK-LABEL: func @lower_all_to_all_complex128 +func.func @lower_all_to_all_complex128(%arg0: tensor, + %arg1: tensor<1x2xcomplex> {tf._layout = "sharding_specs:x,unsharded, mesh:|x=2,y=1|0,1|0,1|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1", tf._mesh = "|x=2,y=1|0,1|0,1|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1"}) { + // CHECK: "tf_device.cluster" + // CHECK-NEXT: %[[REAL:.*]] = "tf.Real"(%arg1) + // CHECK-NEXT: %[[IMAG:.*]] = "tf.Imag"(%arg1) + // CHECK-NEXT: %[[ALLTOALL_OUT_REAL:.*]] = "tf.DTensorAllToAll"(%[[REAL]]) + // CHECK-SAME: _layout = ["sharding_specs:unsharded,x, mesh:|x=2,y=1|0,1|0,1|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1"] + // CHECK-SAME: (tensor<1x2xf64>) -> tensor<2x1xf64> + // CHECK-NEXT: %[[ALLTOALL_OUT_IMAG:.*]] = "tf.DTensorAllToAll"(%[[IMAG]]) + // CHECK-SAME: _layout = ["sharding_specs:unsharded,x, mesh:|x=2,y=1|0,1|0,1|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1"] + // CHECK-SAME: (tensor<1x2xf64>) -> tensor<2x1xf64> + // CHECK-NEXT: %[[OUTPUT:.*]] = "tf.Complex"(%[[ALLTOALL_OUT_REAL]], %[[ALLTOALL_OUT_IMAG]]) + // CHECK-SAME: _layout = ["sharding_specs:unsharded,x, mesh:|x=2,y=1|0,1|0,1|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1"] + // CHECK-SAME: (tensor<2x1xf64>, tensor<2x1xf64>) -> tensor<2x1xcomplex> + // CHECK-NEXT return %[[OUTPUT]] + %0 = "tf_device.cluster"() ({ + %1 = "tf.DTensorAllToAll"(%arg1) {_layout = ["sharding_specs:unsharded,x, mesh:|x=2,y=1|0,1|0,1|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1"], input_layout = #dtensor.layout, output_layout = #dtensor.layout} : (tensor<1x2xcomplex>) -> tensor<2x1xcomplex> + tf_device.return %1 : tensor<2x1xcomplex> + }) {_mesh = "|x=2,y=1|0,1|0,1|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1"} : () -> tensor<2x1xcomplex> + func.return +} + +// ----- + // Check the lowering of AllReduce on TPU with any boolean reduction. // CHECK-LABEL: func @lower_allreduce_any_boolean func.func @lower_allreduce_any_boolean() -> (tensor<4096x8192xi1>) { diff --git a/tensorflow/dtensor/mlir/utils/collective_lowering.cc b/tensorflow/dtensor/mlir/utils/collective_lowering.cc index 61897371071e2f..2897e9c6b80da6 100644 --- a/tensorflow/dtensor/mlir/utils/collective_lowering.cc +++ b/tensorflow/dtensor/mlir/utils/collective_lowering.cc @@ -551,151 +551,8 @@ mlir::LogicalResult LowerAllReduceOpImpl( return mlir::success(); } -// Extension for all reduce with complex numbers -template -mlir::LogicalResult ConvertComplexAllReduce( - AllReduceOpType all_reduce_op, AllReduceOpType* real_all_reduce_op, - AllReduceOpType* imag_all_reduce_op) { - mlir::OpBuilder builder(all_reduce_op); - StatusOr output_layout = - ExtractRequiredSingleLayoutFromOp(all_reduce_op); - if (!output_layout.ok()) { - return all_reduce_op.emitOpError(output_layout.status().message()); - } - - const mlir::Value tensor_input = all_reduce_op.getInput(); - const mlir::Value tensor_result = all_reduce_op.getResult(); - const mlir::TensorType complex_input_tensor_type = - tensor_input.getType().dyn_cast(); - if (!complex_input_tensor_type) { - return mlir::success(); - } - const mlir::TensorType complex_result_tensor_type = - tensor_result.getType().dyn_cast(); - if (!complex_result_tensor_type) { - return mlir::success(); - } - auto input_element_type = mlir::dyn_cast( - complex_input_tensor_type.getElementType()); - if (!input_element_type) { - return mlir::success(); - } - auto real_input_tensor_type = - mlir::RankedTensorType::get(complex_input_tensor_type.getShape(), - input_element_type.getElementType()); - auto real_result_tensor_type = - mlir::RankedTensorType::get(complex_result_tensor_type.getShape(), - input_element_type.getElementType()); - const mlir::Value tensor_temp_real = builder.create( - all_reduce_op.getLoc(), real_input_tensor_type, tensor_input); - const mlir::Value tensor_temp_imag = builder.create( - all_reduce_op.getLoc(), real_input_tensor_type, tensor_input); - *real_all_reduce_op = - mlir::dyn_cast(builder.clone(*all_reduce_op)); - real_all_reduce_op->setOperand(0, tensor_temp_real); - (*real_all_reduce_op)->getResult(0).setType(real_result_tensor_type); - *imag_all_reduce_op = - mlir::dyn_cast(builder.clone(*all_reduce_op)); - imag_all_reduce_op->setOperand(0, tensor_temp_imag); - (*imag_all_reduce_op)->getResult(0).setType(real_result_tensor_type); - const mlir::Type output_type = all_reduce_op.getResult().getType(); - auto complex_all_reduce_op = builder.create( - all_reduce_op->getLoc(), output_type, real_all_reduce_op->getResult(), - imag_all_reduce_op->getResult()); - StatusOr desired_layout = - ExtractRequiredSingleLayoutFromOp(all_reduce_op); - SetSingleLayoutOnOp(complex_all_reduce_op, *desired_layout); - all_reduce_op.getOutput().replaceAllUsesWith( - complex_all_reduce_op.getResult()); - all_reduce_op.erase(); - return mlir::success(); -} - -// For AllToAll and AllScatter -template -mlir::LogicalResult ConvertComplexAllToAll(AllToAllOpType all_to_all_op, - AllToAllOpType* real_all_to_all, - AllToAllOpType* imag_all_to_all) { - mlir::OpBuilder builder(all_to_all_op); - StatusOr output_layout = - ExtractRequiredSingleLayoutFromOp(all_to_all_op); - if (!output_layout.ok()) { - return all_to_all_op.emitOpError(output_layout.status().message()); - } - - const mlir::Value tensor_input = all_to_all_op.getInput(); - const mlir::Value tensor_result = all_to_all_op.getResult(); - const mlir::TensorType complex_input_tensor_type = - tensor_input.getType().dyn_cast(); - if (!complex_input_tensor_type) { - return mlir::success(); - } - const mlir::TensorType& complex_result_tensor_type = - tensor_result.getType().dyn_cast(); - if (!complex_result_tensor_type) { - return mlir::success(); - } - - auto input_element_type = mlir::dyn_cast( - complex_input_tensor_type.getElementType()); - if (!input_element_type) { - return mlir::success(); - } - auto real_input_tensor_type = - mlir::RankedTensorType::get(complex_input_tensor_type.getShape(), - input_element_type.getElementType()); - auto real_result_tensor_type = - mlir::RankedTensorType::get(complex_result_tensor_type.getShape(), - input_element_type.getElementType()); - const mlir::Value tensor_temp_real = builder.create( - all_to_all_op.getLoc(), real_input_tensor_type, tensor_input); - const mlir::Value tensor_temp_imag = builder.create( - all_to_all_op.getLoc(), real_input_tensor_type, tensor_input); - *real_all_to_all = - mlir::dyn_cast(builder.clone(*all_to_all_op)); - (*real_all_to_all)->setOperand(0, tensor_temp_real); - (*real_all_to_all)->getResult(0).setType(real_result_tensor_type); - *imag_all_to_all = - mlir::dyn_cast(builder.clone(*all_to_all_op)); - (*imag_all_to_all)->setOperand(0, tensor_temp_imag); - (*imag_all_to_all)->getResult(0).setType(real_result_tensor_type); - const mlir::Type output_type = all_to_all_op.getResult().getType(); - auto all_to_all_complex_op = builder.create( - all_to_all_op.getLoc(), output_type, real_all_to_all->getResult(), - imag_all_to_all->getResult()); - const Layout desired_layout = all_to_all_op.getOutputLayout(); - SetSingleLayoutOnOp(all_to_all_complex_op, desired_layout); - all_to_all_op.getOutput().replaceAllUsesWith( - all_to_all_complex_op.getResult()); - all_to_all_op.erase(); - return mlir::success(); -} - mlir::LogicalResult LowerAllReduceOp(mlir::MLIRContext& context, mlir::TF::DTensorAllReduceOp all_reduce) { - mlir::TF::DTensorAllReduceOp real_all_reduce; - mlir::TF::DTensorAllReduceOp imag_all_reduce; - if (mlir::failed(ConvertComplexAllReduce( - all_reduce, &real_all_reduce, &imag_all_reduce))) - return mlir::failure(); - if (real_all_reduce && imag_all_reduce) { - mlir::OpBuilder builder_real(real_all_reduce); - mlir::OpBuilder builder_imag(imag_all_reduce); - mlir::Value result_real; - mlir::Value result_imag; - if (mlir::failed(LowerAllReduceOpImpl(context, builder_real, - real_all_reduce, &result_real))) - return mlir::failure(); - if (mlir::failed(LowerAllReduceOpImpl(context, builder_imag, - imag_all_reduce, &result_imag))) - return mlir::failure(); - - real_all_reduce.replaceAllUsesWith(result_real); - imag_all_reduce.replaceAllUsesWith(result_imag); - real_all_reduce.erase(); - imag_all_reduce.erase(); - return mlir::success(); - } mlir::OpBuilder builder(all_reduce); mlir::Value result; if (mlir::failed(LowerAllReduceOpImpl(context, builder, all_reduce, &result))) @@ -1201,7 +1058,7 @@ mlir::LogicalResult LowerAllGatherOp(mlir::TF::DTensorAllGatherOp all_gather) { return mlir::LogicalResult::success(); } -mlir::LogicalResult LowerAllScatterHelper( +mlir::LogicalResult LowerAllScatterOp( mlir::TF::DTensorAllScatterOp all_scatter) { const Layout original_layout = all_scatter.getInputLayout(); const Layout desired_layout = all_scatter.getOutputLayout(); @@ -1302,24 +1159,7 @@ mlir::LogicalResult LowerAllScatterHelper( return mlir::LogicalResult::success(); } -mlir::LogicalResult LowerAllScatterOp( - mlir::TF::DTensorAllScatterOp all_scatter) { - mlir::TF::DTensorAllScatterOp real_all_scatter; - mlir::TF::DTensorAllScatterOp imag_all_scatter; - if (mlir::failed(ConvertComplexAllToAll( - all_scatter, &real_all_scatter, &imag_all_scatter))) - return mlir::failure(); - - if (real_all_scatter && imag_all_scatter) { - auto status = LowerAllScatterHelper(real_all_scatter); - status = LowerAllScatterHelper(imag_all_scatter); - return status; - } - return LowerAllScatterHelper(all_scatter); -} - -mlir::LogicalResult LowerAllToAllHelper( - mlir::TF::DTensorAllToAllOp all_to_all) { +mlir::LogicalResult LowerAllToAllOp(mlir::TF::DTensorAllToAllOp all_to_all) { mlir::OpBuilder builder(all_to_all); mlir::Location loc = all_to_all.getLoc(); const Layout src_layout = all_to_all.getInputLayout(); @@ -1390,20 +1230,6 @@ mlir::LogicalResult LowerAllToAllHelper( return mlir::LogicalResult::success(); } -mlir::LogicalResult LowerAllToAllOp(mlir::TF::DTensorAllToAllOp all_to_all) { - mlir::TF::DTensorAllToAllOp real_all_to_all; - mlir::TF::DTensorAllToAllOp imag_all_to_all; - if (mlir::failed(ConvertComplexAllToAll( - all_to_all, &real_all_to_all, &imag_all_to_all))) - return mlir::failure(); - if (real_all_to_all && imag_all_to_all) { - auto status = LowerAllToAllHelper(real_all_to_all); - status = LowerAllToAllHelper(imag_all_to_all); - return status; - } - return LowerAllToAllHelper(all_to_all); -} - } // namespace internal namespace { From 9c91df1a28e574a383ba571c9d2a1c99995f4dcb Mon Sep 17 00:00:00 2001 From: Fiona Lang Date: Thu, 27 Jul 2023 10:19:28 -0700 Subject: [PATCH 248/410] Update legacy reference to tensor.Tensor. PiperOrigin-RevId: 551569275 --- tensorflow/python/framework/framework_lib.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/framework/framework_lib.py b/tensorflow/python/framework/framework_lib.py index 52bad75b06e813..34aa3435add619 100644 --- a/tensorflow/python/framework/framework_lib.py +++ b/tensorflow/python/framework/framework_lib.py @@ -21,7 +21,7 @@ from tensorflow.python.framework.indexed_slices import IndexedSlices from tensorflow.python.framework.ops import Graph from tensorflow.python.framework.ops import Operation -from tensorflow.python.framework.ops import Tensor +from tensorflow.python.framework.tensor import Tensor from tensorflow.python.framework.sparse_tensor import SparseTensor from tensorflow.python.framework.sparse_tensor import SparseTensorValue From 5368a3ad4719842eecae9e4c965b1a5d072cb9e5 Mon Sep 17 00:00:00 2001 From: Fiona Lang Date: Thu, 27 Jul 2023 10:20:12 -0700 Subject: [PATCH 249/410] Delete all API exports from the stale tensorflow/python/keras/ directory. PiperOrigin-RevId: 551569501 --- tensorflow/python/keras/__init__.py | 6 - tensorflow/python/keras/activations.py | 17 -- tensorflow/python/keras/backend.py | 148 ------------------ tensorflow/python/keras/backend_config.py | 7 - tensorflow/python/keras/callbacks.py | 16 -- tensorflow/python/keras/callbacks_v1.py | 2 - tensorflow/python/keras/constraints.py | 11 -- tensorflow/python/keras/engine/base_layer.py | 2 - .../python/keras/engine/base_layer_utils.py | 3 - .../keras/engine/base_preprocessing_layer.py | 2 - .../python/keras/engine/data_adapter.py | 3 - tensorflow/python/keras/engine/input_layer.py | 3 - tensorflow/python/keras/engine/input_spec.py | 2 - tensorflow/python/keras/engine/sequential.py | 2 - tensorflow/python/keras/engine/training.py | 2 - .../python/keras/initializers/__init__.py | 4 - .../keras/initializers/initializers_v1.py | 32 +--- .../keras/initializers/initializers_v2.py | 43 ----- .../keras/layers/advanced_activations.py | 7 - .../python/keras/layers/convolutional.py | 24 --- .../keras/layers/convolutional_recurrent.py | 2 - tensorflow/python/keras/layers/core.py | 14 -- .../python/keras/layers/dense_attention.py | 3 - tensorflow/python/keras/layers/embeddings.py | 2 - .../keras/layers/legacy_rnn/rnn_cell_impl.py | 11 -- tensorflow/python/keras/layers/merge.py | 17 -- tensorflow/python/keras/layers/pooling.py | 16 -- tensorflow/python/keras/layers/recurrent.py | 11 -- .../python/keras/layers/serialization.py | 3 - .../python/keras/legacy_tf_layers/base.py | 6 - .../keras/legacy_tf_layers/convolutional.py | 15 -- .../python/keras/legacy_tf_layers/core.py | 7 - .../python/keras/legacy_tf_layers/pooling.py | 13 -- tensorflow/python/keras/losses.py | 61 -------- tensorflow/python/keras/metrics.py | 48 ------ .../keras/mixed_precision/get_layer_policy.py | 2 - .../mixed_precision/loss_scale_optimizer.py | 3 - .../python/keras/mixed_precision/policy.py | 7 - tensorflow/python/keras/models.py | 6 - .../python/keras/optimizer_v2/adadelta.py | 2 - .../python/keras/optimizer_v2/adagrad.py | 2 - tensorflow/python/keras/optimizer_v2/adam.py | 2 - .../python/keras/optimizer_v2/adamax.py | 2 - tensorflow/python/keras/optimizer_v2/ftrl.py | 2 - .../keras/optimizer_v2/gradient_descent.py | 2 - .../optimizer_v2/learning_rate_schedule.py | 12 -- tensorflow/python/keras/optimizer_v2/nadam.py | 2 - .../python/keras/optimizer_v2/optimizer_v2.py | 2 - .../python/keras/optimizer_v2/rmsprop.py | 2 - tensorflow/python/keras/optimizers.py | 4 - tensorflow/python/keras/regularizers.py | 9 -- .../python/keras/saving/model_config.py | 4 - tensorflow/python/keras/saving/save.py | 3 - .../keras/saving/saved_model_experimental.py | 3 - tensorflow/python/keras/utils/data_utils.py | 6 - .../python/keras/utils/dataset_creator.py | 2 - .../python/keras/utils/generic_utils.py | 10 -- tensorflow/python/keras/utils/layer_utils.py | 2 - tensorflow/python/keras/utils/losses_utils.py | 3 - tensorflow/python/keras/utils/np_utils.py | 3 - tensorflow/python/keras/utils/tf_utils.py | 2 - tensorflow/python/keras/utils/vis_utils.py | 3 - 62 files changed, 1 insertion(+), 666 deletions(-) diff --git a/tensorflow/python/keras/__init__.py b/tensorflow/python/keras/__init__.py index 1c8c951a7fa761..a3a8a380d7e08a 100644 --- a/tensorflow/python/keras/__init__.py +++ b/tensorflow/python/keras/__init__.py @@ -27,9 +27,3 @@ from tensorflow.python.keras.engine.input_layer import Input from tensorflow.python.keras.engine.sequential import Sequential from tensorflow.python.keras.engine.training import Model - -from tensorflow.python.util.tf_export import keras_export - -__version__ = '2.6.0' - -keras_export('keras.__version__').export_constant(__name__, '__version__') diff --git a/tensorflow/python/keras/activations.py b/tensorflow/python/keras/activations.py index 9afb77323d9e20..623c8e365cb9ef 100644 --- a/tensorflow/python/keras/activations.py +++ b/tensorflow/python/keras/activations.py @@ -21,7 +21,6 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.util import dispatch -from tensorflow.python.util.tf_export import keras_export # b/123041942 # In TF 2.x, if the `tf.nn.softmax` is used as an activation function in Keras @@ -36,7 +35,6 @@ } -@keras_export('keras.activations.softmax') @dispatch.add_dispatch_support def softmax(x, axis=-1): """Softmax converts a vector of values to a probability distribution. @@ -93,7 +91,6 @@ def softmax(x, axis=-1): return output -@keras_export('keras.activations.elu') @dispatch.add_dispatch_support def elu(x, alpha=1.0): """Exponential Linear Unit. @@ -143,7 +140,6 @@ def elu(x, alpha=1.0): return backend.elu(x, alpha) -@keras_export('keras.activations.selu') @dispatch.add_dispatch_support def selu(x): """Scaled Exponential Linear Unit (SELU). @@ -197,7 +193,6 @@ def selu(x): return nn.selu(x) -@keras_export('keras.activations.softplus') @dispatch.add_dispatch_support def softplus(x): """Softplus activation function, `softplus(x) = log(exp(x) + 1)`. @@ -219,7 +214,6 @@ def softplus(x): return math_ops.softplus(x) -@keras_export('keras.activations.softsign') @dispatch.add_dispatch_support def softsign(x): """Softsign activation function, `softsign(x) = x / (abs(x) + 1)`. @@ -240,7 +234,6 @@ def softsign(x): return nn.softsign(x) -@keras_export('keras.activations.swish') @dispatch.add_dispatch_support def swish(x): """Swish activation function, `swish(x) = x * sigmoid(x)`. @@ -271,7 +264,6 @@ def swish(x): return nn.swish(x) -@keras_export('keras.activations.relu') @dispatch.add_dispatch_support def relu(x, alpha=0., max_value=None, threshold=0): """Applies the rectified linear unit activation function. @@ -312,7 +304,6 @@ def relu(x, alpha=0., max_value=None, threshold=0): return backend.relu(x, alpha=alpha, max_value=max_value, threshold=threshold) -@keras_export('keras.activations.gelu', v1=[]) @dispatch.add_dispatch_support def gelu(x, approximate=False): """Applies the Gaussian error linear unit (GELU) activation function. @@ -352,7 +343,6 @@ def gelu(x, approximate=False): return nn.gelu(x, approximate) -@keras_export('keras.activations.tanh') @dispatch.add_dispatch_support def tanh(x): """Hyperbolic tangent activation function. @@ -374,7 +364,6 @@ def tanh(x): return nn.tanh(x) -@keras_export('keras.activations.sigmoid') @dispatch.add_dispatch_support def sigmoid(x): """Sigmoid activation function, `sigmoid(x) = 1 / (1 + exp(-x))`. @@ -407,7 +396,6 @@ def sigmoid(x): return output -@keras_export('keras.activations.exponential') @dispatch.add_dispatch_support def exponential(x): """Exponential activation function. @@ -428,7 +416,6 @@ def exponential(x): return math_ops.exp(x) -@keras_export('keras.activations.hard_sigmoid') @dispatch.add_dispatch_support def hard_sigmoid(x): """Hard sigmoid activation function. @@ -457,7 +444,6 @@ def hard_sigmoid(x): return backend.hard_sigmoid(x) -@keras_export('keras.activations.linear') @dispatch.add_dispatch_support def linear(x): """Linear activation function (pass-through). @@ -478,7 +464,6 @@ def linear(x): return x -@keras_export('keras.activations.serialize') @dispatch.add_dispatch_support def serialize(activation): """Returns the string identifier of an activation function. @@ -517,7 +502,6 @@ def serialize(activation): silu = nn.swish -@keras_export('keras.activations.deserialize') @dispatch.add_dispatch_support def deserialize(name, custom_objects=None): """Returns activation function given a string identifier. @@ -560,7 +544,6 @@ def deserialize(name, custom_objects=None): printable_module_name='activation function') -@keras_export('keras.activations.get') @dispatch.add_dispatch_support def get(identifier): """Returns function. diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py index 18ebf20ab0eff3..7d491e3e23e0aa 100644 --- a/tensorflow/python/keras/backend.py +++ b/tensorflow/python/keras/backend.py @@ -82,7 +82,6 @@ from tensorflow.python.training import moving_averages from tensorflow.python.util import dispatch from tensorflow.python.util import nest -from tensorflow.python.util.tf_export import keras_export from tensorflow.tools.docs import doc_controls py_all = all @@ -173,7 +172,6 @@ def __init__(self): set_image_data_format = backend_config.set_image_data_format -@keras_export('keras.backend.backend') @doc_controls.do_not_generate_docs def backend(): """Publicly accessible method for determining the current backend. @@ -186,7 +184,6 @@ def backend(): return 'tensorflow' -@keras_export('keras.backend.cast_to_floatx') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def cast_to_floatx(x): @@ -220,7 +217,6 @@ def cast_to_floatx(x): return np.asarray(x, dtype=floatx()) -@keras_export('keras.backend.get_uid') def get_uid(prefix=''): """Associates a string prefix with an integer counter in a TensorFlow graph. @@ -246,7 +242,6 @@ def get_uid(prefix=''): return layer_name_uids[prefix] -@keras_export('keras.backend.reset_uids') def reset_uids(): """Resets graph identifiers. """ @@ -255,7 +250,6 @@ def reset_uids(): OBSERVED_NAMES.clear() -@keras_export('keras.backend.clear_session') def clear_session(): """Resets all state generated by Keras. @@ -320,7 +314,6 @@ def clear_session(): context.context().clear_kernel_cache() -@keras_export('keras.backend.manual_variable_initialization') @doc_controls.do_not_generate_docs def manual_variable_initialization(value): """Sets the manual variable initialization flag. @@ -338,7 +331,6 @@ def manual_variable_initialization(value): _MANUAL_VAR_INIT = value -@keras_export('keras.backend.learning_phase') @doc_controls.do_not_generate_docs def learning_phase(): """Returns the learning phase flag. @@ -407,7 +399,6 @@ def _default_learning_phase(): False, shape=(), name='keras_learning_phase') -@keras_export('keras.backend.set_learning_phase') @doc_controls.do_not_generate_docs def set_learning_phase(value): """Sets the learning phase to a fixed value. @@ -473,7 +464,6 @@ def deprecated_internal_set_learning_phase(value): _GRAPH_LEARNING_PHASES[get_graph()] = value -@keras_export('keras.backend.learning_phase_scope') @tf_contextlib.contextmanager @doc_controls.do_not_generate_docs def learning_phase_scope(value): @@ -717,7 +707,6 @@ def _get_session(op_input_list=()): return session -@keras_export(v1=['keras.backend.get_session']) def get_session(op_input_list=()): """Returns the TF session to be used by the backend. @@ -793,7 +782,6 @@ def _scratch_graph(graph=None): _CURRENT_SCRATCH_GRAPH.graph = None -@keras_export(v1=['keras.backend.set_session']) def set_session(session): """Sets the global TensorFlow session. @@ -946,7 +934,6 @@ def _to_tensor(x, dtype): return tensor_conversion.convert_to_tensor_v2_with_dispatch(x, dtype=dtype) -@keras_export('keras.backend.is_sparse') @doc_controls.do_not_generate_docs def is_sparse(tensor): """Returns whether a tensor is a sparse tensor. @@ -974,7 +961,6 @@ def is_sparse(tensor): return isinstance(tensor, sparse_tensor.SparseTensor) -@keras_export('keras.backend.to_dense') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def to_dense(tensor): @@ -1003,7 +989,6 @@ def to_dense(tensor): return tensor -@keras_export('keras.backend.name_scope', v1=[]) @doc_controls.do_not_generate_docs def name_scope(name): """A context manager for use when defining a Python op. @@ -1031,12 +1016,7 @@ def my_op(a): """ return ops.name_scope_v2(name) -# Export V1 version. -_v1_name_scope = ops.name_scope_v1 -keras_export(v1=['keras.backend.name_scope'])(_v1_name_scope) - -@keras_export('keras.backend.variable') @doc_controls.do_not_generate_docs def variable(value, dtype=None, name=None, constraint=None): """Instantiates a variable and returns it. @@ -1095,7 +1075,6 @@ def track_tf_optimizer(tf_optimizer): optimizers.add(tf_optimizer) -@keras_export('keras.__internal__.backend.track_variable', v1=[]) def track_variable(v): """Tracks the given variable for initialization.""" if context.executing_eagerly(): @@ -1174,7 +1153,6 @@ def _get_variables(graph=None): return variables -@keras_export('keras.__internal__.backend.initialize_variables', v1=[]) def _initialize_variables(session): """Utility to initialize uninitialized variables on the fly.""" variables = _get_variables(get_graph()) @@ -1201,7 +1179,6 @@ def _initialize_variables(session): session.run(variables_module.variables_initializer(uninitialized_vars)) -@keras_export('keras.backend.constant') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def constant(value, dtype=None, shape=None, name=None): @@ -1222,7 +1199,6 @@ def constant(value, dtype=None, shape=None, name=None): return constant_op.constant(value, dtype=dtype, shape=shape, name=name) -@keras_export('keras.backend.is_keras_tensor') def is_keras_tensor(x): """Returns whether `x` is a Keras tensor. @@ -1276,7 +1252,6 @@ def is_keras_tensor(x): return hasattr(x, '_keras_history') -@keras_export('keras.backend.placeholder') @doc_controls.do_not_generate_docs def placeholder(shape=None, ndim=None, @@ -1393,7 +1368,6 @@ def is_placeholder(x): return False -@keras_export('keras.backend.shape') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def shape(x): @@ -1419,7 +1393,6 @@ def shape(x): return array_ops.shape(x) -@keras_export('keras.backend.int_shape') @doc_controls.do_not_generate_docs def int_shape(x): """Returns the shape of tensor or variable as a tuple of int or None entries. @@ -1450,7 +1423,6 @@ def int_shape(x): return None -@keras_export('keras.backend.ndim') @doc_controls.do_not_generate_docs def ndim(x): """Returns the number of axes in a tensor, as an integer. @@ -1476,7 +1448,6 @@ def ndim(x): return x.shape.rank -@keras_export('keras.backend.dtype') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def dtype(x): @@ -1523,7 +1494,6 @@ def dtype_numpy(x): return dtypes_module.as_dtype(x.dtype).as_numpy_dtype -@keras_export('keras.backend.eval') @doc_controls.do_not_generate_docs def eval(x): """Evaluates the value of a variable. @@ -1546,7 +1516,6 @@ def eval(x): return get_value(to_dense(x)) -@keras_export('keras.backend.zeros') @doc_controls.do_not_generate_docs def zeros(shape, dtype=None, name=None): """Instantiates an all-zeros variable and returns it. @@ -1591,7 +1560,6 @@ def zeros(shape, dtype=None, name=None): return v -@keras_export('keras.backend.ones') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def ones(shape, dtype=None, name=None): @@ -1627,7 +1595,6 @@ def ones(shape, dtype=None, name=None): return v -@keras_export('keras.backend.eye') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def eye(size, dtype=None, name=None): @@ -1658,7 +1625,6 @@ def eye(size, dtype=None, name=None): return variable(linalg_ops.eye(size, dtype=tf_dtype), dtype, name) -@keras_export('keras.backend.zeros_like') @doc_controls.do_not_generate_docs def zeros_like(x, dtype=None, name=None): """Instantiates an all-zeros variable of the same shape as another tensor. @@ -1685,7 +1651,6 @@ def zeros_like(x, dtype=None, name=None): return array_ops.zeros_like(x, dtype=dtype, name=name) -@keras_export('keras.backend.ones_like') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def ones_like(x, dtype=None, name=None): @@ -1725,7 +1690,6 @@ def identity(x, name=None): return array_ops.identity(x, name=name) -@keras_export('keras.backend.random_uniform_variable') @doc_controls.do_not_generate_docs def random_uniform_variable(shape, low, high, dtype=None, name=None, seed=None): """Instantiates a variable with values drawn from a uniform distribution. @@ -1760,7 +1724,6 @@ def random_uniform_variable(shape, low, high, dtype=None, name=None, seed=None): return variable(value, dtype=dtype, name=name) -@keras_export('keras.backend.random_normal_variable') @doc_controls.do_not_generate_docs def random_normal_variable(shape, mean, scale, dtype=None, name=None, seed=None): @@ -1796,7 +1759,6 @@ def random_normal_variable(shape, mean, scale, dtype=None, name=None, return variable(value, dtype=dtype, name=name) -@keras_export('keras.backend.count_params') @doc_controls.do_not_generate_docs def count_params(x): """Returns the static number of elements in a variable or tensor. @@ -1820,7 +1782,6 @@ def count_params(x): return np.prod(x.shape.as_list()) -@keras_export('keras.backend.cast') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def cast(x, dtype): @@ -1853,13 +1814,11 @@ def cast(x, dtype): # UPDATES OPS -@keras_export('keras.backend.update') @doc_controls.do_not_generate_docs def update(x, new_x): return state_ops.assign(x, new_x) -@keras_export('keras.backend.update_add') @doc_controls.do_not_generate_docs def update_add(x, increment): """Update the value of `x` by adding `increment`. @@ -1874,7 +1833,6 @@ def update_add(x, increment): return state_ops.assign_add(x, increment) -@keras_export('keras.backend.update_sub') @doc_controls.do_not_generate_docs def update_sub(x, decrement): """Update the value of `x` by subtracting `decrement`. @@ -1889,7 +1847,6 @@ def update_sub(x, decrement): return state_ops.assign_sub(x, decrement) -@keras_export('keras.backend.moving_average_update') @doc_controls.do_not_generate_docs def moving_average_update(x, value, momentum): """Compute the exponential moving average of a value. @@ -1940,7 +1897,6 @@ def moving_average_update(x, value, momentum): # LINEAR ALGEBRA -@keras_export('keras.backend.dot') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def dot(x, y): @@ -2007,7 +1963,6 @@ def dot(x, y): return out -@keras_export('keras.backend.batch_dot') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def batch_dot(x, y, axes=None): @@ -2197,7 +2152,6 @@ def batch_dot(x, y, axes=None): return result -@keras_export('keras.backend.transpose') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def transpose(x): @@ -2230,7 +2184,6 @@ def transpose(x): return array_ops.transpose(x) -@keras_export('keras.backend.gather') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def gather(reference, indices): @@ -2267,7 +2220,6 @@ def gather(reference, indices): # ELEMENT-WISE OPERATIONS -@keras_export('keras.backend.max') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def max(x, axis=None, keepdims=False): @@ -2287,7 +2239,6 @@ def max(x, axis=None, keepdims=False): return math_ops.reduce_max(x, axis, keepdims) -@keras_export('keras.backend.min') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def min(x, axis=None, keepdims=False): @@ -2307,7 +2258,6 @@ def min(x, axis=None, keepdims=False): return math_ops.reduce_min(x, axis, keepdims) -@keras_export('keras.backend.sum') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def sum(x, axis=None, keepdims=False): @@ -2327,7 +2277,6 @@ def sum(x, axis=None, keepdims=False): return math_ops.reduce_sum(x, axis, keepdims) -@keras_export('keras.backend.prod') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def prod(x, axis=None, keepdims=False): @@ -2347,7 +2296,6 @@ def prod(x, axis=None, keepdims=False): return math_ops.reduce_prod(x, axis, keepdims) -@keras_export('keras.backend.cumsum') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def cumsum(x, axis=0): @@ -2363,7 +2311,6 @@ def cumsum(x, axis=0): return math_ops.cumsum(x, axis=axis) -@keras_export('keras.backend.cumprod') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def cumprod(x, axis=0): @@ -2379,7 +2326,6 @@ def cumprod(x, axis=0): return math_ops.cumprod(x, axis=axis) -@keras_export('keras.backend.var') @doc_controls.do_not_generate_docs def var(x, axis=None, keepdims=False): """Variance of a tensor, alongside the specified axis. @@ -2400,7 +2346,6 @@ def var(x, axis=None, keepdims=False): return math_ops.reduce_variance(x, axis=axis, keepdims=keepdims) -@keras_export('keras.backend.std') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def std(x, axis=None, keepdims=False): @@ -2428,7 +2373,6 @@ def std(x, axis=None, keepdims=False): return math_ops.reduce_std(x, axis=axis, keepdims=keepdims) -@keras_export('keras.backend.mean') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def mean(x, axis=None, keepdims=False): @@ -2450,7 +2394,6 @@ def mean(x, axis=None, keepdims=False): return math_ops.reduce_mean(x, axis, keepdims) -@keras_export('keras.backend.any') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def any(x, axis=None, keepdims=False): @@ -2468,7 +2411,6 @@ def any(x, axis=None, keepdims=False): return math_ops.reduce_any(x, axis, keepdims) -@keras_export('keras.backend.all') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def all(x, axis=None, keepdims=False): @@ -2486,7 +2428,6 @@ def all(x, axis=None, keepdims=False): return math_ops.reduce_all(x, axis, keepdims) -@keras_export('keras.backend.argmax') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def argmax(x, axis=-1): @@ -2502,7 +2443,6 @@ def argmax(x, axis=-1): return math_ops.argmax(x, axis) -@keras_export('keras.backend.argmin') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def argmin(x, axis=-1): @@ -2518,7 +2458,6 @@ def argmin(x, axis=-1): return math_ops.argmin(x, axis) -@keras_export('keras.backend.square') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def square(x): @@ -2533,7 +2472,6 @@ def square(x): return math_ops.square(x) -@keras_export('keras.backend.abs') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def abs(x): @@ -2548,7 +2486,6 @@ def abs(x): return math_ops.abs(x) -@keras_export('keras.backend.sqrt') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def sqrt(x): @@ -2568,7 +2505,6 @@ def sqrt(x): return math_ops.sqrt(x) -@keras_export('keras.backend.exp') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def exp(x): @@ -2583,7 +2519,6 @@ def exp(x): return math_ops.exp(x) -@keras_export('keras.backend.log') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def log(x): @@ -2619,7 +2554,6 @@ def logsumexp(x, axis=None, keepdims=False): return math_ops.reduce_logsumexp(x, axis, keepdims) -@keras_export('keras.backend.round') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def round(x): @@ -2636,7 +2570,6 @@ def round(x): return math_ops.round(x) -@keras_export('keras.backend.sign') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def sign(x): @@ -2651,7 +2584,6 @@ def sign(x): return math_ops.sign(x) -@keras_export('keras.backend.pow') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def pow(x, a): @@ -2667,7 +2599,6 @@ def pow(x, a): return math_ops.pow(x, a) -@keras_export('keras.backend.clip') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def clip(x, min_value, max_value): @@ -2692,7 +2623,6 @@ def clip(x, min_value, max_value): return clip_ops.clip_by_value(x, min_value, max_value) -@keras_export('keras.backend.equal') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def equal(x, y): @@ -2708,7 +2638,6 @@ def equal(x, y): return math_ops.equal(x, y) -@keras_export('keras.backend.not_equal') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def not_equal(x, y): @@ -2724,7 +2653,6 @@ def not_equal(x, y): return math_ops.not_equal(x, y) -@keras_export('keras.backend.greater') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def greater(x, y): @@ -2740,7 +2668,6 @@ def greater(x, y): return math_ops.greater(x, y) -@keras_export('keras.backend.greater_equal') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def greater_equal(x, y): @@ -2756,7 +2683,6 @@ def greater_equal(x, y): return math_ops.greater_equal(x, y) -@keras_export('keras.backend.less') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def less(x, y): @@ -2772,7 +2698,6 @@ def less(x, y): return math_ops.less(x, y) -@keras_export('keras.backend.less_equal') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def less_equal(x, y): @@ -2788,7 +2713,6 @@ def less_equal(x, y): return math_ops.less_equal(x, y) -@keras_export('keras.backend.maximum') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def maximum(x, y): @@ -2814,7 +2738,6 @@ def maximum(x, y): return math_ops.maximum(x, y) -@keras_export('keras.backend.minimum') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def minimum(x, y): @@ -2830,7 +2753,6 @@ def minimum(x, y): return math_ops.minimum(x, y) -@keras_export('keras.backend.sin') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def sin(x): @@ -2845,7 +2767,6 @@ def sin(x): return math_ops.sin(x) -@keras_export('keras.backend.cos') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def cos(x): @@ -2962,7 +2883,6 @@ def _fused_normalize_batch_in_training(x, x, gamma, beta, epsilon=epsilon, data_format=tf_data_format) -@keras_export('keras.backend.normalize_batch_in_training') @doc_controls.do_not_generate_docs def normalize_batch_in_training(x, gamma, beta, reduction_axes, epsilon=1e-3): """Computes mean and std for batch then apply batch_normalization on batch. @@ -2993,7 +2913,6 @@ def normalize_batch_in_training(x, gamma, beta, reduction_axes, epsilon=1e-3): x, gamma, beta, reduction_axes, epsilon=epsilon) -@keras_export('keras.backend.batch_normalization') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def batch_normalization(x, mean, var, beta, gamma, axis=-1, epsilon=1e-3): @@ -3057,7 +2976,6 @@ def batch_normalization(x, mean, var, beta, gamma, axis=-1, epsilon=1e-3): # SHAPE OPERATIONS -@keras_export('keras.backend.concatenate') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def concatenate(tensors, axis=-1): @@ -3096,7 +3014,6 @@ def concatenate(tensors, axis=-1): return array_ops.concat([to_dense(x) for x in tensors], axis) -@keras_export('keras.backend.reshape') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def reshape(x, shape): @@ -3127,7 +3044,6 @@ def reshape(x, shape): return array_ops.reshape(x, shape) -@keras_export('keras.backend.permute_dimensions') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def permute_dimensions(x, pattern): @@ -3160,7 +3076,6 @@ def permute_dimensions(x, pattern): return array_ops.transpose(x, perm=pattern) -@keras_export('keras.backend.resize_images') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def resize_images(x, height_factor, width_factor, data_format, @@ -3213,7 +3128,6 @@ def resize_images(x, height_factor, width_factor, data_format, return x -@keras_export('keras.backend.resize_volumes') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def resize_volumes(x, depth_factor, height_factor, width_factor, data_format): @@ -3247,7 +3161,6 @@ def resize_volumes(x, depth_factor, height_factor, width_factor, data_format): raise ValueError('Invalid data_format: ' + str(data_format)) -@keras_export('keras.backend.repeat_elements') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def repeat_elements(x, rep, axis): @@ -3310,7 +3223,6 @@ def repeat_elements(x, rep, axis): return x_rep -@keras_export('keras.backend.repeat') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def repeat(x, n): @@ -3347,7 +3259,6 @@ def repeat(x, n): return array_ops.tile(x, pattern) -@keras_export('keras.backend.arange') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def arange(start, stop=None, step=1, dtype='int32'): @@ -3387,7 +3298,6 @@ def arange(start, stop=None, step=1, dtype='int32'): return result -@keras_export('keras.backend.tile') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def tile(x, n): @@ -3406,7 +3316,6 @@ def tile(x, n): return array_ops.tile(x, n) -@keras_export('keras.backend.flatten') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def flatten(x): @@ -3433,7 +3342,6 @@ def flatten(x): return array_ops.reshape(x, [-1]) -@keras_export('keras.backend.batch_flatten') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def batch_flatten(x): @@ -3460,7 +3368,6 @@ def batch_flatten(x): return x -@keras_export('keras.backend.expand_dims') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def expand_dims(x, axis=-1): @@ -3476,7 +3383,6 @@ def expand_dims(x, axis=-1): return array_ops.expand_dims(x, axis) -@keras_export('keras.backend.squeeze') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def squeeze(x, axis): @@ -3492,7 +3398,6 @@ def squeeze(x, axis): return array_ops.squeeze(x, [axis]) -@keras_export('keras.backend.temporal_padding') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def temporal_padding(x, padding=(1, 1)): @@ -3511,7 +3416,6 @@ def temporal_padding(x, padding=(1, 1)): return array_ops.pad(x, pattern) -@keras_export('keras.backend.spatial_2d_padding') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def spatial_2d_padding(x, padding=((1, 1), (1, 1)), data_format=None): @@ -3544,7 +3448,6 @@ def spatial_2d_padding(x, padding=((1, 1), (1, 1)), data_format=None): return array_ops.pad(x, pattern) -@keras_export('keras.backend.spatial_3d_padding') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def spatial_3d_padding(x, padding=((1, 1), (1, 1), (1, 1)), data_format=None): @@ -3590,7 +3493,6 @@ def spatial_3d_padding(x, padding=((1, 1), (1, 1), (1, 1)), data_format=None): return array_ops.pad(x, pattern) -@keras_export('keras.backend.stack') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def stack(x, axis=0): @@ -3618,7 +3520,6 @@ def stack(x, axis=0): return array_ops_stack.stack(x, axis=axis) -@keras_export('keras.backend.one_hot') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def one_hot(indices, num_classes): @@ -3639,7 +3540,6 @@ def one_hot(indices, num_classes): return array_ops.one_hot(indices, depth=num_classes, axis=-1) -@keras_export('keras.backend.reverse') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def reverse(x, axes): @@ -3687,7 +3587,6 @@ def reverse(x, axes): 3.0"""[3:] # Prune first newline and indent to match the docstring template. -@keras_export('keras.backend.get_value') @doc_controls.do_not_generate_docs def get_value(x): """Returns the value of a variable. @@ -3723,7 +3622,6 @@ def get_value(x): return x.eval(session=get_session((x,))) -@keras_export('keras.backend.batch_get_value') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def batch_get_value(tensors): @@ -3748,7 +3646,6 @@ def batch_get_value(tensors): return [] -@keras_export('keras.backend.set_value') @doc_controls.do_not_generate_docs def set_value(x, value): """Sets the value of a variable, from a Numpy array. @@ -3787,7 +3684,6 @@ def set_value(x, value): get_session().run(assign_op, feed_dict={assign_placeholder: value}) -@keras_export('keras.backend.batch_set_value') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def batch_set_value(tuples): @@ -3831,7 +3727,6 @@ def batch_set_value(tuples): set_value.__doc__ = set_value.__doc__.format(snippet=_VALUE_SET_CODE_STRING) -@keras_export('keras.backend.print_tensor') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def print_tensor(x, message='', summarize=3): @@ -4070,7 +3965,6 @@ def __call__(self, inputs): return nest.map_structure(self._eval_if_composite, output_structure) -@keras_export('keras.backend.function') @doc_controls.do_not_generate_docs def function(inputs, outputs, updates=None, name=None, **kwargs): """Instantiates a Keras function. @@ -4119,7 +4013,6 @@ def func(model_inputs): inputs, outputs, updates=updates, name=name, **kwargs) -@keras_export('keras.backend.gradients') @doc_controls.do_not_generate_docs def gradients(loss, variables): """Returns the gradients of `loss` w.r.t. `variables`. @@ -4135,7 +4028,6 @@ def gradients(loss, variables): loss, variables, colocate_gradients_with_ops=True) -@keras_export('keras.backend.stop_gradient') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def stop_gradient(variables): @@ -4158,7 +4050,6 @@ def stop_gradient(variables): # CONTROL FLOW -@keras_export('keras.backend.rnn') @dispatch.add_dispatch_support def rnn(step_function, inputs, @@ -4551,7 +4442,6 @@ def set_shape(output_): return last_output, outputs, new_states -@keras_export('keras.backend.switch') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def switch(condition, then_expression, else_expression): @@ -4617,7 +4507,6 @@ def else_expression_fn(): return x -@keras_export('keras.backend.in_train_phase') @doc_controls.do_not_generate_docs def in_train_phase(x, alt, training=None): """Selects `x` in train phase, and `alt` otherwise. @@ -4663,7 +4552,6 @@ def in_train_phase(x, alt, training=None): return x -@keras_export('keras.backend.in_test_phase') @doc_controls.do_not_generate_docs def in_test_phase(x, alt, training=None): """Selects `x` in test phase, and `alt` otherwise. @@ -4688,7 +4576,6 @@ def in_test_phase(x, alt, training=None): # NN OPERATIONS -@keras_export('keras.backend.relu') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def relu(x, alpha=0., max_value=None, threshold=0): @@ -4746,7 +4633,6 @@ def relu(x, alpha=0., max_value=None, threshold=0): return x -@keras_export('keras.backend.elu') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def elu(x, alpha=1.): @@ -4766,7 +4652,6 @@ def elu(x, alpha=1.): return array_ops.where_v2(x > 0, res, alpha * res) -@keras_export('keras.backend.softmax') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def softmax(x, axis=-1): @@ -4783,7 +4668,6 @@ def softmax(x, axis=-1): return nn.softmax(x, axis=axis) -@keras_export('keras.backend.softplus') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def softplus(x): @@ -4798,7 +4682,6 @@ def softplus(x): return math_ops.softplus(x) -@keras_export('keras.backend.softsign') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def softsign(x): @@ -4813,7 +4696,6 @@ def softsign(x): return nn.softsign(x) -@keras_export('keras.backend.categorical_crossentropy') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def categorical_crossentropy(target, output, from_logits=False, axis=-1): @@ -4896,7 +4778,6 @@ def categorical_crossentropy(target, output, from_logits=False, axis=-1): return -math_ops.reduce_sum(target * math_ops.log(output), axis) -@keras_export('keras.backend.sparse_categorical_crossentropy') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1): @@ -4990,7 +4871,6 @@ def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1): return res -@keras_export('keras.backend.binary_crossentropy') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def binary_crossentropy(target, output, from_logits=False): @@ -5041,7 +4921,6 @@ def binary_crossentropy(target, output, from_logits=False): return -bce -@keras_export('keras.backend.sigmoid') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def sigmoid(x): @@ -5056,7 +4935,6 @@ def sigmoid(x): return nn.sigmoid(x) -@keras_export('keras.backend.hard_sigmoid') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def hard_sigmoid(x): @@ -5080,7 +4958,6 @@ def hard_sigmoid(x): return x -@keras_export('keras.backend.tanh') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def tanh(x): @@ -5095,7 +4972,6 @@ def tanh(x): return nn.tanh(x) -@keras_export('keras.backend.dropout') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def dropout(x, level, noise_shape=None, seed=None): @@ -5117,7 +4993,6 @@ def dropout(x, level, noise_shape=None, seed=None): return nn.dropout_v2(x, rate=level, noise_shape=noise_shape, seed=seed) -@keras_export('keras.backend.l2_normalize') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def l2_normalize(x, axis=None): @@ -5133,7 +5008,6 @@ def l2_normalize(x, axis=None): return nn.l2_normalize(x, axis=axis) -@keras_export('keras.backend.in_top_k') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def in_top_k(predictions, targets, k): @@ -5237,7 +5111,6 @@ def _preprocess_padding(padding): return padding -@keras_export('keras.backend.conv1d') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def conv1d(x, @@ -5289,7 +5162,6 @@ def conv1d(x, return x -@keras_export('keras.backend.conv2d') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def conv2d(x, @@ -5334,7 +5206,6 @@ def conv2d(x, return x -@keras_export('keras.backend.conv2d_transpose') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def conv2d_transpose(x, @@ -5476,7 +5347,6 @@ def separable_conv1d(x, return x -@keras_export('keras.backend.separable_conv2d') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def separable_conv2d(x, @@ -5535,7 +5405,6 @@ def separable_conv2d(x, return x -@keras_export('keras.backend.depthwise_conv2d') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def depthwise_conv2d(x, @@ -5586,7 +5455,6 @@ def depthwise_conv2d(x, return x -@keras_export('keras.backend.conv3d') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def conv3d(x, @@ -5690,7 +5558,6 @@ def conv3d_transpose(x, return x -@keras_export('keras.backend.pool2d') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def pool2d(x, @@ -5751,7 +5618,6 @@ def pool2d(x, return x -@keras_export('keras.backend.pool3d') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def pool3d(x, @@ -5883,7 +5749,6 @@ def local_conv(inputs, return permute_dimensions(output, permutation) -@keras_export('keras.backend.local_conv1d') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def local_conv1d(inputs, kernel, kernel_size, strides, data_format=None): @@ -5920,7 +5785,6 @@ def local_conv1d(inputs, kernel, kernel_size, strides, data_format=None): data_format) -@keras_export('keras.backend.local_conv2d') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def local_conv2d(inputs, @@ -5963,7 +5827,6 @@ def local_conv2d(inputs, data_format) -@keras_export('keras.backend.bias_add') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def bias_add(x, bias, data_format=None): @@ -6009,7 +5872,6 @@ def bias_add(x, bias, data_format=None): # RANDOMNESS -@keras_export('keras.backend.random_normal') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def random_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): @@ -6047,7 +5909,6 @@ def random_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): shape, mean=mean, stddev=stddev, dtype=dtype, seed=seed) -@keras_export('keras.backend.random_uniform') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def random_uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None): @@ -6081,7 +5942,6 @@ def random_uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None): shape, minval=minval, maxval=maxval, dtype=dtype, seed=seed) -@keras_export('keras.backend.random_binomial') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def random_binomial(shape, p=0.0, dtype=None, seed=None): @@ -6116,7 +5976,6 @@ def random_binomial(shape, p=0.0, dtype=None, seed=None): return random_bernoulli(shape, p, dtype, seed) -@keras_export('keras.backend.random_bernoulli') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def random_bernoulli(shape, p=0.0, dtype=None, seed=None): @@ -6140,7 +5999,6 @@ def random_bernoulli(shape, p=0.0, dtype=None, seed=None): array_ops.ones(shape, dtype=dtype), array_ops.zeros(shape, dtype=dtype)) -@keras_export('keras.backend.truncated_normal') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): @@ -6176,7 +6034,6 @@ def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): # in TensorFlow's CTC implementation -@keras_export('keras.backend.ctc_label_dense_to_sparse') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def ctc_label_dense_to_sparse(labels, label_lengths): @@ -6224,7 +6081,6 @@ def range_less_than(old_input, current_input): math_ops.cast(label_shape, dtypes_module.int64)) -@keras_export('keras.backend.ctc_batch_cost') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def ctc_batch_cost(y_true, y_pred, input_length, label_length): @@ -6258,7 +6114,6 @@ def ctc_batch_cost(y_true, y_pred, input_length, label_length): inputs=y_pred, labels=sparse_labels, sequence_length=input_length), 1) -@keras_export('keras.backend.ctc_decode') @dispatch.add_dispatch_support @doc_controls.do_not_generate_docs def ctc_decode(y_pred, input_length, greedy=True, beam_width=100, top_paths=1): @@ -6316,7 +6171,6 @@ def ctc_decode(y_pred, input_length, greedy=True, beam_width=100, top_paths=1): # HIGH ORDER FUNCTIONS -@keras_export('keras.backend.map_fn') @doc_controls.do_not_generate_docs def map_fn(fn, elems, name=None, dtype=None): """Map the function fn over the elements elems and return the outputs. @@ -6333,7 +6187,6 @@ def map_fn(fn, elems, name=None, dtype=None): return map_fn_lib.map_fn(fn, elems, name=name, dtype=dtype) -@keras_export('keras.backend.foldl') @doc_controls.do_not_generate_docs def foldl(fn, elems, initializer=None, name=None): """Reduce elems using fn to combine them from left to right. @@ -6351,7 +6204,6 @@ def foldl(fn, elems, initializer=None, name=None): return functional_ops.foldl(fn, elems, initializer=initializer, name=name) -@keras_export('keras.backend.foldr') @doc_controls.do_not_generate_docs def foldr(fn, elems, initializer=None, name=None): """Reduce elems using fn to combine them from right to left. diff --git a/tensorflow/python/keras/backend_config.py b/tensorflow/python/keras/backend_config.py index 63bdb2b96c1222..ad2adba81f23e2 100644 --- a/tensorflow/python/keras/backend_config.py +++ b/tensorflow/python/keras/backend_config.py @@ -15,7 +15,6 @@ """Keras backend config API.""" from tensorflow.python.util import dispatch -from tensorflow.python.util.tf_export import keras_export # The type of float to use throughout a session. _FLOATX = 'float32' @@ -27,7 +26,6 @@ _IMAGE_DATA_FORMAT = 'channels_last' -@keras_export('keras.backend.epsilon') @dispatch.add_dispatch_support def epsilon(): """Returns the value of the fuzz factor used in numeric expressions. @@ -42,7 +40,6 @@ def epsilon(): return _EPSILON -@keras_export('keras.backend.set_epsilon') def set_epsilon(value): """Sets the value of the fuzz factor used in numeric expressions. @@ -61,7 +58,6 @@ def set_epsilon(value): _EPSILON = value -@keras_export('keras.backend.floatx') def floatx(): """Returns the default float type, as a string. @@ -77,7 +73,6 @@ def floatx(): return _FLOATX -@keras_export('keras.backend.set_floatx') def set_floatx(value): """Sets the default float type. @@ -108,7 +103,6 @@ def set_floatx(value): _FLOATX = str(value) -@keras_export('keras.backend.image_data_format') @dispatch.add_dispatch_support def image_data_format(): """Returns the default image data format convention. @@ -123,7 +117,6 @@ def image_data_format(): return _IMAGE_DATA_FORMAT -@keras_export('keras.backend.set_image_data_format') def set_image_data_format(data_format): """Sets the value of the image data format convention. diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py index c5cbd1873c3058..f243a952d987a5 100644 --- a/tensorflow/python/keras/callbacks.py +++ b/tensorflow/python/keras/callbacks.py @@ -62,7 +62,6 @@ from tensorflow.python.profiler import profiler_v2 as profiler from tensorflow.python.saved_model import save_options as save_options_lib from tensorflow.python.util import nest -from tensorflow.python.util.tf_export import keras_export from tensorflow.tools.docs import doc_controls try: @@ -199,7 +198,6 @@ def make_logs(model, logs, outputs, mode, prefix=''): return logs -@keras_export('keras.callbacks.CallbackList') class CallbackList: """Container abstracting a list of callbacks.""" @@ -590,7 +588,6 @@ def _disallow_batch_hooks_in_ps_strategy(self): # pylint: enable=protected-access -@keras_export('keras.callbacks.Callback') class Callback: """Abstract base class used to build new callbacks. @@ -902,7 +899,6 @@ def _implements_predict_batch_hooks(self): not generic_utils.is_default(self.on_predict_batch_end)) -@keras_export('keras.callbacks.BaseLogger') class BaseLogger(Callback): """Callback that accumulates epoch averages of metrics. @@ -951,7 +947,6 @@ def on_epoch_end(self, epoch, logs=None): logs[k] = self.totals[k] / self.seen -@keras_export('keras.callbacks.TerminateOnNaN') class TerminateOnNaN(Callback): """Callback that terminates training when a NaN loss is encountered. """ @@ -970,7 +965,6 @@ def on_batch_end(self, batch, logs=None): self.model.stop_training = True -@keras_export('keras.callbacks.ProgbarLogger') class ProgbarLogger(Callback): """Callback that prints metrics to stdout. @@ -1137,7 +1131,6 @@ def _finalize_progbar(self, logs, counter): self.progbar.update(self.target, list(logs.items()), finalize=True) -@keras_export('keras.callbacks.History') class History(Callback): """Callback that records events into a `History` object. @@ -1177,7 +1170,6 @@ def on_epoch_end(self, epoch, logs=None): self.model.history = self -@keras_export('keras.callbacks.ModelCheckpoint') class ModelCheckpoint(Callback): """Callback to save the Keras model or model weights at some frequency. @@ -1593,7 +1585,6 @@ def _get_most_recently_modified_file_matching_pattern(self, pattern): return file_path_with_largest_file_name -@keras_export('keras.callbacks.experimental.BackupAndRestore', v1=[]) class BackupAndRestore(Callback): """Callback to back up and restore the training state. @@ -1708,7 +1699,6 @@ def on_epoch_end(self, epoch, logs=None): self._training_state.back_up(epoch) -@keras_export('keras.callbacks.EarlyStopping') class EarlyStopping(Callback): """Stop training when a monitored metric has stopped improving. @@ -1852,7 +1842,6 @@ def _is_improvement(self, monitor_value, reference_value): return self.monitor_op(monitor_value - self.min_delta, reference_value) -@keras_export('keras.callbacks.RemoteMonitor') class RemoteMonitor(Callback): """Callback used to stream events to a server. @@ -1915,7 +1904,6 @@ def on_epoch_end(self, epoch, logs=None): 'root server at ' + str(self.root)) -@keras_export('keras.callbacks.LearningRateScheduler') class LearningRateScheduler(Callback): """Learning rate scheduler. @@ -2026,7 +2014,6 @@ def keras_model_summary(name, data, step=None): tag=tag, tensor=tensor, step=step, metadata=summary_metadata) -@keras_export('keras.callbacks.TensorBoard', v1=[]) class TensorBoard(Callback, version_utils.TensorBoardVersionSelector): # pylint: disable=line-too-long """Enable visualizations for TensorBoard. @@ -2596,7 +2583,6 @@ def _stop_profiler(self, save=True): self._profiler_started = False -@keras_export('keras.callbacks.ReduceLROnPlateau') class ReduceLROnPlateau(Callback): """Reduce learning rate when a metric has stopped improving. @@ -2720,7 +2706,6 @@ def in_cooldown(self): return self.cooldown_counter > 0 -@keras_export('keras.callbacks.CSVLogger') class CSVLogger(Callback): """Callback that streams epoch results to a CSV file. @@ -2803,7 +2788,6 @@ def on_train_end(self, logs=None): self.writer = None -@keras_export('keras.callbacks.LambdaCallback') class LambdaCallback(Callback): r"""Callback for creating simple, custom callbacks on-the-fly. diff --git a/tensorflow/python/keras/callbacks_v1.py b/tensorflow/python/keras/callbacks_v1.py index dc9505dd5cde8c..58e5cc534b37b3 100644 --- a/tensorflow/python/keras/callbacks_v1.py +++ b/tensorflow/python/keras/callbacks_v1.py @@ -32,10 +32,8 @@ from tensorflow.python.profiler import profiler_v2 as profiler from tensorflow.python.summary import summary as tf_summary from tensorflow.python.training import saver -from tensorflow.python.util.tf_export import keras_export -@keras_export(v1=['keras.callbacks.TensorBoard']) class TensorBoard(callbacks.TensorBoard): # pylint: disable=line-too-long """Enable visualizations for TensorBoard. diff --git a/tensorflow/python/keras/constraints.py b/tensorflow/python/keras/constraints.py index a686b9b8c989ef..85a38e8dc85066 100644 --- a/tensorflow/python/keras/constraints.py +++ b/tensorflow/python/keras/constraints.py @@ -24,11 +24,9 @@ from tensorflow.python.ops import array_ops_stack from tensorflow.python.ops import math_ops from tensorflow.python.ops import while_loop -from tensorflow.python.util.tf_export import keras_export from tensorflow.tools.docs import doc_controls -@keras_export('keras.constraints.Constraint') class Constraint: """Base class for weight constraints. @@ -80,7 +78,6 @@ def get_config(self): return {} -@keras_export('keras.constraints.MaxNorm', 'keras.constraints.max_norm') class MaxNorm(Constraint): """MaxNorm weight constraint. @@ -121,7 +118,6 @@ def get_config(self): return {'max_value': self.max_value, 'axis': self.axis} -@keras_export('keras.constraints.NonNeg', 'keras.constraints.non_neg') class NonNeg(Constraint): """Constrains the weights to be non-negative. @@ -132,7 +128,6 @@ def __call__(self, w): return w * math_ops.cast(math_ops.greater_equal(w, 0.), backend.floatx()) -@keras_export('keras.constraints.UnitNorm', 'keras.constraints.unit_norm') class UnitNorm(Constraint): """Constrains the weights incident to each hidden unit to have unit norm. @@ -167,7 +162,6 @@ def get_config(self): return {'axis': self.axis} -@keras_export('keras.constraints.MinMaxNorm', 'keras.constraints.min_max_norm') class MinMaxNorm(Constraint): """MinMaxNorm weight constraint. @@ -224,8 +218,6 @@ def get_config(self): } -@keras_export('keras.constraints.RadialConstraint', - 'keras.constraints.radial_constraint') class RadialConstraint(Constraint): """Constrains `Conv2D` kernel weights to be the same for each radius. @@ -322,12 +314,10 @@ def body_fn(i, array): unitnorm = unit_norm -@keras_export('keras.constraints.serialize') def serialize(constraint): return serialize_keras_object(constraint) -@keras_export('keras.constraints.deserialize') def deserialize(config, custom_objects=None): return deserialize_keras_object( config, @@ -336,7 +326,6 @@ def deserialize(config, custom_objects=None): printable_module_name='constraint') -@keras_export('keras.constraints.get') def get(identifier): if identifier is None: return None diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index 6b97d9ce9904ed..d7f5cbfdf92e0a 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -75,7 +75,6 @@ from tensorflow.python.util import compat from tensorflow.python.util import nest from tensorflow.python.util.tf_export import get_canonical_name_for_symbol -from tensorflow.python.util.tf_export import keras_export from tensorflow.tools.docs import doc_controls # A module that only depends on `keras.layers` import these from here. @@ -95,7 +94,6 @@ ragged_tensor.RaggedTensor) -@keras_export('keras.layers.Layer') class Layer(module.Module, version_utils.LayerVersionSelector): """This is the class from which all layers inherit. diff --git a/tensorflow/python/keras/engine/base_layer_utils.py b/tensorflow/python/keras/engine/base_layer_utils.py index f18cce3609b5a4..0fd0b92d37711c 100644 --- a/tensorflow/python/keras/engine/base_layer_utils.py +++ b/tensorflow/python/keras/engine/base_layer_utils.py @@ -36,7 +36,6 @@ from tensorflow.python.trackable import base as tracking from tensorflow.python.training.saving import saveable_object_util from tensorflow.python.util import nest -from tensorflow.python.util.tf_export import keras_export _call_context = threading.local() @@ -723,7 +722,6 @@ def _mark_as_return(tensor): V2_DTYPE_BEHAVIOR = None -@keras_export(v1=['keras.layers.enable_v2_dtype_behavior']) def enable_v2_dtype_behavior(): """Enable the V2 dtype behavior for Keras layers. @@ -758,7 +756,6 @@ def enable_v2_dtype_behavior(): V2_DTYPE_BEHAVIOR = True -@keras_export(v1=['keras.layers.disable_v2_dtype_behavior']) def disable_v2_dtype_behavior(): """Disables the V2 dtype behavior for Keras layers. diff --git a/tensorflow/python/keras/engine/base_preprocessing_layer.py b/tensorflow/python/keras/engine/base_preprocessing_layer.py index 0d29c21c83dcbb..bde3f322b48468 100644 --- a/tensorflow/python/keras/engine/base_preprocessing_layer.py +++ b/tensorflow/python/keras/engine/base_preprocessing_layer.py @@ -35,10 +35,8 @@ from tensorflow.python.ops import variables from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.trackable import base as trackable -from tensorflow.python.util.tf_export import keras_export -@keras_export('keras.layers.experimental.preprocessing.PreprocessingLayer') class PreprocessingLayer(Layer, metaclass=abc.ABCMeta): """Base class for Preprocessing Layers. diff --git a/tensorflow/python/keras/engine/data_adapter.py b/tensorflow/python/keras/engine/data_adapter.py index 50a58757df34bc..f48da337048157 100644 --- a/tensorflow/python/keras/engine/data_adapter.py +++ b/tensorflow/python/keras/engine/data_adapter.py @@ -49,7 +49,6 @@ from tensorflow.python.platform import tf_logging as logging from tensorflow.python.types import data as data_types from tensorflow.python.util import nest -from tensorflow.python.util.tf_export import keras_export class DataAdapter(object, metaclass=abc.ABCMeta): @@ -1528,7 +1527,6 @@ def _split(t, start, end): return train_arrays, val_arrays -@keras_export("keras.utils.unpack_x_y_sample_weight", v1=[]) def unpack_x_y_sample_weight(data): """Unpacks user-provided data tuple. @@ -1590,7 +1588,6 @@ def train_step(self, data): raise ValueError(error_msg) -@keras_export("keras.utils.pack_x_y_sample_weight", v1=[]) def pack_x_y_sample_weight(x, y=None, sample_weight=None): """Packs user-provided data into a tuple. diff --git a/tensorflow/python/keras/engine/input_layer.py b/tensorflow/python/keras/engine/input_layer.py index b659860cbe4435..6af2f52896d15b 100644 --- a/tensorflow/python/keras/engine/input_layer.py +++ b/tensorflow/python/keras/engine/input_layer.py @@ -26,7 +26,6 @@ from tensorflow.python.keras.engine import node as node_module from tensorflow.python.keras.saving.saved_model import layer_serialization from tensorflow.python.keras.utils import tf_utils -from tensorflow.python.util.tf_export import keras_export def _assert_other_arg_none(arg_name, arg): @@ -36,7 +35,6 @@ def _assert_other_arg_none(arg_name, arg): 'but %s is not None.' % arg_name) -@keras_export('keras.layers.InputLayer') class InputLayer(base_layer.Layer): """Layer to be used as an entry point into a Network (a graph of layers). @@ -254,7 +252,6 @@ def _trackable_saved_model_saver(self): return layer_serialization.InputLayerSavedModelSaver(self) -@keras_export('keras.Input', 'keras.layers.Input') def Input( # pylint: disable=invalid-name shape=None, batch_size=None, diff --git a/tensorflow/python/keras/engine/input_spec.py b/tensorflow/python/keras/engine/input_spec.py index b1b791fc003197..27dac2998a1332 100644 --- a/tensorflow/python/keras/engine/input_spec.py +++ b/tensorflow/python/keras/engine/input_spec.py @@ -21,11 +21,9 @@ from tensorflow.python.framework import tensor_spec from tensorflow.python.keras import backend from tensorflow.python.util import nest -from tensorflow.python.util.tf_export import keras_export from tensorflow.python.util.tf_export import tf_export -@keras_export('keras.layers.InputSpec') @tf_export(v1=['layers.InputSpec']) class InputSpec(object): """Specifies the rank, dtype and shape of every input to a layer. diff --git a/tensorflow/python/keras/engine/sequential.py b/tensorflow/python/keras/engine/sequential.py index d03e2a68c0e299..0f46e17d37837d 100644 --- a/tensorflow/python/keras/engine/sequential.py +++ b/tensorflow/python/keras/engine/sequential.py @@ -36,7 +36,6 @@ from tensorflow.python.platform import tf_logging as logging from tensorflow.python.trackable import base as trackable from tensorflow.python.util import nest -from tensorflow.python.util.tf_export import keras_export SINGLE_LAYER_OUTPUT_ERROR_MSG = ('All layers in a Sequential model should have ' @@ -44,7 +43,6 @@ 'layers, use the functional API.') -@keras_export('keras.Sequential', 'keras.models.Sequential') class Sequential(functional.Functional): """`Sequential` groups a linear stack of layers into a `tf.keras.Model`. diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index 56fcbaaeb4e4bc..1e94ca45aef0d1 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -78,7 +78,6 @@ from tensorflow.python.types import data as data_types from tensorflow.python.util import nest from tensorflow.python.util import tf_decorator -from tensorflow.python.util.tf_export import keras_export from tensorflow.tools.docs import doc_controls @@ -129,7 +128,6 @@ def is_functional_model_init_params(args, kwargs): 'inputs' in kwargs and 'outputs' in kwargs) -@keras_export('keras.Model', 'keras.models.Model') class Model(base_layer.Layer, version_utils.ModelVersionSelector): """`Model` groups layers into an object with training and inference features. diff --git a/tensorflow/python/keras/initializers/__init__.py b/tensorflow/python/keras/initializers/__init__.py index 222bb2d4a863eb..deba75c8b7e1a2 100644 --- a/tensorflow/python/keras/initializers/__init__.py +++ b/tensorflow/python/keras/initializers/__init__.py @@ -22,7 +22,6 @@ from tensorflow.python.keras.utils import generic_utils from tensorflow.python.keras.utils import tf_inspect as inspect from tensorflow.python.ops import init_ops -from tensorflow.python.util.tf_export import keras_export # LOCAL.ALL_OBJECTS is meant to be a global mutable. Hence we need to make it @@ -125,12 +124,10 @@ def populate_deserializable_objects(): # Utility functions -@keras_export('keras.initializers.serialize') def serialize(initializer): return generic_utils.serialize_keras_object(initializer) -@keras_export('keras.initializers.deserialize') def deserialize(config, custom_objects=None): """Return an `Initializer` object from its config.""" populate_deserializable_objects() @@ -141,7 +138,6 @@ def deserialize(config, custom_objects=None): printable_module_name='initializer') -@keras_export('keras.initializers.get') def get(identifier): """Retrieve a Keras initializer by the identifier. diff --git a/tensorflow/python/keras/initializers/initializers_v1.py b/tensorflow/python/keras/initializers/initializers_v1.py index 3a4de90bc9289c..38cc0cee4ce15e 100644 --- a/tensorflow/python/keras/initializers/initializers_v1.py +++ b/tensorflow/python/keras/initializers/initializers_v1.py @@ -16,7 +16,6 @@ from tensorflow.python.framework import dtypes from tensorflow.python.ops import init_ops -from tensorflow.python.util.tf_export import keras_export _v1_zeros_initializer = init_ops.Zeros @@ -28,27 +27,7 @@ _v1_glorot_uniform_initializer = init_ops.GlorotUniform _v1_glorot_normal_initializer = init_ops.GlorotNormal -keras_export(v1=['keras.initializers.Zeros', 'keras.initializers.zeros'])( - _v1_zeros_initializer) -keras_export(v1=['keras.initializers.Ones', 'keras.initializers.ones'])( - _v1_ones_initializer) -keras_export(v1=['keras.initializers.Constant', 'keras.initializers.constant'])( - _v1_constant_initializer) -keras_export(v1=['keras.initializers.VarianceScaling'])( - _v1_variance_scaling_initializer) -keras_export(v1=['keras.initializers.Orthogonal', - 'keras.initializers.orthogonal'])(_v1_orthogonal_initializer) -keras_export(v1=['keras.initializers.Identity', - 'keras.initializers.identity'])(_v1_identity) -keras_export(v1=['keras.initializers.glorot_uniform'])( - _v1_glorot_uniform_initializer) -keras_export(v1=['keras.initializers.glorot_normal'])( - _v1_glorot_normal_initializer) - - -@keras_export(v1=['keras.initializers.RandomNormal', - 'keras.initializers.random_normal', - 'keras.initializers.normal']) + class RandomNormal(init_ops.RandomNormal): def __init__(self, mean=0.0, stddev=0.05, seed=None, dtype=dtypes.float32): @@ -56,9 +35,6 @@ def __init__(self, mean=0.0, stddev=0.05, seed=None, dtype=dtypes.float32): mean=mean, stddev=stddev, seed=seed, dtype=dtype) -@keras_export(v1=['keras.initializers.RandomUniform', - 'keras.initializers.random_uniform', - 'keras.initializers.uniform']) class RandomUniform(init_ops.RandomUniform): def __init__(self, minval=-0.05, maxval=0.05, seed=None, @@ -67,8 +43,6 @@ def __init__(self, minval=-0.05, maxval=0.05, seed=None, minval=minval, maxval=maxval, seed=seed, dtype=dtype) -@keras_export(v1=['keras.initializers.TruncatedNormal', - 'keras.initializers.truncated_normal']) class TruncatedNormal(init_ops.TruncatedNormal): def __init__(self, mean=0.0, stddev=0.05, seed=None, dtype=dtypes.float32): @@ -76,7 +50,6 @@ def __init__(self, mean=0.0, stddev=0.05, seed=None, dtype=dtypes.float32): mean=mean, stddev=stddev, seed=seed, dtype=dtype) -@keras_export(v1=['keras.initializers.lecun_normal']) class LecunNormal(init_ops.VarianceScaling): def __init__(self, seed=None): @@ -87,7 +60,6 @@ def get_config(self): return {'seed': self.seed} -@keras_export(v1=['keras.initializers.lecun_uniform']) class LecunUniform(init_ops.VarianceScaling): def __init__(self, seed=None): @@ -98,7 +70,6 @@ def get_config(self): return {'seed': self.seed} -@keras_export(v1=['keras.initializers.he_normal']) class HeNormal(init_ops.VarianceScaling): def __init__(self, seed=None): @@ -109,7 +80,6 @@ def get_config(self): return {'seed': self.seed} -@keras_export(v1=['keras.initializers.he_uniform']) class HeUniform(init_ops.VarianceScaling): def __init__(self, seed=None): diff --git a/tensorflow/python/keras/initializers/initializers_v2.py b/tensorflow/python/keras/initializers/initializers_v2.py index 67b304a5f83e9b..ba0a932aaf5b88 100644 --- a/tensorflow/python/keras/initializers/initializers_v2.py +++ b/tensorflow/python/keras/initializers/initializers_v2.py @@ -26,13 +26,11 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import stateless_random_ops -from tensorflow.python.util.tf_export import keras_export _PARTITION_SHAPE = 'partition_shape' _PARTITION_OFFSET = 'partition_offset' -@keras_export('keras.initializers.Initializer') class Initializer(object): """Initializer base class: all Keras initializers inherit from this class. @@ -114,7 +112,6 @@ def from_config(cls, config): return cls(**config) -@keras_export('keras.initializers.Zeros', 'keras.initializers.zeros', v1=[]) class Zeros(Initializer): """Initializer that generates tensors initialized to 0. @@ -151,7 +148,6 @@ def __call__(self, shape, dtype=None, **kwargs): return array_ops.zeros(shape, dtype) -@keras_export('keras.initializers.Ones', 'keras.initializers.ones', v1=[]) class Ones(Initializer): """Initializer that generates tensors initialized to 1. @@ -188,9 +184,6 @@ def __call__(self, shape, dtype=None, **kwargs): return array_ops.ones(shape, dtype) -@keras_export('keras.initializers.Constant', - 'keras.initializers.constant', - v1=[]) class Constant(Initializer): """Initializer that generates tensors with constant values. @@ -236,9 +229,6 @@ def get_config(self): return {'value': self.value} -@keras_export('keras.initializers.RandomUniform', - 'keras.initializers.random_uniform', - v1=[]) class RandomUniform(Initializer): """Initializer that generates tensors with a uniform distribution. @@ -299,9 +289,6 @@ def get_config(self): } -@keras_export('keras.initializers.RandomNormal', - 'keras.initializers.random_normal', - v1=[]) class RandomNormal(Initializer): """Initializer that generates tensors with a normal distribution. @@ -359,9 +346,6 @@ def get_config(self): } -@keras_export('keras.initializers.TruncatedNormal', - 'keras.initializers.truncated_normal', - v1=[]) class TruncatedNormal(Initializer): """Initializer that generates a truncated normal distribution. @@ -424,9 +408,6 @@ def get_config(self): } -@keras_export('keras.initializers.VarianceScaling', - 'keras.initializers.variance_scaling', - v1=[]) class VarianceScaling(Initializer): """Initializer capable of adapting its scale to the shape of weights tensors. @@ -531,9 +512,6 @@ def get_config(self): } -@keras_export('keras.initializers.Orthogonal', - 'keras.initializers.orthogonal', - v1=[]) class Orthogonal(Initializer): """Initializer that generates an orthogonal matrix. @@ -615,9 +593,6 @@ def get_config(self): return {'gain': self.gain, 'seed': self.seed} -@keras_export('keras.initializers.Identity', - 'keras.initializers.identity', - v1=[]) class Identity(Initializer): """Initializer that generates the identity matrix. @@ -665,9 +640,6 @@ def get_config(self): return {'gain': self.gain} -@keras_export('keras.initializers.GlorotUniform', - 'keras.initializers.glorot_uniform', - v1=[]) class GlorotUniform(VarianceScaling): """The Glorot uniform initializer, also called Xavier uniform initializer. @@ -708,9 +680,6 @@ def get_config(self): return {'seed': self.seed} -@keras_export('keras.initializers.GlorotNormal', - 'keras.initializers.glorot_normal', - v1=[]) class GlorotNormal(VarianceScaling): """The Glorot normal initializer, also called Xavier normal initializer. @@ -752,9 +721,6 @@ def get_config(self): return {'seed': self.seed} -@keras_export('keras.initializers.LecunNormal', - 'keras.initializers.lecun_normal', - v1=[]) class LecunNormal(VarianceScaling): """Lecun normal initializer. @@ -800,9 +766,6 @@ def get_config(self): return {'seed': self.seed} -@keras_export('keras.initializers.LecunUniform', - 'keras.initializers.lecun_uniform', - v1=[]) class LecunUniform(VarianceScaling): """Lecun uniform initializer. @@ -843,9 +806,6 @@ def get_config(self): return {'seed': self.seed} -@keras_export('keras.initializers.HeNormal', - 'keras.initializers.he_normal', - v1=[]) class HeNormal(VarianceScaling): """He normal initializer. @@ -883,9 +843,6 @@ def get_config(self): return {'seed': self.seed} -@keras_export('keras.initializers.HeUniform', - 'keras.initializers.he_uniform', - v1=[]) class HeUniform(VarianceScaling): """He uniform variance scaling initializer. diff --git a/tensorflow/python/keras/layers/advanced_activations.py b/tensorflow/python/keras/layers/advanced_activations.py index eefbb502d12d51..4adc3bed950888 100644 --- a/tensorflow/python/keras/layers/advanced_activations.py +++ b/tensorflow/python/keras/layers/advanced_activations.py @@ -24,14 +24,12 @@ from tensorflow.python.keras.engine.input_spec import InputSpec from tensorflow.python.keras.utils import tf_utils from tensorflow.python.ops import math_ops -from tensorflow.python.util.tf_export import keras_export def get_globals(): return globals() -@keras_export('keras.layers.LeakyReLU') class LeakyReLU(Layer): """Leaky version of a Rectified Linear Unit. @@ -88,7 +86,6 @@ def compute_output_shape(self, input_shape): return input_shape -@keras_export('keras.layers.PReLU') class PReLU(Layer): """Parametric Rectified Linear Unit. @@ -182,7 +179,6 @@ def compute_output_shape(self, input_shape): return input_shape -@keras_export('keras.layers.ELU') class ELU(Layer): """Exponential Linear Unit. @@ -226,7 +222,6 @@ def compute_output_shape(self, input_shape): return input_shape -@keras_export('keras.layers.ThresholdedReLU') class ThresholdedReLU(Layer): """Thresholded Rectified Linear Unit. @@ -291,7 +286,6 @@ def _large_compatible_negative(tensor_type): return -1e9 -@keras_export('keras.layers.Softmax') class Softmax(Layer): """Softmax activation function. @@ -359,7 +353,6 @@ def compute_output_shape(self, input_shape): return input_shape -@keras_export('keras.layers.ReLU') class ReLU(Layer): """Rectified Linear Unit activation function. diff --git a/tensorflow/python/keras/layers/convolutional.py b/tensorflow/python/keras/layers/convolutional.py index fa76e2772c485e..b2eb7089b6d095 100644 --- a/tensorflow/python/keras/layers/convolutional.py +++ b/tensorflow/python/keras/layers/convolutional.py @@ -40,7 +40,6 @@ from tensorflow.python.ops import array_ops_stack from tensorflow.python.ops import nn from tensorflow.python.ops import nn_ops -from tensorflow.python.util.tf_export import keras_export # pylint: disable=g-classes-have-attributes @@ -383,7 +382,6 @@ def _get_padding_op(self): return op_padding -@keras_export('keras.layers.Conv1D', 'keras.layers.Convolution1D') class Conv1D(Conv): """1D convolution layer (e.g. temporal convolution). @@ -524,7 +522,6 @@ def __init__(self, **kwargs) -@keras_export('keras.layers.Conv2D', 'keras.layers.Convolution2D') class Conv2D(Conv): """2D convolution layer (e.g. spatial convolution over images). @@ -684,7 +681,6 @@ def __init__(self, **kwargs) -@keras_export('keras.layers.Conv3D', 'keras.layers.Convolution3D') class Conv3D(Conv): """3D convolution layer (e.g. spatial convolution over volumes). @@ -831,8 +827,6 @@ def __init__(self, **kwargs) -@keras_export('keras.layers.Conv1DTranspose', - 'keras.layers.Convolution1DTranspose') class Conv1DTranspose(Conv1D): """Transposed convolution layer (sometimes called Deconvolution). @@ -1078,8 +1072,6 @@ def get_config(self): return config -@keras_export('keras.layers.Conv2DTranspose', - 'keras.layers.Convolution2DTranspose') class Conv2DTranspose(Conv2D): """Transposed convolution layer (sometimes called Deconvolution). @@ -1381,8 +1373,6 @@ def get_config(self): return config -@keras_export('keras.layers.Conv3DTranspose', - 'keras.layers.Convolution3DTranspose') class Conv3DTranspose(Conv3D): """Transposed convolution layer (sometimes called Deconvolution). @@ -1905,8 +1895,6 @@ def get_config(self): return dict(list(base_config.items()) + list(config.items())) -@keras_export('keras.layers.SeparableConv1D', - 'keras.layers.SeparableConvolution1D') class SeparableConv1D(SeparableConv): """Depthwise separable 1D convolution. @@ -2086,8 +2074,6 @@ def call(self, inputs): return outputs -@keras_export('keras.layers.SeparableConv2D', - 'keras.layers.SeparableConvolution2D') class SeparableConv2D(SeparableConv): """Depthwise separable 2D convolution. @@ -2260,7 +2246,6 @@ def call(self, inputs): return outputs -@keras_export('keras.layers.DepthwiseConv2D') class DepthwiseConv2D(Conv2D): """Depthwise 2D convolution. @@ -2491,7 +2476,6 @@ def get_config(self): return config -@keras_export('keras.layers.UpSampling1D') class UpSampling1D(Layer): """Upsampling layer for 1D inputs. @@ -2548,7 +2532,6 @@ def get_config(self): return dict(list(base_config.items()) + list(config.items())) -@keras_export('keras.layers.UpSampling2D') class UpSampling2D(Layer): """Upsampling layer for 2D inputs. @@ -2652,7 +2635,6 @@ def get_config(self): return dict(list(base_config.items()) + list(config.items())) -@keras_export('keras.layers.UpSampling3D') class UpSampling3D(Layer): """Upsampling layer for 3D inputs. @@ -2733,7 +2715,6 @@ def get_config(self): return dict(list(base_config.items()) + list(config.items())) -@keras_export('keras.layers.ZeroPadding1D') class ZeroPadding1D(Layer): """Zero-padding layer for 1D input (e.g. temporal sequence). @@ -2799,7 +2780,6 @@ def get_config(self): return dict(list(base_config.items()) + list(config.items())) -@keras_export('keras.layers.ZeroPadding2D') class ZeroPadding2D(Layer): """Zero-padding layer for 2D input (e.g. picture). @@ -2924,7 +2904,6 @@ def get_config(self): return dict(list(base_config.items()) + list(config.items())) -@keras_export('keras.layers.ZeroPadding3D') class ZeroPadding3D(Layer): """Zero-padding layer for 3D data (spatial or spatio-temporal). @@ -3050,7 +3029,6 @@ def get_config(self): return dict(list(base_config.items()) + list(config.items())) -@keras_export('keras.layers.Cropping1D') class Cropping1D(Layer): """Cropping layer for 1D input (e.g. temporal sequence). @@ -3111,7 +3089,6 @@ def get_config(self): return dict(list(base_config.items()) + list(config.items())) -@keras_export('keras.layers.Cropping2D') class Cropping2D(Layer): """Cropping layer for 2D input (e.g. picture). @@ -3238,7 +3215,6 @@ def get_config(self): return dict(list(base_config.items()) + list(config.items())) -@keras_export('keras.layers.Cropping3D') class Cropping3D(Layer): """Cropping layer for 3D data (e.g. spatial or spatio-temporal). diff --git a/tensorflow/python/keras/layers/convolutional_recurrent.py b/tensorflow/python/keras/layers/convolutional_recurrent.py index 0e56fc4fc1a089..15a38060a644aa 100644 --- a/tensorflow/python/keras/layers/convolutional_recurrent.py +++ b/tensorflow/python/keras/layers/convolutional_recurrent.py @@ -31,7 +31,6 @@ from tensorflow.python.keras.utils import generic_utils from tensorflow.python.keras.utils import tf_utils from tensorflow.python.ops import array_ops -from tensorflow.python.util.tf_export import keras_export class ConvRNN2D(RNN): @@ -693,7 +692,6 @@ def get_config(self): return dict(list(base_config.items()) + list(config.items())) -@keras_export('keras.layers.ConvLSTM2D') class ConvLSTM2D(ConvRNN2D): """2D Convolutional LSTM layer. diff --git a/tensorflow/python/keras/layers/core.py b/tensorflow/python/keras/layers/core.py index 1132b78dfe422d..c9be0a3cc5ba10 100644 --- a/tensorflow/python/keras/layers/core.py +++ b/tensorflow/python/keras/layers/core.py @@ -62,11 +62,9 @@ from tensorflow.python.util import tf_decorator from tensorflow.python.util.tf_export import get_canonical_name_for_symbol from tensorflow.python.util.tf_export import get_symbol_from_name -from tensorflow.python.util.tf_export import keras_export # pylint: disable=g-classes-have-attributes -@keras_export('keras.layers.Masking') class Masking(Layer): """Masks a sequence by using a mask value to skip timesteps. @@ -133,7 +131,6 @@ def get_config(self): return dict(list(base_config.items()) + list(config.items())) -@keras_export('keras.layers.Dropout') class Dropout(Layer): """Applies Dropout to the input. @@ -236,7 +233,6 @@ def get_config(self): return dict(list(base_config.items()) + list(config.items())) -@keras_export('keras.layers.SpatialDropout1D') class SpatialDropout1D(Dropout): """Spatial 1D version of Dropout. @@ -278,7 +274,6 @@ def _get_noise_shape(self, inputs): return noise_shape -@keras_export('keras.layers.SpatialDropout2D') class SpatialDropout2D(Dropout): """Spatial 2D version of Dropout. @@ -337,7 +332,6 @@ def _get_noise_shape(self, inputs): return (input_shape[0], 1, 1, input_shape[3]) -@keras_export('keras.layers.SpatialDropout3D') class SpatialDropout3D(Dropout): """Spatial 3D version of Dropout. @@ -395,7 +389,6 @@ def _get_noise_shape(self, inputs): return (input_shape[0], 1, 1, 1, input_shape[4]) -@keras_export('keras.layers.Activation') class Activation(Layer): """Applies an activation function to an output. @@ -440,7 +433,6 @@ def get_config(self): return dict(list(base_config.items()) + list(config.items())) -@keras_export('keras.layers.Reshape') class Reshape(Layer): """Layer that reshapes inputs into the given shape. @@ -555,7 +547,6 @@ def get_config(self): return dict(list(base_config.items()) + list(config.items())) -@keras_export('keras.layers.Permute') class Permute(Layer): """Permutes the dimensions of the input according to a given pattern. @@ -613,7 +604,6 @@ def get_config(self): return dict(list(base_config.items()) + list(config.items())) -@keras_export('keras.layers.Flatten') class Flatten(Layer): """Flattens the input. Does not affect the batch size. @@ -701,7 +691,6 @@ def get_config(self): return config -@keras_export('keras.layers.RepeatVector') class RepeatVector(Layer): """Repeats the input n times. @@ -747,7 +736,6 @@ def get_config(self): return dict(list(base_config.items()) + list(config.items())) -@keras_export('keras.layers.Lambda') class Lambda(Layer): """Wraps arbitrary expressions as a `Layer` object. @@ -1073,7 +1061,6 @@ def _parse_function_from_config( return function -@keras_export('keras.layers.Dense') class Dense(Layer): """Just your regular densely-connected NN layer. @@ -1280,7 +1267,6 @@ def get_config(self): return config -@keras_export('keras.layers.ActivityRegularization') class ActivityRegularization(Layer): """Layer that applies an update to the cost function based input activity. diff --git a/tensorflow/python/keras/layers/dense_attention.py b/tensorflow/python/keras/layers/dense_attention.py index 8a570f7245bc6b..08dd07c89fc9f7 100644 --- a/tensorflow/python/keras/layers/dense_attention.py +++ b/tensorflow/python/keras/layers/dense_attention.py @@ -28,7 +28,6 @@ from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn -from tensorflow.python.util.tf_export import keras_export class BaseDenseAttention(Layer): @@ -217,7 +216,6 @@ def get_config(self): return dict(list(base_config.items()) + list(config.items())) -@keras_export('keras.layers.Attention') class Attention(BaseDenseAttention): """Dot-product attention layer, a.k.a. Luong-style attention. @@ -355,7 +353,6 @@ def get_config(self): return dict(list(base_config.items()) + list(config.items())) -@keras_export('keras.layers.AdditiveAttention') class AdditiveAttention(BaseDenseAttention): """Additive attention layer, a.k.a. Bahdanau-style attention. diff --git a/tensorflow/python/keras/layers/embeddings.py b/tensorflow/python/keras/layers/embeddings.py index 381f3b9eac73a3..7a193cfd3de5bd 100644 --- a/tensorflow/python/keras/layers/embeddings.py +++ b/tensorflow/python/keras/layers/embeddings.py @@ -24,10 +24,8 @@ from tensorflow.python.keras.utils import tf_utils from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import math_ops -from tensorflow.python.util.tf_export import keras_export -@keras_export('keras.layers.Embedding') class Embedding(Layer): """Turns positive integers (indexes) into dense vectors of fixed size. diff --git a/tensorflow/python/keras/layers/legacy_rnn/rnn_cell_impl.py b/tensorflow/python/keras/layers/legacy_rnn/rnn_cell_impl.py index b7bcf9483180a3..f9a00f532bbb36 100644 --- a/tensorflow/python/keras/layers/legacy_rnn/rnn_cell_impl.py +++ b/tensorflow/python/keras/layers/legacy_rnn/rnn_cell_impl.py @@ -52,7 +52,6 @@ from tensorflow.python.platform import tf_logging as logging from tensorflow.python.trackable import base as trackable from tensorflow.python.util import nest -from tensorflow.python.util.tf_export import keras_export from tensorflow.python.util.tf_export import tf_export _BIAS_VARIABLE_NAME = "bias" @@ -182,7 +181,6 @@ def get_state_shape(s): return nest.map_structure(get_state_shape, state_size) -@keras_export(v1=["keras.__internal__.legacy.rnn_cell.RNNCell"]) @tf_export(v1=["nn.rnn_cell.RNNCell"]) class RNNCell(base_layer.Layer): """Abstract object representing an RNN cell. @@ -398,7 +396,6 @@ def __call__(self, inputs, state, scope=None, *args, **kwargs): self, inputs, state, scope=scope, *args, **kwargs) -@keras_export(v1=["keras.__internal__.legacy.rnn_cell.BasicRNNCell"]) @tf_export(v1=["nn.rnn_cell.BasicRNNCell"]) class BasicRNNCell(LayerRNNCell): """The most basic RNN cell. @@ -495,7 +492,6 @@ def get_config(self): return dict(list(base_config.items()) + list(config.items())) -@keras_export(v1=["keras.__internal__.legacy.rnn_cell.GRUCell"]) @tf_export(v1=["nn.rnn_cell.GRUCell"]) class GRUCell(LayerRNNCell): """Gated Recurrent Unit cell. @@ -635,7 +631,6 @@ def get_config(self): _LSTMStateTuple = collections.namedtuple("LSTMStateTuple", ("c", "h")) -@keras_export(v1=["keras.__internal__.legacy.rnn_cell.LSTMStateTuple"]) @tf_export(v1=["nn.rnn_cell.LSTMStateTuple"]) class LSTMStateTuple(_LSTMStateTuple): """Tuple used by LSTM Cells for `state_size`, `zero_state`, and output state. @@ -656,7 +651,6 @@ def dtype(self): return c.dtype -@keras_export(v1=["keras.__internal__.legacy.rnn_cell.BasicLSTMCell"]) @tf_export(v1=["nn.rnn_cell.BasicLSTMCell"]) class BasicLSTMCell(LayerRNNCell): """DEPRECATED: Please use `tf.compat.v1.nn.rnn_cell.LSTMCell` instead. @@ -827,7 +821,6 @@ def get_config(self): return dict(list(base_config.items()) + list(config.items())) -@keras_export(v1=["keras.__internal__.legacy.rnn_cell.LSTMCell"]) @tf_export(v1=["nn.rnn_cell.LSTMCell"]) class LSTMCell(LayerRNNCell): """Long short-term memory unit (LSTM) recurrent network cell. @@ -1195,7 +1188,6 @@ def from_config(cls, config, custom_objects=None): "instance.") -@keras_export(v1=["keras.__internal__.legacy.rnn_cell.DropoutWrapper"]) @tf_export(v1=["nn.rnn_cell.DropoutWrapper"]) class DropoutWrapper(rnn_cell_wrapper_impl.DropoutWrapperBase, _RNNCellWrapperV1): @@ -1207,7 +1199,6 @@ def __init__(self, *args, **kwargs): # pylint: disable=useless-super-delegation __init__.__doc__ = rnn_cell_wrapper_impl.DropoutWrapperBase.__init__.__doc__ -@keras_export(v1=["keras.__internal__.legacy.rnn_cell.ResidualWrapper"]) @tf_export(v1=["nn.rnn_cell.ResidualWrapper"]) class ResidualWrapper(rnn_cell_wrapper_impl.ResidualWrapperBase, _RNNCellWrapperV1): @@ -1219,7 +1210,6 @@ def __init__(self, *args, **kwargs): # pylint: disable=useless-super-delegation __init__.__doc__ = rnn_cell_wrapper_impl.ResidualWrapperBase.__init__.__doc__ -@keras_export(v1=["keras.__internal__.legacy.rnn_cell.DeviceWrapper"]) @tf_export(v1=["nn.rnn_cell.DeviceWrapper"]) class DeviceWrapper(rnn_cell_wrapper_impl.DeviceWrapperBase, _RNNCellWrapperV1): @@ -1230,7 +1220,6 @@ def __init__(self, *args, **kwargs): # pylint: disable=useless-super-delegation __init__.__doc__ = rnn_cell_wrapper_impl.DeviceWrapperBase.__init__.__doc__ -@keras_export(v1=["keras.__internal__.legacy.rnn_cell.MultiRNNCell"]) @tf_export(v1=["nn.rnn_cell.MultiRNNCell"]) class MultiRNNCell(RNNCell): """RNN cell composed sequentially of multiple simple cells. diff --git a/tensorflow/python/keras/layers/merge.py b/tensorflow/python/keras/layers/merge.py index 28612270cb744a..68461de0841f2f 100644 --- a/tensorflow/python/keras/layers/merge.py +++ b/tensorflow/python/keras/layers/merge.py @@ -24,7 +24,6 @@ from tensorflow.python.ops import array_ops_stack from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn -from tensorflow.python.util.tf_export import keras_export class _Merge(Layer): @@ -215,7 +214,6 @@ def compute_mask(self, inputs, mask=None): backend.concatenate(masks, axis=0), axis=0, keepdims=False) -@keras_export('keras.layers.Add') class Add(_Merge): """Layer that adds a list of inputs. @@ -252,7 +250,6 @@ def _merge_function(self, inputs): return output -@keras_export('keras.layers.Subtract') class Subtract(_Merge): """Layer that subtracts two inputs. @@ -291,7 +288,6 @@ def _merge_function(self, inputs): return inputs[0] - inputs[1] -@keras_export('keras.layers.Multiply') class Multiply(_Merge): """Layer that multiplies (element-wise) a list of inputs. @@ -321,7 +317,6 @@ def _merge_function(self, inputs): return output -@keras_export('keras.layers.Average') class Average(_Merge): """Layer that averages a list of inputs element-wise. @@ -358,7 +353,6 @@ def _merge_function(self, inputs): return output / len(inputs) -@keras_export('keras.layers.Maximum') class Maximum(_Merge): """Layer that computes the maximum (element-wise) a list of inputs. @@ -388,7 +382,6 @@ def _merge_function(self, inputs): return output -@keras_export('keras.layers.Minimum') class Minimum(_Merge): """Layer that computes the minimum (element-wise) a list of inputs. @@ -418,7 +411,6 @@ def _merge_function(self, inputs): return output -@keras_export('keras.layers.Concatenate') class Concatenate(_Merge): """Layer that concatenates a list of inputs. @@ -573,7 +565,6 @@ def get_config(self): return dict(list(base_config.items()) + list(config.items())) -@keras_export('keras.layers.Dot') class Dot(_Merge): """Layer that computes a dot product between samples in two tensors. @@ -735,7 +726,6 @@ def get_config(self): return dict(list(base_config.items()) + list(config.items())) -@keras_export('keras.layers.add') def add(inputs, **kwargs): """Functional interface to the `tf.keras.layers.Add` layer. @@ -769,7 +759,6 @@ def add(inputs, **kwargs): return Add(**kwargs)(inputs) -@keras_export('keras.layers.subtract') def subtract(inputs, **kwargs): """Functional interface to the `Subtract` layer. @@ -798,7 +787,6 @@ def subtract(inputs, **kwargs): return Subtract(**kwargs)(inputs) -@keras_export('keras.layers.multiply') def multiply(inputs, **kwargs): """Functional interface to the `Multiply` layer. @@ -829,7 +817,6 @@ def multiply(inputs, **kwargs): return Multiply(**kwargs)(inputs) -@keras_export('keras.layers.average') def average(inputs, **kwargs): """Functional interface to the `tf.keras.layers.Average` layer. @@ -865,7 +852,6 @@ def average(inputs, **kwargs): return Average(**kwargs)(inputs) -@keras_export('keras.layers.maximum') def maximum(inputs, **kwargs): """Functional interface to compute maximum (element-wise) list of `inputs`. @@ -897,7 +883,6 @@ def maximum(inputs, **kwargs): return Maximum(**kwargs)(inputs) -@keras_export('keras.layers.minimum') def minimum(inputs, **kwargs): """Functional interface to the `Minimum` layer. @@ -911,7 +896,6 @@ def minimum(inputs, **kwargs): return Minimum(**kwargs)(inputs) -@keras_export('keras.layers.concatenate') def concatenate(inputs, axis=-1, **kwargs): """Functional interface to the `Concatenate` layer. @@ -946,7 +930,6 @@ def concatenate(inputs, axis=-1, **kwargs): return Concatenate(axis=axis, **kwargs)(inputs) -@keras_export('keras.layers.dot') def dot(inputs, axes, normalize=False, **kwargs): """Functional interface to the `Dot` layer. diff --git a/tensorflow/python/keras/layers/pooling.py b/tensorflow/python/keras/layers/pooling.py index a389e83b290ebd..61d24ee0acd2fa 100644 --- a/tensorflow/python/keras/layers/pooling.py +++ b/tensorflow/python/keras/layers/pooling.py @@ -24,7 +24,6 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn -from tensorflow.python.util.tf_export import keras_export class Pooling1D(Layer): @@ -104,7 +103,6 @@ def get_config(self): return dict(list(base_config.items()) + list(config.items())) -@keras_export('keras.layers.MaxPool1D', 'keras.layers.MaxPooling1D') class MaxPooling1D(Pooling1D): """Max pooling operation for 1D temporal data. @@ -196,7 +194,6 @@ def __init__(self, pool_size=2, strides=None, **kwargs) -@keras_export('keras.layers.AveragePooling1D', 'keras.layers.AvgPool1D') class AveragePooling1D(Pooling1D): """Average pooling for temporal data. @@ -393,7 +390,6 @@ def get_config(self): return dict(list(base_config.items()) + list(config.items())) -@keras_export('keras.layers.MaxPool2D', 'keras.layers.MaxPooling2D') class MaxPooling2D(Pooling2D): """Max pooling operation for 2D spatial data. @@ -530,7 +526,6 @@ def __init__(self, padding=padding, data_format=data_format, **kwargs) -@keras_export('keras.layers.AveragePooling2D', 'keras.layers.AvgPool2D') class AveragePooling2D(Pooling2D): """Average pooling operation for spatial data. @@ -740,7 +735,6 @@ def get_config(self): return dict(list(base_config.items()) + list(config.items())) -@keras_export('keras.layers.MaxPool3D', 'keras.layers.MaxPooling3D') class MaxPooling3D(Pooling3D): """Max pooling operation for 3D data (spatial or spatio-temporal). @@ -811,7 +805,6 @@ def __init__(self, padding=padding, data_format=data_format, **kwargs) -@keras_export('keras.layers.AveragePooling3D', 'keras.layers.AvgPool3D') class AveragePooling3D(Pooling3D): """Average pooling operation for 3D data (spatial or spatio-temporal). @@ -913,8 +906,6 @@ def get_config(self): return dict(list(base_config.items()) + list(config.items())) -@keras_export('keras.layers.GlobalAveragePooling1D', - 'keras.layers.GlobalAvgPool1D') class GlobalAveragePooling1D(GlobalPooling1D): """Global average pooling operation for temporal data. @@ -987,7 +978,6 @@ def compute_mask(self, inputs, mask=None): return None -@keras_export('keras.layers.GlobalMaxPool1D', 'keras.layers.GlobalMaxPooling1D') class GlobalMaxPooling1D(GlobalPooling1D): """Global max pooling operation for 1D temporal data. @@ -1080,8 +1070,6 @@ def get_config(self): return dict(list(base_config.items()) + list(config.items())) -@keras_export('keras.layers.GlobalAveragePooling2D', - 'keras.layers.GlobalAvgPool2D') class GlobalAveragePooling2D(GlobalPooling2D): """Global average pooling operation for spatial data. @@ -1134,7 +1122,6 @@ def call(self, inputs): return backend.mean(inputs, axis=[2, 3], keepdims=self.keepdims) -@keras_export('keras.layers.GlobalMaxPool2D', 'keras.layers.GlobalMaxPooling2D') class GlobalMaxPooling2D(GlobalPooling2D): """Global max pooling operation for spatial data. @@ -1220,8 +1207,6 @@ def get_config(self): return dict(list(base_config.items()) + list(config.items())) -@keras_export('keras.layers.GlobalAveragePooling3D', - 'keras.layers.GlobalAvgPool3D') class GlobalAveragePooling3D(GlobalPooling3D): """Global Average pooling operation for 3D data. @@ -1268,7 +1253,6 @@ def call(self, inputs): return backend.mean(inputs, axis=[2, 3, 4], keepdims=self.keepdims) -@keras_export('keras.layers.GlobalMaxPool3D', 'keras.layers.GlobalMaxPooling3D') class GlobalMaxPooling3D(GlobalPooling3D): """Global Max pooling operation for 3D data. diff --git a/tensorflow/python/keras/layers/recurrent.py b/tensorflow/python/keras/layers/recurrent.py index d3c5fec5048e67..94b85f29e693d9 100644 --- a/tensorflow/python/keras/layers/recurrent.py +++ b/tensorflow/python/keras/layers/recurrent.py @@ -44,7 +44,6 @@ from tensorflow.python.platform import tf_logging as logging from tensorflow.python.trackable import base as trackable from tensorflow.python.util import nest -from tensorflow.python.util.tf_export import keras_export from tensorflow.tools.docs import doc_controls @@ -53,7 +52,6 @@ 'Using `implementation=1`.') -@keras_export('keras.layers.StackedRNNCells') class StackedRNNCells(Layer): """Wrapper allowing a stack of RNN cells to behave as a single cell. @@ -195,7 +193,6 @@ def from_config(cls, config, custom_objects=None): return cls(cells, **config) -@keras_export('keras.layers.RNN') class RNN(Layer): """Base class for recurrent layers. @@ -1003,7 +1000,6 @@ def _trackable_saved_model_saver(self): return layer_serialization.RNNSavedModelSaver(self) -@keras_export('keras.layers.AbstractRNNCell') class AbstractRNNCell(Layer): """Abstract object representing an RNN cell. @@ -1230,7 +1226,6 @@ def __setstate__(self, state): super(DropoutRNNCellMixin, self).__setstate__(state) -@keras_export('keras.layers.SimpleRNNCell') class SimpleRNNCell(DropoutRNNCellMixin, Layer): """Cell class for SimpleRNN. @@ -1435,7 +1430,6 @@ def get_config(self): return dict(list(base_config.items()) + list(config.items())) -@keras_export('keras.layers.SimpleRNN') class SimpleRNN(RNN): """Fully-connected RNN where the output is to be fed back to input. @@ -1689,7 +1683,6 @@ def from_config(cls, config): return cls(**config) -@keras_export(v1=['keras.layers.GRUCell']) class GRUCell(DropoutRNNCellMixin, Layer): """Cell class for the GRU layer. @@ -1970,7 +1963,6 @@ def get_initial_state(self, inputs=None, batch_size=None, dtype=None): return _generate_zero_filled_state_for_cell(self, inputs, batch_size, dtype) -@keras_export(v1=['keras.layers.GRU']) class GRU(RNN): """Gated Recurrent Unit - Cho et al. 2014. @@ -2249,7 +2241,6 @@ def from_config(cls, config): return cls(**config) -@keras_export(v1=['keras.layers.LSTMCell']) class LSTMCell(DropoutRNNCellMixin, Layer): """Cell class for the LSTM layer. @@ -2527,7 +2518,6 @@ def get_initial_state(self, inputs=None, batch_size=None, dtype=None): self, inputs, batch_size, dtype)) -@keras_export('keras.experimental.PeepholeLSTMCell') class PeepholeLSTMCell(LSTMCell): """Equivalent to LSTMCell class but adds peephole connections. @@ -2646,7 +2636,6 @@ def _compute_carry_and_output_fused(self, z, c_tm1): return c, o -@keras_export(v1=['keras.layers.LSTM']) class LSTM(RNN): """Long Short-Term Memory layer - Hochreiter 1997. diff --git a/tensorflow/python/keras/layers/serialization.py b/tensorflow/python/keras/layers/serialization.py index 14e26c421680e4..65238330f93b29 100644 --- a/tensorflow/python/keras/layers/serialization.py +++ b/tensorflow/python/keras/layers/serialization.py @@ -35,7 +35,6 @@ from tensorflow.python.keras.layers import rnn_cell_wrapper_v2 from tensorflow.python.keras.utils import generic_utils from tensorflow.python.keras.utils import tf_inspect as inspect -from tensorflow.python.util.tf_export import keras_export ALL_MODULES = (base_layer, input_layer, advanced_activations, convolutional, convolutional_recurrent, core, dense_attention, @@ -95,12 +94,10 @@ def populate_deserializable_objects(): LOCAL.ALL_OBJECTS['dot'] = merge.dot -@keras_export('keras.layers.serialize') def serialize(layer): return generic_utils.serialize_keras_object(layer) -@keras_export('keras.layers.deserialize') def deserialize(config, custom_objects=None): """Instantiates a layer from a config dictionary. diff --git a/tensorflow/python/keras/legacy_tf_layers/base.py b/tensorflow/python/keras/legacy_tf_layers/base.py index eeb5626aaeb7e5..0225fbfa47a708 100644 --- a/tensorflow/python/keras/legacy_tf_layers/base.py +++ b/tensorflow/python/keras/legacy_tf_layers/base.py @@ -30,7 +30,6 @@ from tensorflow.python.ops import variables as tf_variables from tensorflow.python.trackable import base as trackable from tensorflow.python.util import nest -from tensorflow.python.util.tf_export import keras_export # Avoid breaking users who directly import this symbol from this file. # TODO(fchollet): remove this. @@ -39,8 +38,6 @@ _KERAS_STYLE_SCOPE = False -@keras_export( - v1=['keras.__internal__.legacy.layers.experimental.keras_style_scope']) @tf_contextlib.contextmanager def keras_style_scope(): """Use Keras-style variable management. @@ -109,8 +106,6 @@ def call(self, input, state): _KERAS_STYLE_SCOPE = stack -@keras_export( - v1=['keras.__internal__.legacy.layers.experimental.set_keras_style']) def set_keras_style(): """Use Keras-style variable management. @@ -153,7 +148,6 @@ def _is_in_keras_style_scope(): return _KERAS_STYLE_SCOPE -@keras_export(v1=['keras.__internal__.legacy.layers.Layer']) class Layer(base_layer.Layer): """Base layer class. diff --git a/tensorflow/python/keras/legacy_tf_layers/convolutional.py b/tensorflow/python/keras/legacy_tf_layers/convolutional.py index f43f854334d52a..f8f91b12b8a098 100644 --- a/tensorflow/python/keras/legacy_tf_layers/convolutional.py +++ b/tensorflow/python/keras/legacy_tf_layers/convolutional.py @@ -19,10 +19,8 @@ from tensorflow.python.keras import layers as keras_layers from tensorflow.python.keras.legacy_tf_layers import base from tensorflow.python.ops import init_ops -from tensorflow.python.util.tf_export import keras_export -@keras_export(v1=['keras.__internal__.legacy.layers.Conv1D']) class Conv1D(keras_layers.Conv1D, base.Layer): """1D convolution layer (e.g. temporal convolution). @@ -114,7 +112,6 @@ def __init__(self, filters, name=name, **kwargs) -@keras_export(v1=['keras.__internal__.legacy.layers.conv1d']) def conv1d(inputs, filters, kernel_size, @@ -220,7 +217,6 @@ def conv1d(inputs, return layer.apply(inputs) -@keras_export(v1=['keras.__internal__.legacy.layers.Conv2D']) class Conv2D(keras_layers.Conv2D, base.Layer): """2D convolution layer (e.g. spatial convolution over images). @@ -319,7 +315,6 @@ def __init__(self, filters, name=name, **kwargs) -@keras_export(v1=['keras.__internal__.legacy.layers.conv2d']) def conv2d(inputs, filters, kernel_size, @@ -432,7 +427,6 @@ def conv2d(inputs, return layer.apply(inputs) -@keras_export(v1=['keras.__internal__.legacy.layers.Conv3D']) class Conv3D(keras_layers.Conv3D, base.Layer): """3D convolution layer (e.g. spatial convolution over volumes). @@ -532,7 +526,6 @@ def __init__(self, filters, name=name, **kwargs) -@keras_export(v1=['keras.__internal__.legacy.layers.conv3d']) def conv3d(inputs, filters, kernel_size, @@ -646,7 +639,6 @@ def conv3d(inputs, return layer.apply(inputs) -@keras_export(v1=['keras.__internal__.legacy.layers.SeparableConv1D']) class SeparableConv1D(keras_layers.SeparableConv1D, base.Layer): """Depthwise separable 1D convolution. @@ -756,7 +748,6 @@ def __init__(self, filters, **kwargs) -@keras_export(v1=['keras.__internal__.legacy.layers.SeparableConv2D']) class SeparableConv2D(keras_layers.SeparableConv2D, base.Layer): """Depthwise separable 2D convolution. @@ -871,7 +862,6 @@ def __init__(self, filters, **kwargs) -@keras_export(v1=['keras.__internal__.legacy.layers.separable_conv1d']) def separable_conv1d(inputs, filters, kernel_size, @@ -994,7 +984,6 @@ def separable_conv1d(inputs, return layer.apply(inputs) -@keras_export(v1=['keras.__internal__.legacy.layers.separable_conv2d']) def separable_conv2d(inputs, filters, kernel_size, @@ -1122,7 +1111,6 @@ def separable_conv2d(inputs, return layer.apply(inputs) -@keras_export(v1=['keras.__internal__.legacy.layers.Conv2DTranspose']) class Conv2DTranspose(keras_layers.Conv2DTranspose, base.Layer): """Transposed 2D convolution layer (sometimes called 2D Deconvolution). @@ -1210,7 +1198,6 @@ def __init__(self, filters, **kwargs) -@keras_export(v1=['keras.__internal__.legacy.layers.conv2d_transpose']) def conv2d_transpose(inputs, filters, kernel_size, @@ -1311,7 +1298,6 @@ def conv2d_transpose(inputs, return layer.apply(inputs) -@keras_export(v1=['keras.__internal__.legacy.layers.Conv3DTranspose']) class Conv3DTranspose(keras_layers.Conv3DTranspose, base.Layer): """Transposed 3D convolution layer (sometimes called 3D Deconvolution). @@ -1396,7 +1382,6 @@ def __init__(self, **kwargs) -@keras_export(v1=['keras.__internal__.legacy.layers.conv3d_transpose']) def conv3d_transpose(inputs, filters, kernel_size, diff --git a/tensorflow/python/keras/legacy_tf_layers/core.py b/tensorflow/python/keras/legacy_tf_layers/core.py index 2564a2da8ce65f..99d2c59ccadad8 100644 --- a/tensorflow/python/keras/legacy_tf_layers/core.py +++ b/tensorflow/python/keras/legacy_tf_layers/core.py @@ -22,10 +22,8 @@ from tensorflow.python.keras import layers as keras_layers from tensorflow.python.keras.legacy_tf_layers import base from tensorflow.python.ops import init_ops -from tensorflow.python.util.tf_export import keras_export -@keras_export(v1=['keras.__internal__.legacy.layers.Dense']) class Dense(keras_layers.Dense, base.Layer): """Densely-connected layer class. @@ -106,7 +104,6 @@ def __init__(self, units, **kwargs) -@keras_export(v1=['keras.__internal__.legacy.layers.dense']) def dense( inputs, units, activation=None, @@ -184,7 +181,6 @@ def dense( return layer.apply(inputs) -@keras_export(v1=['keras.__internal__.legacy.layers.Dropout']) class Dropout(keras_layers.Dropout, base.Layer): """Applies Dropout to the input. @@ -223,7 +219,6 @@ def call(self, inputs, training=False): return super(Dropout, self).call(inputs, training=training) -@keras_export(v1=['keras.__internal__.legacy.layers.dropout']) def dropout(inputs, rate=0.5, noise_shape=None, @@ -268,7 +263,6 @@ def dropout(inputs, return layer.apply(inputs, training=training) -@keras_export(v1=['keras.__internal__.legacy.layers.Flatten']) class Flatten(keras_layers.Flatten, base.Layer): """Flattens an input tensor while preserving the batch axis (axis 0). @@ -294,7 +288,6 @@ class Flatten(keras_layers.Flatten, base.Layer): pass -@keras_export(v1=['keras.__internal__.legacy.layers.flatten']) def flatten(inputs, name=None, data_format='channels_last'): """Flattens an input tensor while preserving the batch axis (axis 0). diff --git a/tensorflow/python/keras/legacy_tf_layers/pooling.py b/tensorflow/python/keras/legacy_tf_layers/pooling.py index b7134abe8f0de6..68849eb39a20bc 100644 --- a/tensorflow/python/keras/legacy_tf_layers/pooling.py +++ b/tensorflow/python/keras/legacy_tf_layers/pooling.py @@ -18,10 +18,8 @@ from tensorflow.python.keras import layers as keras_layers from tensorflow.python.keras.legacy_tf_layers import base -from tensorflow.python.util.tf_export import keras_export -@keras_export(v1=['keras.__internal__.legacy.layers.AveragePooling1D']) class AveragePooling1D(keras_layers.AveragePooling1D, base.Layer): """Average Pooling layer for 1D inputs. @@ -54,7 +52,6 @@ def __init__(self, pool_size, strides, **kwargs) -@keras_export(v1=['keras.__internal__.legacy.layers.average_pooling1d']) def average_pooling1d(inputs, pool_size, strides, padding='valid', data_format='channels_last', name=None): @@ -92,7 +89,6 @@ def average_pooling1d(inputs, pool_size, strides, return layer.apply(inputs) -@keras_export(v1=['keras.__internal__.legacy.layers.MaxPooling1D']) class MaxPooling1D(keras_layers.MaxPooling1D, base.Layer): """Max Pooling layer for 1D inputs. @@ -125,7 +121,6 @@ def __init__(self, pool_size, strides, **kwargs) -@keras_export(v1=['keras.__internal__.legacy.layers.max_pooling1d']) def max_pooling1d(inputs, pool_size, strides, padding='valid', data_format='channels_last', name=None): @@ -163,7 +158,6 @@ def max_pooling1d(inputs, pool_size, strides, return layer.apply(inputs) -@keras_export(v1=['keras.__internal__.legacy.layers.AveragePooling2D']) class AveragePooling2D(keras_layers.AveragePooling2D, base.Layer): """Average pooling layer for 2D inputs (e.g. images). @@ -196,7 +190,6 @@ def __init__(self, pool_size, strides, padding=padding, data_format=data_format, name=name, **kwargs) -@keras_export(v1=['keras.__internal__.legacy.layers.average_pooling2d']) def average_pooling2d(inputs, pool_size, strides, padding='valid', data_format='channels_last', @@ -237,7 +230,6 @@ def average_pooling2d(inputs, return layer.apply(inputs) -@keras_export(v1=['keras.__internal__.legacy.layers.MaxPooling2D']) class MaxPooling2D(keras_layers.MaxPooling2D, base.Layer): """Max pooling layer for 2D inputs (e.g. images). @@ -270,7 +262,6 @@ def __init__(self, pool_size, strides, padding=padding, data_format=data_format, name=name, **kwargs) -@keras_export(v1=['keras.__internal__.legacy.layers.max_pooling2d']) def max_pooling2d(inputs, pool_size, strides, padding='valid', data_format='channels_last', @@ -311,7 +302,6 @@ def max_pooling2d(inputs, return layer.apply(inputs) -@keras_export(v1=['keras.__internal__.legacy.layers.AveragePooling3D']) class AveragePooling3D(keras_layers.AveragePooling3D, base.Layer): """Average pooling layer for 3D inputs (e.g. volumes). @@ -346,7 +336,6 @@ def __init__(self, pool_size, strides, padding=padding, data_format=data_format, name=name, **kwargs) -@keras_export(v1=['keras.__internal__.legacy.layers.average_pooling3d']) def average_pooling3d(inputs, pool_size, strides, padding='valid', data_format='channels_last', @@ -389,7 +378,6 @@ def average_pooling3d(inputs, return layer.apply(inputs) -@keras_export(v1=['keras.__internal__.legacy.layers.MaxPooling3D']) class MaxPooling3D(keras_layers.MaxPooling3D, base.Layer): """Max pooling layer for 3D inputs (e.g. volumes). @@ -424,7 +412,6 @@ def __init__(self, pool_size, strides, padding=padding, data_format=data_format, name=name, **kwargs) -@keras_export(v1=['keras.__internal__.legacy.layers.max_pooling3d']) def max_pooling3d(inputs, pool_size, strides, padding='valid', data_format='channels_last', diff --git a/tensorflow/python/keras/losses.py b/tensorflow/python/keras/losses.py index bd40507cda8c12..f03e7de0ed932f 100644 --- a/tensorflow/python/keras/losses.py +++ b/tensorflow/python/keras/losses.py @@ -42,11 +42,9 @@ from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.ops.ragged import ragged_util from tensorflow.python.util import dispatch -from tensorflow.python.util.tf_export import keras_export from tensorflow.tools.docs import doc_controls -@keras_export('keras.losses.Loss') class Loss: """Loss base class. @@ -267,7 +265,6 @@ def get_config(self): return dict(list(base_config.items()) + list(config.items())) -@keras_export('keras.losses.MeanSquaredError') class MeanSquaredError(LossFunctionWrapper): """Computes the mean of squares of errors between labels and predictions. @@ -325,7 +322,6 @@ def __init__(self, super().__init__(mean_squared_error, name=name, reduction=reduction) -@keras_export('keras.losses.MeanAbsoluteError') class MeanAbsoluteError(LossFunctionWrapper): """Computes the mean of absolute difference between labels and predictions. @@ -383,7 +379,6 @@ def __init__(self, super().__init__(mean_absolute_error, name=name, reduction=reduction) -@keras_export('keras.losses.MeanAbsolutePercentageError') class MeanAbsolutePercentageError(LossFunctionWrapper): """Computes the mean absolute percentage error between `y_true` and `y_pred`. @@ -444,7 +439,6 @@ def __init__(self, mean_absolute_percentage_error, name=name, reduction=reduction) -@keras_export('keras.losses.MeanSquaredLogarithmicError') class MeanSquaredLogarithmicError(LossFunctionWrapper): """Computes the mean squared logarithmic error between `y_true` and `y_pred`. @@ -505,7 +499,6 @@ def __init__(self, mean_squared_logarithmic_error, name=name, reduction=reduction) -@keras_export('keras.losses.BinaryCrossentropy') class BinaryCrossentropy(LossFunctionWrapper): """Computes the cross-entropy loss between true labels and predicted labels. @@ -608,7 +601,6 @@ def __init__(self, self.from_logits = from_logits -@keras_export('keras.losses.CategoricalCrossentropy') class CategoricalCrossentropy(LossFunctionWrapper): """Computes the crossentropy loss between the labels and predictions. @@ -691,7 +683,6 @@ def __init__(self, axis=axis) -@keras_export('keras.losses.SparseCategoricalCrossentropy') class SparseCategoricalCrossentropy(LossFunctionWrapper): """Computes the crossentropy loss between the labels and predictions. @@ -767,7 +758,6 @@ def __init__(self, from_logits=from_logits) -@keras_export('keras.losses.Hinge') class Hinge(LossFunctionWrapper): """Computes the hinge loss between `y_true` and `y_pred`. @@ -826,7 +816,6 @@ def __init__(self, reduction=losses_utils.ReductionV2.AUTO, name='hinge'): super().__init__(hinge, name=name, reduction=reduction) -@keras_export('keras.losses.SquaredHinge') class SquaredHinge(LossFunctionWrapper): """Computes the squared hinge loss between `y_true` and `y_pred`. @@ -887,7 +876,6 @@ def __init__(self, super().__init__(squared_hinge, name=name, reduction=reduction) -@keras_export('keras.losses.CategoricalHinge') class CategoricalHinge(LossFunctionWrapper): """Computes the categorical hinge loss between `y_true` and `y_pred`. @@ -946,7 +934,6 @@ def __init__(self, super().__init__(categorical_hinge, name=name, reduction=reduction) -@keras_export('keras.losses.Poisson') class Poisson(LossFunctionWrapper): """Computes the Poisson loss between `y_true` and `y_pred`. @@ -1002,7 +989,6 @@ def __init__(self, reduction=losses_utils.ReductionV2.AUTO, name='poisson'): super().__init__(poisson, name=name, reduction=reduction) -@keras_export('keras.losses.LogCosh') class LogCosh(LossFunctionWrapper): """Computes the logarithm of the hyperbolic cosine of the prediction error. @@ -1059,7 +1045,6 @@ def __init__(self, reduction=losses_utils.ReductionV2.AUTO, name='log_cosh'): super().__init__(log_cosh, name=name, reduction=reduction) -@keras_export('keras.losses.KLDivergence') class KLDivergence(LossFunctionWrapper): """Computes Kullback-Leibler divergence loss between `y_true` and `y_pred`. @@ -1119,7 +1104,6 @@ def __init__(self, super().__init__(kl_divergence, name=name, reduction=reduction) -@keras_export('keras.losses.Huber') class Huber(LossFunctionWrapper): """Computes the Huber loss between `y_true` and `y_pred`. @@ -1186,9 +1170,6 @@ def __init__(self, super().__init__(huber, name=name, reduction=reduction, delta=delta) -@keras_export('keras.metrics.mean_squared_error', 'keras.metrics.mse', - 'keras.metrics.MSE', 'keras.losses.mean_squared_error', - 'keras.losses.mse', 'keras.losses.MSE') @dispatch.add_dispatch_support def mean_squared_error(y_true, y_pred): """Computes the mean squared error between labels and predictions. @@ -1317,9 +1298,6 @@ def _ragged_tensor_mse(y_true, y_pred): return _ragged_tensor_apply_loss(mean_squared_error, y_true, y_pred) -@keras_export('keras.metrics.mean_absolute_error', 'keras.metrics.mae', - 'keras.metrics.MAE', 'keras.losses.mean_absolute_error', - 'keras.losses.mae', 'keras.losses.MAE') @dispatch.add_dispatch_support def mean_absolute_error(y_true, y_pred): """Computes the mean absolute error between labels and predictions. @@ -1353,10 +1331,6 @@ def _ragged_tensor_mae(y_true, y_pred): return _ragged_tensor_apply_loss(mean_absolute_error, y_true, y_pred) -@keras_export('keras.metrics.mean_absolute_percentage_error', - 'keras.metrics.mape', 'keras.metrics.MAPE', - 'keras.losses.mean_absolute_percentage_error', - 'keras.losses.mape', 'keras.losses.MAPE') @dispatch.add_dispatch_support def mean_absolute_percentage_error(y_true, y_pred): """Computes the mean absolute percentage error between `y_true` and `y_pred`. @@ -1397,10 +1371,6 @@ def _ragged_tensor_mape(y_true, y_pred): y_pred) -@keras_export('keras.metrics.mean_squared_logarithmic_error', - 'keras.metrics.msle', 'keras.metrics.MSLE', - 'keras.losses.mean_squared_logarithmic_error', - 'keras.losses.msle', 'keras.losses.MSLE') @dispatch.add_dispatch_support def mean_squared_logarithmic_error(y_true, y_pred): """Computes the mean squared logarithmic error between `y_true` and `y_pred`. @@ -1458,7 +1428,6 @@ def _convert_binary_labels(): return updated_y_true -@keras_export('keras.metrics.squared_hinge', 'keras.losses.squared_hinge') @dispatch.add_dispatch_support def squared_hinge(y_true, y_pred): """Computes the squared hinge loss between `y_true` and `y_pred`. @@ -1491,7 +1460,6 @@ def squared_hinge(y_true, y_pred): math_ops.square(math_ops.maximum(1. - y_true * y_pred, 0.)), axis=-1) -@keras_export('keras.metrics.hinge', 'keras.losses.hinge') @dispatch.add_dispatch_support def hinge(y_true, y_pred): """Computes the hinge loss between `y_true` and `y_pred`. @@ -1523,7 +1491,6 @@ def hinge(y_true, y_pred): return backend.mean(math_ops.maximum(1. - y_true * y_pred, 0.), axis=-1) -@keras_export('keras.losses.categorical_hinge') @dispatch.add_dispatch_support def categorical_hinge(y_true, y_pred): """Computes the categorical hinge loss between `y_true` and `y_pred`. @@ -1558,7 +1525,6 @@ def categorical_hinge(y_true, y_pred): return math_ops.maximum(neg - pos + 1., zero) -@keras_export('keras.losses.huber', v1=[]) @dispatch.add_dispatch_support def huber(y_true, y_pred, delta=1.0): """Computes Huber loss value. @@ -1594,8 +1560,6 @@ def huber(y_true, y_pred, delta=1.0): axis=-1) -@keras_export('keras.losses.log_cosh', 'keras.losses.logcosh', - 'keras.metrics.log_cosh', 'keras.metrics.logcosh') @dispatch.add_dispatch_support def log_cosh(y_true, y_pred): """Logarithm of the hyperbolic cosine of the prediction error. @@ -1634,8 +1598,6 @@ def _logcosh(x): return backend.mean(_logcosh(y_pred - y_true), axis=-1) -@keras_export('keras.metrics.categorical_crossentropy', - 'keras.losses.categorical_crossentropy') @dispatch.add_dispatch_support def categorical_crossentropy(y_true, y_pred, @@ -1725,8 +1687,6 @@ def _ragged_tensor_categorical_crossentropy(y_true, return _ragged_tensor_apply_loss(fn, y_true, y_pred) -@keras_export('keras.metrics.sparse_categorical_crossentropy', - 'keras.losses.sparse_categorical_crossentropy') @dispatch.add_dispatch_support def sparse_categorical_crossentropy(y_true, y_pred, from_logits=False, axis=-1): """Computes the sparse categorical crossentropy loss. @@ -1780,8 +1740,6 @@ def _ragged_tensor_sparse_categorical_crossentropy(y_true, return _ragged_tensor_apply_loss(fn, y_true, y_pred, y_pred_extra_dim=True) -@keras_export('keras.metrics.binary_crossentropy', - 'keras.losses.binary_crossentropy') @dispatch.add_dispatch_support def binary_crossentropy(y_true, y_pred, @@ -1866,11 +1824,6 @@ def _ragged_tensor_binary_crossentropy(y_true, return _ragged_tensor_apply_loss(fn, y_true, y_pred) -@keras_export('keras.metrics.kl_divergence', - 'keras.metrics.kullback_leibler_divergence', 'keras.metrics.kld', - 'keras.metrics.KLD', 'keras.losses.kl_divergence', - 'keras.losses.kullback_leibler_divergence', 'keras.losses.kld', - 'keras.losses.KLD') @dispatch.add_dispatch_support def kl_divergence(y_true, y_pred): """Computes Kullback-Leibler divergence loss between `y_true` and `y_pred`. @@ -1907,7 +1860,6 @@ def kl_divergence(y_true, y_pred): return math_ops.reduce_sum(y_true * math_ops.log(y_true / y_pred), axis=-1) -@keras_export('keras.metrics.poisson', 'keras.losses.poisson') @dispatch.add_dispatch_support def poisson(y_true, y_pred): """Computes the Poisson loss between y_true and y_pred. @@ -1942,15 +1894,6 @@ def poisson(y_true, y_pred): y_pred - y_true * math_ops.log(y_pred + backend.epsilon()), axis=-1) -@keras_export( - 'keras.losses.cosine_similarity', - v1=[ - 'keras.metrics.cosine_proximity', - 'keras.metrics.cosine', - 'keras.losses.cosine_proximity', - 'keras.losses.cosine', - 'keras.losses.cosine_similarity', - ]) @dispatch.add_dispatch_support def cosine_similarity(y_true, y_pred, axis=-1): """Computes the cosine similarity between labels and predictions. @@ -1987,7 +1930,6 @@ def cosine_similarity(y_true, y_pred, axis=-1): return -math_ops.reduce_sum(y_true * y_pred, axis=axis) -@keras_export('keras.losses.CosineSimilarity') class CosineSimilarity(LossFunctionWrapper): """Computes the cosine similarity between labels and predictions. @@ -2082,7 +2024,6 @@ def is_categorical_crossentropy(loss): return result -@keras_export('keras.losses.serialize') def serialize(loss): """Serializes loss function or `Loss` instance. @@ -2095,7 +2036,6 @@ def serialize(loss): return serialize_keras_object(loss) -@keras_export('keras.losses.deserialize') def deserialize(name, custom_objects=None): """Deserializes a serialized loss class/function instance. @@ -2114,7 +2054,6 @@ def deserialize(name, custom_objects=None): printable_module_name='loss function') -@keras_export('keras.losses.get') def get(identifier): """Retrieves a Keras loss as a `function`/`Loss` class instance. diff --git a/tensorflow/python/keras/metrics.py b/tensorflow/python/keras/metrics.py index 2a3fc0ce872397..96e63a57d4d206 100644 --- a/tensorflow/python/keras/metrics.py +++ b/tensorflow/python/keras/metrics.py @@ -68,11 +68,9 @@ from tensorflow.python.ops import weights_broadcast_ops from tensorflow.python.util import dispatch from tensorflow.python.util import nest -from tensorflow.python.util.tf_export import keras_export from tensorflow.tools.docs import doc_controls -@keras_export('keras.metrics.Metric') class Metric(base_layer.Layer, metaclass=abc.ABCMeta): """Encapsulates metric logic and state. @@ -451,7 +449,6 @@ def result(self): 'reduction [%s] not implemented' % self.reduction) -@keras_export('keras.metrics.Sum') class Sum(Reduce): """Computes the (weighted) sum of the given values. @@ -488,7 +485,6 @@ def __init__(self, name='sum', dtype=None): name=name, dtype=dtype) -@keras_export('keras.metrics.Mean') class Mean(Reduce): """Computes the (weighted) mean of the given values. @@ -530,7 +526,6 @@ def __init__(self, name='mean', dtype=None): reduction=metrics_utils.Reduction.WEIGHTED_MEAN, name=name, dtype=dtype) -@keras_export('keras.metrics.MeanRelativeError') class MeanRelativeError(Mean): """Computes the mean relative error by normalizing with the given values. @@ -610,7 +605,6 @@ def get_config(self): return dict(list(base_config.items()) + list(config.items())) -@keras_export('keras.metrics.MeanMetricWrapper') class MeanMetricWrapper(Mean): """Wraps a stateless metric function with the Mean metric. @@ -700,7 +694,6 @@ def from_config(cls, config): return super(MeanMetricWrapper, cls).from_config(config) -@keras_export('keras.metrics.Accuracy') class Accuracy(MeanMetricWrapper): """Calculates how often predictions equal labels. @@ -742,7 +735,6 @@ def __init__(self, name='accuracy', dtype=None): super(Accuracy, self).__init__(accuracy, name, dtype=dtype) -@keras_export('keras.metrics.BinaryAccuracy') class BinaryAccuracy(MeanMetricWrapper): """Calculates how often predictions match binary labels. @@ -787,7 +779,6 @@ def __init__(self, name='binary_accuracy', dtype=None, threshold=0.5): binary_accuracy, name, dtype=dtype, threshold=threshold) -@keras_export('keras.metrics.CategoricalAccuracy') class CategoricalAccuracy(MeanMetricWrapper): """Calculates how often predictions match one-hot labels. @@ -839,7 +830,6 @@ def __init__(self, name='categorical_accuracy', dtype=None): categorical_accuracy, name, dtype=dtype) -@keras_export('keras.metrics.SparseCategoricalAccuracy') class SparseCategoricalAccuracy(MeanMetricWrapper): """Calculates how often predictions match integer labels. @@ -890,7 +880,6 @@ def __init__(self, name='sparse_categorical_accuracy', dtype=None): sparse_categorical_accuracy, name, dtype=dtype) -@keras_export('keras.metrics.TopKCategoricalAccuracy') class TopKCategoricalAccuracy(MeanMetricWrapper): """Computes how often targets are in the top `K` predictions. @@ -929,7 +918,6 @@ def __init__(self, k=5, name='top_k_categorical_accuracy', dtype=None): top_k_categorical_accuracy, name, dtype=dtype, k=k) -@keras_export('keras.metrics.SparseTopKCategoricalAccuracy') class SparseTopKCategoricalAccuracy(MeanMetricWrapper): """Computes how often integer targets are in the top `K` predictions. @@ -1037,7 +1025,6 @@ def get_config(self): return dict(list(base_config.items()) + list(config.items())) -@keras_export('keras.metrics.FalsePositives') class FalsePositives(_ConfusionMatrixConditionCount): """Calculates the number of false positives. @@ -1086,7 +1073,6 @@ def __init__(self, thresholds=None, name=None, dtype=None): dtype=dtype) -@keras_export('keras.metrics.FalseNegatives') class FalseNegatives(_ConfusionMatrixConditionCount): """Calculates the number of false negatives. @@ -1135,7 +1121,6 @@ def __init__(self, thresholds=None, name=None, dtype=None): dtype=dtype) -@keras_export('keras.metrics.TrueNegatives') class TrueNegatives(_ConfusionMatrixConditionCount): """Calculates the number of true negatives. @@ -1184,7 +1169,6 @@ def __init__(self, thresholds=None, name=None, dtype=None): dtype=dtype) -@keras_export('keras.metrics.TruePositives') class TruePositives(_ConfusionMatrixConditionCount): """Calculates the number of true positives. @@ -1233,7 +1217,6 @@ def __init__(self, thresholds=None, name=None, dtype=None): dtype=dtype) -@keras_export('keras.metrics.Precision') class Precision(Metric): """Computes the precision of the predictions with respect to the labels. @@ -1375,7 +1358,6 @@ def get_config(self): return dict(list(base_config.items()) + list(config.items())) -@keras_export('keras.metrics.Recall') class Recall(Metric): """Computes the recall of the predictions with respect to the labels. @@ -1611,7 +1593,6 @@ def _find_max_under_constraint(self, constrained, dependent, predicate): return array_ops.where_v2(feasible_exists, max_dependent, 0.0) -@keras_export('keras.metrics.SensitivityAtSpecificity') class SensitivityAtSpecificity(SensitivitySpecificityBase): """Computes best sensitivity where specificity is >= specified value. @@ -1705,7 +1686,6 @@ def get_config(self): return dict(list(base_config.items()) + list(config.items())) -@keras_export('keras.metrics.SpecificityAtSensitivity') class SpecificityAtSensitivity(SensitivitySpecificityBase): """Computes best specificity where sensitivity is >= specified value. @@ -1797,7 +1777,6 @@ def get_config(self): return dict(list(base_config.items()) + list(config.items())) -@keras_export('keras.metrics.PrecisionAtRecall') class PrecisionAtRecall(SensitivitySpecificityBase): """Computes best precision where recall is >= specified value. @@ -1878,7 +1857,6 @@ def get_config(self): return dict(list(base_config.items()) + list(config.items())) -@keras_export('keras.metrics.RecallAtPrecision') class RecallAtPrecision(SensitivitySpecificityBase): """Computes best recall where precision is >= specified value. @@ -1963,7 +1941,6 @@ def get_config(self): return dict(list(base_config.items()) + list(config.items())) -@keras_export('keras.metrics.AUC') class AUC(Metric): """Approximates the AUC (Area under the curve) of the ROC or PR curves. @@ -2444,7 +2421,6 @@ def get_config(self): return dict(list(base_config.items()) + list(config.items())) -@keras_export('keras.metrics.CosineSimilarity') class CosineSimilarity(MeanMetricWrapper): """Computes the cosine similarity between the labels and predictions. @@ -2494,7 +2470,6 @@ def __init__(self, name='cosine_similarity', dtype=None, axis=-1): cosine_similarity, name, dtype=dtype, axis=axis) -@keras_export('keras.metrics.MeanAbsoluteError') class MeanAbsoluteError(MeanMetricWrapper): """Computes the mean absolute error between the labels and predictions. @@ -2530,7 +2505,6 @@ def __init__(self, name='mean_absolute_error', dtype=None): mean_absolute_error, name, dtype=dtype) -@keras_export('keras.metrics.MeanAbsolutePercentageError') class MeanAbsolutePercentageError(MeanMetricWrapper): """Computes the mean absolute percentage error between `y_true` and `y_pred`. @@ -2566,7 +2540,6 @@ def __init__(self, name='mean_absolute_percentage_error', dtype=None): mean_absolute_percentage_error, name, dtype=dtype) -@keras_export('keras.metrics.MeanSquaredError') class MeanSquaredError(MeanMetricWrapper): """Computes the mean squared error between `y_true` and `y_pred`. @@ -2602,7 +2575,6 @@ def __init__(self, name='mean_squared_error', dtype=None): mean_squared_error, name, dtype=dtype) -@keras_export('keras.metrics.MeanSquaredLogarithmicError') class MeanSquaredLogarithmicError(MeanMetricWrapper): """Computes the mean squared logarithmic error between `y_true` and `y_pred`. @@ -2638,7 +2610,6 @@ def __init__(self, name='mean_squared_logarithmic_error', dtype=None): mean_squared_logarithmic_error, name, dtype=dtype) -@keras_export('keras.metrics.Hinge') class Hinge(MeanMetricWrapper): """Computes the hinge metric between `y_true` and `y_pred`. @@ -2673,7 +2644,6 @@ def __init__(self, name='hinge', dtype=None): super(Hinge, self).__init__(hinge, name, dtype=dtype) -@keras_export('keras.metrics.SquaredHinge') class SquaredHinge(MeanMetricWrapper): """Computes the squared hinge metric between `y_true` and `y_pred`. @@ -2711,7 +2681,6 @@ def __init__(self, name='squared_hinge', dtype=None): super(SquaredHinge, self).__init__(squared_hinge, name, dtype=dtype) -@keras_export('keras.metrics.CategoricalHinge') class CategoricalHinge(MeanMetricWrapper): """Computes the categorical hinge metric between `y_true` and `y_pred`. @@ -2746,7 +2715,6 @@ def __init__(self, name='categorical_hinge', dtype=None): super(CategoricalHinge, self).__init__(categorical_hinge, name, dtype=dtype) -@keras_export('keras.metrics.RootMeanSquaredError') class RootMeanSquaredError(Mean): """Computes root mean squared error metric between `y_true` and `y_pred`. @@ -2801,7 +2769,6 @@ def result(self): return math_ops.sqrt(math_ops.div_no_nan(self.total, self.count)) -@keras_export('keras.metrics.LogCoshError') class LogCoshError(MeanMetricWrapper): """Computes the logarithm of the hyperbolic cosine of the prediction error. @@ -2837,7 +2804,6 @@ def __init__(self, name='logcosh', dtype=None): super(LogCoshError, self).__init__(logcosh, name, dtype=dtype) -@keras_export('keras.metrics.Poisson') class Poisson(MeanMetricWrapper): """Computes the Poisson metric between `y_true` and `y_pred`. @@ -2873,7 +2839,6 @@ def __init__(self, name='poisson', dtype=None): super(Poisson, self).__init__(poisson, name, dtype=dtype) -@keras_export('keras.metrics.KLDivergence') class KLDivergence(MeanMetricWrapper): """Computes Kullback-Leibler divergence metric between `y_true` and `y_pred`. @@ -2910,7 +2875,6 @@ def __init__(self, name='kullback_leibler_divergence', dtype=None): kullback_leibler_divergence, name, dtype=dtype) -@keras_export('keras.metrics.MeanIoU') class MeanIoU(Metric): """Computes the mean Intersection-Over-Union metric. @@ -3041,7 +3005,6 @@ def get_config(self): return dict(list(base_config.items()) + list(config.items())) -@keras_export('keras.metrics.MeanTensor') class MeanTensor(Metric): """Computes the element-wise (weighted) mean of the given tensors. @@ -3165,7 +3128,6 @@ def reset_state(self): [(v, np.zeros(self._shape.as_list())) for v in self.variables]) -@keras_export('keras.metrics.BinaryCrossentropy') class BinaryCrossentropy(MeanMetricWrapper): """Computes the crossentropy metric between the labels and predictions. @@ -3218,7 +3180,6 @@ def __init__(self, label_smoothing=label_smoothing) -@keras_export('keras.metrics.CategoricalCrossentropy') class CategoricalCrossentropy(MeanMetricWrapper): """Computes the crossentropy metric between the labels and predictions. @@ -3282,7 +3243,6 @@ def __init__(self, label_smoothing=label_smoothing) -@keras_export('keras.metrics.SparseCategoricalCrossentropy') class SparseCategoricalCrossentropy(MeanMetricWrapper): """Computes the crossentropy metric between the labels and predictions. @@ -3421,7 +3381,6 @@ def accuracy(y_true, y_pred): return math_ops.cast(math_ops.equal(y_true, y_pred), backend.floatx()) -@keras_export('keras.metrics.binary_accuracy') @dispatch.add_dispatch_support def binary_accuracy(y_true, y_pred, threshold=0.5): """Calculates how often predictions match binary labels. @@ -3449,7 +3408,6 @@ def binary_accuracy(y_true, y_pred, threshold=0.5): return backend.mean(math_ops.equal(y_true, y_pred), axis=-1) -@keras_export('keras.metrics.categorical_accuracy') @dispatch.add_dispatch_support def categorical_accuracy(y_true, y_pred): """Calculates how often predictions match one-hot labels. @@ -3478,7 +3436,6 @@ def categorical_accuracy(y_true, y_pred): backend.floatx()) -@keras_export('keras.metrics.sparse_categorical_accuracy') @dispatch.add_dispatch_support def sparse_categorical_accuracy(y_true, y_pred): """Calculates how often predictions match integer labels. @@ -3519,7 +3476,6 @@ def sparse_categorical_accuracy(y_true, y_pred): return math_ops.cast(math_ops.equal(y_true, y_pred), backend.floatx()) -@keras_export('keras.metrics.top_k_categorical_accuracy') @dispatch.add_dispatch_support def top_k_categorical_accuracy(y_true, y_pred, k=5): """Computes how often targets are in the top `K` predictions. @@ -3546,7 +3502,6 @@ def top_k_categorical_accuracy(y_true, y_pred, k=5): y_pred, math_ops.argmax(y_true, axis=-1), k), backend.floatx()) -@keras_export('keras.metrics.sparse_top_k_categorical_accuracy') @dispatch.add_dispatch_support def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5): """Computes how often integer targets are in the top `K` predictions. @@ -3627,7 +3582,6 @@ def clone_metrics(metrics): return nest.map_structure(clone_metric, metrics) -@keras_export('keras.metrics.serialize') def serialize(metric): """Serializes metric function or `Metric` instance. @@ -3640,7 +3594,6 @@ def serialize(metric): return serialize_keras_object(metric) -@keras_export('keras.metrics.deserialize') def deserialize(config, custom_objects=None): """Deserializes a serialized metric class/function instance. @@ -3659,7 +3612,6 @@ def deserialize(config, custom_objects=None): printable_module_name='metric function') -@keras_export('keras.metrics.get') def get(identifier): """Retrieves a Keras metric as a `function`/`Metric` class instance. diff --git a/tensorflow/python/keras/mixed_precision/get_layer_policy.py b/tensorflow/python/keras/mixed_precision/get_layer_policy.py index 3a17f25565cdbf..1bcee88d6114e4 100644 --- a/tensorflow/python/keras/mixed_precision/get_layer_policy.py +++ b/tensorflow/python/keras/mixed_precision/get_layer_policy.py @@ -19,10 +19,8 @@ """ from tensorflow.python.keras.engine import base_layer -from tensorflow.python.util.tf_export import keras_export -@keras_export('keras.mixed_precision.experimental.get_layer_policy', v1=[]) def get_layer_policy(layer): """Returns the dtype policy of a layer. diff --git a/tensorflow/python/keras/mixed_precision/loss_scale_optimizer.py b/tensorflow/python/keras/mixed_precision/loss_scale_optimizer.py index 9799c4117962e7..b8895015266491 100644 --- a/tensorflow/python/keras/mixed_precision/loss_scale_optimizer.py +++ b/tensorflow/python/keras/mixed_precision/loss_scale_optimizer.py @@ -42,7 +42,6 @@ from tensorflow.python.training.experimental import loss_scale as loss_scale_module from tensorflow.python.training.experimental import mixed_precision from tensorflow.python.util import nest -from tensorflow.python.util.tf_export import keras_export class _UnwrapPreventer(object): @@ -281,7 +280,6 @@ def update_if_not_finite_grads(): # pylint: disable=g-classes-have-attributes -@keras_export('keras.mixed_precision.LossScaleOptimizer') class LossScaleOptimizer(base_delegate.DelegatingTrackableMixin, optimizer_v2.OptimizerV2): """An optimizer that applies loss scaling to prevent numeric underflow. @@ -887,7 +885,6 @@ def lr(self, value): # optimizer being used. -@keras_export('keras.mixed_precision.experimental.LossScaleOptimizer') class LossScaleOptimizerV1(LossScaleOptimizer): """An deprecated optimizer that applies loss scaling. diff --git a/tensorflow/python/keras/mixed_precision/policy.py b/tensorflow/python/keras/mixed_precision/policy.py index 6966968c19b551..36e558bba666b8 100644 --- a/tensorflow/python/keras/mixed_precision/policy.py +++ b/tensorflow/python/keras/mixed_precision/policy.py @@ -24,11 +24,9 @@ from tensorflow.python.keras.utils import generic_utils from tensorflow.python.platform import tf_logging from tensorflow.python.training.experimental import mixed_precision_global_state -from tensorflow.python.util.tf_export import keras_export # pylint: disable=g-classes-have-attributes -@keras_export('keras.mixed_precision.Policy', v1=[]) class Policy(object): """A dtype policy for a Keras layer. @@ -309,7 +307,6 @@ def from_config(cls, config, custom_objects=None): return cls(**config) -@keras_export('keras.mixed_precision.experimental.Policy', v1=[]) class PolicyV1(Policy): """A deprecated dtype policy for a Keras layer. @@ -409,8 +406,6 @@ def from_config(cls, config, custom_objects=None): _global_policy = None -@keras_export('keras.mixed_precision.global_policy', - 'keras.mixed_precision.experimental.global_policy', v1=[]) def global_policy(): """Returns the global dtype policy. @@ -459,8 +454,6 @@ def _check_if_mixed_precision_graph_rewrite_is_enabled(policy): 'customizable.'.format(policy=policy)) -@keras_export('keras.mixed_precision.set_global_policy', - 'keras.mixed_precision.experimental.set_global_policy', v1=[]) def set_global_policy(policy): """Sets the global dtype policy. diff --git a/tensorflow/python/keras/models.py b/tensorflow/python/keras/models.py index 31ba64daed85b0..b5eaccc3c579bd 100644 --- a/tensorflow/python/keras/models.py +++ b/tensorflow/python/keras/models.py @@ -34,7 +34,6 @@ from tensorflow.python.keras.utils.generic_utils import CustomObjectScope from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest -from tensorflow.python.util.tf_export import keras_export # API entries importable from `keras.models`: @@ -381,7 +380,6 @@ def _clone_sequential_model(model, input_tensors=None, layer_fn=_clone_layer): return cloned_model -@keras_export('keras.models.clone_model') def clone_model(model, input_tensors=None, clone_function=None): """Clone a Functional or Sequential `Model` instance. @@ -578,9 +576,6 @@ def _reset_build_compile_trackers(model): model.optimizer = None -@keras_export( - 'keras.__internal__.models.in_place_subclassed_model_state_restoration', - v1=[]) def in_place_subclassed_model_state_restoration(model): """Restores the original state of a model after it was "reset". @@ -613,7 +608,6 @@ def in_place_subclassed_model_state_restoration(model): _reset_build_compile_trackers(model) -@keras_export('keras.__internal__.models.clone_and_build_model', v1=[]) def clone_and_build_model( model, input_tensors=None, target_tensors=None, custom_objects=None, compile_clone=True, in_place_reset=False, optimizer_iterations=None, diff --git a/tensorflow/python/keras/optimizer_v2/adadelta.py b/tensorflow/python/keras/optimizer_v2/adadelta.py index 611537b2b8f8a5..f2264bd123543d 100644 --- a/tensorflow/python/keras/optimizer_v2/adadelta.py +++ b/tensorflow/python/keras/optimizer_v2/adadelta.py @@ -21,10 +21,8 @@ from tensorflow.python.keras.optimizer_v2 import optimizer_v2 from tensorflow.python.ops import array_ops from tensorflow.python.training import gen_training_ops -from tensorflow.python.util.tf_export import keras_export -@keras_export('keras.optimizers.Adadelta') class Adadelta(optimizer_v2.OptimizerV2): r"""Optimizer that implements the Adadelta algorithm. diff --git a/tensorflow/python/keras/optimizer_v2/adagrad.py b/tensorflow/python/keras/optimizer_v2/adagrad.py index 8bf410ef0f1a1f..c59e165db02409 100644 --- a/tensorflow/python/keras/optimizer_v2/adagrad.py +++ b/tensorflow/python/keras/optimizer_v2/adagrad.py @@ -24,10 +24,8 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.training import gen_training_ops -from tensorflow.python.util.tf_export import keras_export -@keras_export('keras.optimizers.Adagrad') class Adagrad(optimizer_v2.OptimizerV2): r"""Optimizer that implements the Adagrad algorithm. diff --git a/tensorflow/python/keras/optimizer_v2/adam.py b/tensorflow/python/keras/optimizer_v2/adam.py index ffd2801117c69f..7b1e90ad9088fe 100644 --- a/tensorflow/python/keras/optimizer_v2/adam.py +++ b/tensorflow/python/keras/optimizer_v2/adam.py @@ -26,10 +26,8 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops from tensorflow.python.training import gen_training_ops -from tensorflow.python.util.tf_export import keras_export -@keras_export('keras.optimizers.Adam') class Adam(optimizer_v2.OptimizerV2): r"""Optimizer that implements the Adam algorithm. diff --git a/tensorflow/python/keras/optimizer_v2/adamax.py b/tensorflow/python/keras/optimizer_v2/adamax.py index 83622a505184a2..f5e009393a9061 100644 --- a/tensorflow/python/keras/optimizer_v2/adamax.py +++ b/tensorflow/python/keras/optimizer_v2/adamax.py @@ -24,10 +24,8 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.training import gen_training_ops -from tensorflow.python.util.tf_export import keras_export -@keras_export('keras.optimizers.Adamax') class Adamax(optimizer_v2.OptimizerV2): """Optimizer that implements the Adamax algorithm. diff --git a/tensorflow/python/keras/optimizer_v2/ftrl.py b/tensorflow/python/keras/optimizer_v2/ftrl.py index 27c5989dbd217f..6e9ba7206ff7db 100644 --- a/tensorflow/python/keras/optimizer_v2/ftrl.py +++ b/tensorflow/python/keras/optimizer_v2/ftrl.py @@ -20,10 +20,8 @@ from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.training import gen_training_ops -from tensorflow.python.util.tf_export import keras_export -@keras_export('keras.optimizers.Ftrl') class Ftrl(optimizer_v2.OptimizerV2): r"""Optimizer that implements the FTRL algorithm. diff --git a/tensorflow/python/keras/optimizer_v2/gradient_descent.py b/tensorflow/python/keras/optimizer_v2/gradient_descent.py index 87c3b543578973..fe4e01db426f8e 100644 --- a/tensorflow/python/keras/optimizer_v2/gradient_descent.py +++ b/tensorflow/python/keras/optimizer_v2/gradient_descent.py @@ -20,10 +20,8 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_resource_variable_ops from tensorflow.python.training import gen_training_ops -from tensorflow.python.util.tf_export import keras_export -@keras_export("keras.optimizers.SGD") class SGD(optimizer_v2.OptimizerV2): r"""Gradient descent (with momentum) optimizer. diff --git a/tensorflow/python/keras/optimizer_v2/learning_rate_schedule.py b/tensorflow/python/keras/optimizer_v2/learning_rate_schedule.py index 08438301d17e83..7c701e321c8df6 100644 --- a/tensorflow/python/keras/optimizer_v2/learning_rate_schedule.py +++ b/tensorflow/python/keras/optimizer_v2/learning_rate_schedule.py @@ -27,10 +27,8 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.util import nest -from tensorflow.python.util.tf_export import keras_export -@keras_export("keras.optimizers.schedules.LearningRateSchedule") class LearningRateSchedule(object): """The learning rate schedule base class. @@ -95,7 +93,6 @@ def from_config(cls, config): return cls(**config) -@keras_export("keras.optimizers.schedules.ExponentialDecay") class ExponentialDecay(LearningRateSchedule): """A LearningRateSchedule that uses an exponential decay schedule. @@ -203,7 +200,6 @@ def get_config(self): } -@keras_export("keras.optimizers.schedules.PiecewiseConstantDecay") class PiecewiseConstantDecay(LearningRateSchedule): """A LearningRateSchedule that uses a piecewise constant decay schedule. @@ -309,7 +305,6 @@ def get_config(self): } -@keras_export("keras.optimizers.schedules.PolynomialDecay") class PolynomialDecay(LearningRateSchedule): """A LearningRateSchedule that uses a polynomial decay schedule. @@ -456,7 +451,6 @@ def get_config(self): } -@keras_export("keras.optimizers.schedules.InverseTimeDecay") class InverseTimeDecay(LearningRateSchedule): """A LearningRateSchedule that uses an inverse time decay schedule. @@ -565,8 +559,6 @@ def get_config(self): } -@keras_export("keras.optimizers.schedules.CosineDecay", - "keras.experimental.CosineDecay") class CosineDecay(LearningRateSchedule): """A LearningRateSchedule that uses a cosine decay schedule. @@ -662,8 +654,6 @@ def get_config(self): } -@keras_export("keras.optimizers.schedules.CosineDecayRestarts", - "keras.experimental.CosineDecayRestarts") class CosineDecayRestarts(LearningRateSchedule): """A LearningRateSchedule that uses a cosine decay schedule with restarts. @@ -1054,12 +1044,10 @@ def get_config(self): } -@keras_export("keras.optimizers.schedules.serialize") def serialize(learning_rate_schedule): return generic_utils.serialize_keras_object(learning_rate_schedule) -@keras_export("keras.optimizers.schedules.deserialize") def deserialize(config, custom_objects=None): return generic_utils.deserialize_keras_object( config, diff --git a/tensorflow/python/keras/optimizer_v2/nadam.py b/tensorflow/python/keras/optimizer_v2/nadam.py index 5c9cbb92e413af..1440c6aaaabf43 100644 --- a/tensorflow/python/keras/optimizer_v2/nadam.py +++ b/tensorflow/python/keras/optimizer_v2/nadam.py @@ -25,10 +25,8 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variables as tf_variables -from tensorflow.python.util.tf_export import keras_export -@keras_export('keras.optimizers.Nadam') class Nadam(optimizer_v2.OptimizerV2): r"""Optimizer that implements the NAdam algorithm. Much like Adam is essentially RMSprop with momentum, Nadam is Adam with diff --git a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py index 755363545e21ae..d239d49951346a 100644 --- a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py +++ b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py @@ -50,7 +50,6 @@ from tensorflow.python.saved_model import revived_types from tensorflow.python.trackable import base as trackable from tensorflow.python.util import nest -from tensorflow.python.util.tf_export import keras_export _DEFAULT_VALID_DTYPES = frozenset([ @@ -109,7 +108,6 @@ def name_scope_only_in_function_or_graph(name): return NullContextmanager() -@keras_export("keras.optimizers.Optimizer", metaclass=abc.ABCMeta) class OptimizerV2(trackable.Trackable): """Base class for Keras optimizers. diff --git a/tensorflow/python/keras/optimizer_v2/rmsprop.py b/tensorflow/python/keras/optimizer_v2/rmsprop.py index f752c41eeaf903..39fa85af7690b3 100644 --- a/tensorflow/python/keras/optimizer_v2/rmsprop.py +++ b/tensorflow/python/keras/optimizer_v2/rmsprop.py @@ -27,10 +27,8 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops from tensorflow.python.training import gen_training_ops -from tensorflow.python.util.tf_export import keras_export -@keras_export("keras.optimizers.RMSprop") class RMSprop(optimizer_v2.OptimizerV2): r"""Optimizer that implements the RMSprop algorithm. diff --git a/tensorflow/python/keras/optimizers.py b/tensorflow/python/keras/optimizers.py index 9a18128b780a5b..51fad0fc94790d 100644 --- a/tensorflow/python/keras/optimizers.py +++ b/tensorflow/python/keras/optimizers.py @@ -33,10 +33,8 @@ from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object from tensorflow.python.keras.utils.generic_utils import serialize_keras_object from tensorflow.python.training import optimizer as tf_optimizer_module -from tensorflow.python.util.tf_export import keras_export -@keras_export('keras.optimizers.serialize') def serialize(optimizer): """Serialize the optimizer configuration to JSON compatible python dict. @@ -57,7 +55,6 @@ def serialize(optimizer): return serialize_keras_object(optimizer) -@keras_export('keras.optimizers.deserialize') def deserialize(config, custom_objects=None): """Inverse of the `serialize` function. @@ -98,7 +95,6 @@ def deserialize(config, custom_objects=None): printable_module_name='optimizer') -@keras_export('keras.optimizers.get') def get(identifier): """Retrieves a Keras Optimizer instance. diff --git a/tensorflow/python/keras/regularizers.py b/tensorflow/python/keras/regularizers.py index ac00bd4589a329..634ea39f0e4bf6 100644 --- a/tensorflow/python/keras/regularizers.py +++ b/tensorflow/python/keras/regularizers.py @@ -21,7 +21,6 @@ from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object from tensorflow.python.keras.utils.generic_utils import serialize_keras_object from tensorflow.python.ops import math_ops -from tensorflow.python.util.tf_export import keras_export def _check_penalty_number(x): @@ -41,7 +40,6 @@ def _none_to_default(inputs, default): return default if inputs is None else default -@keras_export('keras.regularizers.Regularizer') class Regularizer(object): """Regularizer base class. @@ -206,7 +204,6 @@ def get_config(self): raise NotImplementedError(str(self) + ' does not implement get_config()') -@keras_export('keras.regularizers.L1L2') class L1L2(Regularizer): """A regularizer that applies both L1 and L2 regularization penalties. @@ -251,7 +248,6 @@ def get_config(self): return {'l1': float(self.l1), 'l2': float(self.l2)} -@keras_export('keras.regularizers.L1', 'keras.regularizers.l1') class L1(Regularizer): """A regularizer that applies a L1 regularization penalty. @@ -285,7 +281,6 @@ def get_config(self): return {'l1': float(self.l1)} -@keras_export('keras.regularizers.L2', 'keras.regularizers.l2') class L2(Regularizer): """A regularizer that applies a L2 regularization penalty. @@ -319,7 +314,6 @@ def get_config(self): return {'l2': float(self.l2)} -@keras_export('keras.regularizers.l1_l2') def l1_l2(l1=0.01, l2=0.01): # pylint: disable=redefined-outer-name r"""Create a regularizer that applies both L1 and L2 penalties. @@ -344,12 +338,10 @@ def l1_l2(l1=0.01, l2=0.01): # pylint: disable=redefined-outer-name l2 = L2 -@keras_export('keras.regularizers.serialize') def serialize(regularizer): return serialize_keras_object(regularizer) -@keras_export('keras.regularizers.deserialize') def deserialize(config, custom_objects=None): if config == 'l1_l2': # Special case necessary since the defaults used for "l1_l2" (string) @@ -362,7 +354,6 @@ def deserialize(config, custom_objects=None): printable_module_name='regularizer') -@keras_export('keras.regularizers.get') def get(identifier): """Retrieve a regularizer instance from a config or identifier.""" if identifier is None: diff --git a/tensorflow/python/keras/saving/model_config.py b/tensorflow/python/keras/saving/model_config.py index 1f4309e0c461e8..fc151339edbeba 100644 --- a/tensorflow/python/keras/saving/model_config.py +++ b/tensorflow/python/keras/saving/model_config.py @@ -16,10 +16,8 @@ """Functions that save the model's config into different formats.""" from tensorflow.python.keras.saving.saved_model import json_utils -from tensorflow.python.util.tf_export import keras_export -@keras_export('keras.models.model_from_config') def model_from_config(config, custom_objects=None): """Instantiates a Keras model from its config. @@ -52,7 +50,6 @@ def model_from_config(config, custom_objects=None): return deserialize(config, custom_objects=custom_objects) -@keras_export('keras.models.model_from_yaml') def model_from_yaml(yaml_string, custom_objects=None): """Parses a yaml model configuration file and returns a model instance. @@ -78,7 +75,6 @@ def model_from_yaml(yaml_string, custom_objects=None): ) -@keras_export('keras.models.model_from_json') def model_from_json(json_string, custom_objects=None): """Parses a JSON model configuration string and returns a model instance. diff --git a/tensorflow/python/keras/saving/save.py b/tensorflow/python/keras/saving/save.py index dc8c86cbbb3866..eee859233e5eba 100644 --- a/tensorflow/python/keras/saving/save.py +++ b/tensorflow/python/keras/saving/save.py @@ -22,7 +22,6 @@ from tensorflow.python.keras.saving.saved_model import save as saved_model_save from tensorflow.python.keras.utils import generic_utils from tensorflow.python.keras.utils.io_utils import path_to_string -from tensorflow.python.util.tf_export import keras_export # pylint: disable=g-import-not-at-top try: @@ -32,7 +31,6 @@ # pylint: enable=g-import-not-at-top -@keras_export('keras.models.save_model') def save_model(model, filepath, overwrite=True, @@ -150,7 +148,6 @@ def save_model(model, signatures, options, save_traces) -@keras_export('keras.models.load_model') def load_model(filepath, custom_objects=None, compile=True, options=None): # pylint: disable=redefined-builtin """Loads a model saved via `model.save()`. diff --git a/tensorflow/python/keras/saving/saved_model_experimental.py b/tensorflow/python/keras/saving/saved_model_experimental.py index 90493e22019f6d..2b0067274f8fcd 100644 --- a/tensorflow/python/keras/saving/saved_model_experimental.py +++ b/tensorflow/python/keras/saving/saved_model_experimental.py @@ -37,7 +37,6 @@ from tensorflow.python.training import saver as saver_lib from tensorflow.python.util import compat from tensorflow.python.util import nest -from tensorflow.python.util.tf_export import keras_export # To avoid circular dependencies between keras/engine and keras/saving, # code in keras/saving must delay imports. @@ -59,7 +58,6 @@ SAVED_MODEL_FILENAME_JSON = 'saved_model.json' -@keras_export(v1=['keras.experimental.export_saved_model']) def export_saved_model(model, saved_model_path, custom_objects=None, @@ -370,7 +368,6 @@ def _assert_same_non_optimizer_objects(model, model_graph, clone, clone_graph): return True -@keras_export(v1=['keras.experimental.load_from_saved_model']) def load_from_saved_model(saved_model_path, custom_objects=None): """Loads a keras Model from a SavedModel created by `export_saved_model()`. diff --git a/tensorflow/python/keras/utils/data_utils.py b/tensorflow/python/keras/utils/data_utils.py index c9c4be339e1e9b..2fa63910a3f06b 100644 --- a/tensorflow/python/keras/utils/data_utils.py +++ b/tensorflow/python/keras/utils/data_utils.py @@ -41,7 +41,6 @@ from tensorflow.python.keras.utils import tf_inspect from tensorflow.python.keras.utils.generic_utils import Progbar from tensorflow.python.keras.utils.io_utils import path_to_string -from tensorflow.python.util.tf_export import keras_export # Required to support google internal urlretrieve if sys.version_info[0] == 2: @@ -145,7 +144,6 @@ def _extract_archive(file_path, path='.', archive_format='auto'): return False -@keras_export('keras.utils.get_file') def get_file(fname, origin, untar=False, @@ -389,7 +387,6 @@ def g(*a, **kw): return g -@keras_export('keras.utils.Sequence') class Sequence(object): """Base object for fitting to a sequence of data, such as a dataset. @@ -547,7 +544,6 @@ def get_index(uid, i): return _SHARED_SEQUENCES[uid][i] -@keras_export('keras.utils.SequenceEnqueuer') class SequenceEnqueuer(object): """Base class to enqueue inputs. @@ -676,7 +672,6 @@ def get(self): raise NotImplementedError -@keras_export('keras.utils.OrderedEnqueuer') class OrderedEnqueuer(SequenceEnqueuer): """Builds a Enqueuer from a Sequence. @@ -809,7 +804,6 @@ def next_sample(uid): return next(_SHARED_SEQUENCES[uid]) -@keras_export('keras.utils.GeneratorEnqueuer') class GeneratorEnqueuer(SequenceEnqueuer): """Builds a queue out of a data generator. diff --git a/tensorflow/python/keras/utils/dataset_creator.py b/tensorflow/python/keras/utils/dataset_creator.py index b80f594d50593c..b318e2dd7fa4a5 100644 --- a/tensorflow/python/keras/utils/dataset_creator.py +++ b/tensorflow/python/keras/utils/dataset_creator.py @@ -17,10 +17,8 @@ from tensorflow.python.distribute import distribute_lib from tensorflow.python.types import data as data_types -from tensorflow.python.util.tf_export import keras_export -@keras_export('keras.utils.experimental.DatasetCreator', v1=[]) class DatasetCreator(object): """Object that returns a `tf.data.Dataset` upon invoking. diff --git a/tensorflow/python/keras/utils/generic_utils.py b/tensorflow/python/keras/utils/generic_utils.py index 4d50c93808e8c6..4b4a321f9bdd02 100644 --- a/tensorflow/python/keras/utils/generic_utils.py +++ b/tensorflow/python/keras/utils/generic_utils.py @@ -33,7 +33,6 @@ from tensorflow.python.keras.utils import tf_inspect from tensorflow.python.util import nest from tensorflow.python.util import tf_decorator -from tensorflow.python.util.tf_export import keras_export _GLOBAL_CUSTOM_OBJECTS = {} _GLOBAL_CUSTOM_NAMES = {} @@ -47,8 +46,6 @@ _LAYER_UNDEFINED_CONFIG_KEY = 'layer was saved without config' -@keras_export('keras.utils.custom_object_scope', # pylint: disable=g-classes-have-attributes - 'keras.utils.CustomObjectScope') class CustomObjectScope(object): """Exposes custom classes/functions to Keras deserialization internals. @@ -89,7 +86,6 @@ def __exit__(self, *args, **kwargs): _GLOBAL_CUSTOM_OBJECTS.update(self.backup) -@keras_export('keras.utils.get_custom_objects') def get_custom_objects(): """Retrieves a live reference to the global dictionary of custom objects. @@ -341,7 +337,6 @@ def serialize_keras_class_and_config( return base_config -@keras_export('keras.utils.register_keras_serializable') def register_keras_serializable(package='Custom', name=None): """Registers an object with the Keras serialization framework. @@ -390,7 +385,6 @@ def decorator(arg): return decorator -@keras_export('keras.utils.get_registered_name') def get_registered_name(obj): """Returns the name registered to an object within the Keras framework. @@ -422,7 +416,6 @@ def skip_failed_serialization(): _SKIP_FAILED_SERIALIZATION = prev -@keras_export('keras.utils.get_registered_object') def get_registered_object(name, custom_objects=None, module_objects=None): """Returns the class associated with `name` if it is registered with Keras. @@ -464,7 +457,6 @@ class CustomMaskWarning(Warning): # pylint: enable=g-bad-exception-name -@keras_export('keras.utils.serialize_keras_object') def serialize_keras_object(instance): """Serialize a Keras object into a JSON-compatible representation. @@ -602,7 +594,6 @@ def class_and_config_for_serialized_keras_object( return (cls, cls_config) -@keras_export('keras.utils.deserialize_keras_object') def deserialize_keras_object(identifier, module_objects=None, custom_objects=None, @@ -811,7 +802,6 @@ def has_arg(fn, name, accept_all=False): return name in arg_spec.args or name in arg_spec.kwonlyargs -@keras_export('keras.utils.Progbar') class Progbar(object): """Displays a progress bar. diff --git a/tensorflow/python/keras/utils/layer_utils.py b/tensorflow/python/keras/utils/layer_utils.py index 04989c3ebcba70..76b0b138596d66 100644 --- a/tensorflow/python/keras/utils/layer_utils.py +++ b/tensorflow/python/keras/utils/layer_utils.py @@ -21,10 +21,8 @@ import numpy as np from tensorflow.python.util import nest -from tensorflow.python.util.tf_export import keras_export -@keras_export('keras.utils.get_source_inputs') def get_source_inputs(tensor, layer=None, node_index=None): """Returns the list of input tensors necessary to compute `tensor`. diff --git a/tensorflow/python/keras/utils/losses_utils.py b/tensorflow/python/keras/utils/losses_utils.py index 4439c3c2d913b8..6c6a37dfe95835 100644 --- a/tensorflow/python/keras/utils/losses_utils.py +++ b/tensorflow/python/keras/utils/losses_utils.py @@ -24,10 +24,8 @@ from tensorflow.python.ops import cond from tensorflow.python.ops import math_ops from tensorflow.python.ops.ragged import ragged_tensor -from tensorflow.python.util.tf_export import keras_export -@keras_export('keras.losses.Reduction', v1=[]) class ReductionV2(object): """Types of loss reduction. @@ -276,7 +274,6 @@ def reduce_weighted_loss(weighted_losses, return loss -@keras_export('keras.__internal__.losses.compute_weighted_loss', v1=[]) def compute_weighted_loss(losses, sample_weight=None, reduction=ReductionV2.SUM_OVER_BATCH_SIZE, diff --git a/tensorflow/python/keras/utils/np_utils.py b/tensorflow/python/keras/utils/np_utils.py index 14214e248bdf69..591a8fe2fd3c19 100644 --- a/tensorflow/python/keras/utils/np_utils.py +++ b/tensorflow/python/keras/utils/np_utils.py @@ -15,10 +15,8 @@ """Numpy-related utilities.""" import numpy as np -from tensorflow.python.util.tf_export import keras_export -@keras_export('keras.utils.to_categorical') def to_categorical(y, num_classes=None, dtype='float32'): """Converts a class vector (integers) to binary class matrix. @@ -78,7 +76,6 @@ def to_categorical(y, num_classes=None, dtype='float32'): return categorical -@keras_export('keras.utils.normalize') def normalize(x, axis=-1, order=2): """Normalizes a Numpy array. diff --git a/tensorflow/python/keras/utils/tf_utils.py b/tensorflow/python/keras/utils/tf_utils.py index 91c1aab5cdbada..a5c419d73159f0 100644 --- a/tensorflow/python/keras/utils/tf_utils.py +++ b/tensorflow/python/keras/utils/tf_utils.py @@ -37,7 +37,6 @@ from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.ops.ragged import ragged_tensor_value from tensorflow.python.util import nest -from tensorflow.python.util.tf_export import keras_export def is_tensor_or_tensor_list(v): @@ -334,7 +333,6 @@ def is_symbolic_tensor(tensor): return False -@keras_export('keras.__internal__.utils.register_symbolic_tensor_type', v1=[]) def register_symbolic_tensor_type(cls): """Allows users to specify types regarded as symbolic `Tensor`s. diff --git a/tensorflow/python/keras/utils/vis_utils.py b/tensorflow/python/keras/utils/vis_utils.py index 10ba3be3ff9785..6026caa355fca9 100644 --- a/tensorflow/python/keras/utils/vis_utils.py +++ b/tensorflow/python/keras/utils/vis_utils.py @@ -20,7 +20,6 @@ import sys from tensorflow.python.keras.utils.io_utils import path_to_string from tensorflow.python.util import nest -from tensorflow.python.util.tf_export import keras_export try: @@ -63,7 +62,6 @@ def add_edge(dot, src, dst): dot.add_edge(pydot.Edge(src, dst)) -@keras_export('keras.utils.model_to_dot') def model_to_dot(model, show_shapes=False, show_dtype=False, @@ -275,7 +273,6 @@ def format_shape(shape): return dot -@keras_export('keras.utils.plot_model') def plot_model(model, to_file='model.png', show_shapes=False, From 78be6ff484c6f02a5fc8fb03edffcf902ab1d7d7 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 27 Jul 2023 10:44:13 -0700 Subject: [PATCH 250/410] Update C++ UnsortedSegmentMin/Max grad tests to avoid potential flakiness. Random inputs to max/min can cause issues with numerical gradient test if inputs are very close. PiperOrigin-RevId: 551576996 --- tensorflow/cc/gradients/math_grad_test.cc | 39 ++++++++++++++++------- 1 file changed, 27 insertions(+), 12 deletions(-) diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc index b3d77f29b067c6..b7b0b713861bff 100644 --- a/tensorflow/cc/gradients/math_grad_test.cc +++ b/tensorflow/cc/gradients/math_grad_test.cc @@ -1061,40 +1061,55 @@ TEST_F(NaryGradTest, Atan2Grad) { RunTest({x1}, {shape}, {y}, {shape}); } +// Deterministic test value for UnsortedSegmentMin/Max, since the numerical +// gradient can be wrong if the compared inputs are nearly the same (which can +// happen with random inputs). +constexpr float kUnsortedSegmentMinMaxTestValue[] = { + 0.5f, 0.7f, 0.2f, 1.0f, 1.5f, 10.5f, -0.7f, 1.2f, + -1.0f, 2.5f, 4.2f, 3.7f, 1.2f, -5.0f, -1.5f}; + TEST_F(NaryGradTest, UnsortedSegmentMaxGrad) { - TensorShape shape({3, 2, 5}); + TensorShape shape({3, 1, 5}); auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); auto segment_ids = Const(scope_, {0, 0, 1}); auto y = UnsortedSegmentMax(scope_, x, segment_ids, /*num_segments=*/2); - TensorShape y_shape({2, 2, 5}); - RunTest({x}, {shape}, {y}, {y_shape}); + Tensor x_init_value = + test::AsTensor(kUnsortedSegmentMinMaxTestValue, shape); + TensorShape y_shape({2, 1, 5}); + RunTest({x}, x_init_value, {y}, {y_shape}); } TEST_F(NaryGradTest, UnsortedSegmentMaxGrad_Int64Ids) { - TensorShape shape({3, 2, 5}); + TensorShape shape({3, 1, 5}); auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); auto segment_ids = Const(scope_, {0ll, 0ll, 1ll}); auto y = UnsortedSegmentMax(scope_, x, segment_ids, /*num_segments=*/2); - TensorShape y_shape({2, 2, 5}); - RunTest({x}, {shape}, {y}, {y_shape}); + TensorShape y_shape({2, 1, 5}); + Tensor x_init_value = + test::AsTensor(kUnsortedSegmentMinMaxTestValue, shape); + RunTest({x}, x_init_value, {y}, {y_shape}); } TEST_F(NaryGradTest, UnsortedSegmentMaxGrad_NegativeIds) { - TensorShape shape({3, 2, 5}); + TensorShape shape({3, 1, 5}); auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); auto segment_ids = Const(scope_, {0, 0, -1}); auto y = UnsortedSegmentMax(scope_, x, segment_ids, /*num_segments=*/1); - TensorShape y_shape({1, 2, 5}); - RunTest({x}, {shape}, {y}, {y_shape}); + TensorShape y_shape({1, 1, 5}); + Tensor x_init_value = + test::AsTensor(kUnsortedSegmentMinMaxTestValue, shape); + RunTest({x}, x_init_value, {y}, {y_shape}); } TEST_F(NaryGradTest, UnsortedSegmentMinGrad) { - TensorShape shape({3, 2, 5}); + TensorShape shape({3, 1, 5}); auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); auto segment_ids = Const(scope_, {0, 0, 1}); auto y = UnsortedSegmentMin(scope_, x, segment_ids, /*num_segments=*/2); - TensorShape y_shape({2, 2, 5}); - RunTest({x}, {shape}, {y}, {y_shape}); + TensorShape y_shape({2, 1, 5}); + Tensor x_init_value = + test::AsTensor(kUnsortedSegmentMinMaxTestValue, shape); + RunTest({x}, x_init_value, {y}, {y_shape}); } TEST_F(NaryGradTest, UnsortedSegmentSumGrad) { From 2dd1d399947c41b59d4aac9a8ff6a30af34ff64c Mon Sep 17 00:00:00 2001 From: Yu Feng Date: Thu, 27 Jul 2023 10:44:53 -0700 Subject: [PATCH 251/410] Tighten up DVariable constructor. Ban creating non-DTensor variable from DVariable constructor. If initial_value is not a dtensor, and a layout is provided, relayout the initial_value to DTensor. Also open source variable_test.py PiperOrigin-RevId: 551577201 --- tensorflow/dtensor/python/d_variable.py | 41 +++- tensorflow/dtensor/python/tests/BUILD | 30 +++ .../dtensor/python/tests/device_test.py | 6 +- .../dtensor/python/tests/variable_test.py | 226 ++++++++++++++++++ 4 files changed, 287 insertions(+), 16 deletions(-) create mode 100644 tensorflow/dtensor/python/tests/variable_test.py diff --git a/tensorflow/dtensor/python/d_variable.py b/tensorflow/dtensor/python/d_variable.py index 50eb008bd169f9..b99875f03fa70e 100644 --- a/tensorflow/dtensor/python/d_variable.py +++ b/tensorflow/dtensor/python/d_variable.py @@ -14,7 +14,6 @@ # ============================================================================== """DTensor variable and saveable.""" -import contextlib import functools from tensorflow.dtensor.python import api @@ -22,7 +21,7 @@ from tensorflow.python.eager import context from tensorflow.python.eager import def_function from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors +from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops @@ -214,18 +213,34 @@ def __init__(self, initial_value, *args, dtype=None, **kwargs): with ops.device(variable_device): # If initial tensor assigned to DVariable is DTensor, record the layout of # the resource so that this can be queried. - self.layout = None if context.executing_eagerly(): - try: - self.layout = api.fetch_layout(initial_value) - except (errors.InvalidArgumentError, errors.NotFoundError): - # For Non-DTensor tensors, fetch layout results in expected - # InvalidArgument or NotFoundError depending on whether the API - # is called within DTensor device scope or not. - self.layout = None - pass - mesh = self.layout.mesh if self.layout else None - with api.default_mesh(mesh) if mesh else contextlib.nullcontext(): + if api.is_dtensor(initial_value): + value_layout = api.fetch_layout(initial_value) + if layout is not None and layout != value_layout: + raise errors_impl.InvalidArgumentError( + None, + None, + 'Conflicting layout are provided for initial ' + f'value layout ({value_layout}) and variable ({layout}).', + ) + layout = value_layout + elif layout is not None: + initial_value = api.relayout(initial_value, layout) + else: + raise errors_impl.InvalidArgumentError( + None, + None, + 'Neither layout nor DTensor initial value are provided.', + ) + self.layout = layout + with api.default_mesh(layout.mesh): + super(DVariable, self).__init__( + initial_value, *args, dtype=dtype, **kwargs + ) + else: + # FIXME(175928457): Record value layout in graph mode. + if layout is not None: + initial_value = api.relayout(initial_value, layout) super(DVariable, self).__init__( initial_value, *args, dtype=dtype, **kwargs) diff --git a/tensorflow/dtensor/python/tests/BUILD b/tensorflow/dtensor/python/tests/BUILD index 5b64ad1d4b549b..615baad3085f67 100644 --- a/tensorflow/dtensor/python/tests/BUILD +++ b/tensorflow/dtensor/python/tests/BUILD @@ -599,3 +599,33 @@ dtensor_test( "@absl_py//absl/testing:parameterized", ], ) + +dtensor_test( + name = "variable_test", + srcs = ["variable_test.py"], + disable_tfrt = [ + "tpu", + "gpu", + ], # b/198521331 timeout on TFRT TPU. + main = "variable_test.py", + tags = [ + "nomultivm", + ], + deps = [ + ":test_util", + "//tensorflow/dtensor/python:api", + "//tensorflow/dtensor/python:d_variable", + "//tensorflow/dtensor/python:layout", + "//tensorflow/python/eager/polymorphic_function", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:errors", + "//tensorflow/python/framework:ops", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:array_ops_stack", + "//tensorflow/python/ops:math_ops", + "//tensorflow/python/ops:variables", + "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", + ], +) diff --git a/tensorflow/dtensor/python/tests/device_test.py b/tensorflow/dtensor/python/tests/device_test.py index 3e8c073e4724f0..e1e5c66e21c341 100644 --- a/tensorflow/dtensor/python/tests/device_test.py +++ b/tensorflow/dtensor/python/tests/device_test.py @@ -197,7 +197,7 @@ def testImplicitCopyVariableOnCPUMesh(self, dtype, shape): self.skipForDeviceType( ["GPU", "TPU"], "Variable implicit copy is only allowed for CPU mesh.") - variable = d_variable.DVariable(array_ops.ones(shape=shape, dtype=dtype)) + variable = variables.Variable(array_ops.ones(shape=shape, dtype=dtype)) new_value = array_ops.zeros(shape=shape, dtype=dtype) @polymorphic_function.function @@ -561,8 +561,8 @@ def testPackingVariablesRaisesError(self): ): api._dtensor_device().pack( [ - d_variable.DVariable(array_ops.ones([2, 3])), - d_variable.DVariable(array_ops.ones([2, 3])), + variables.Variable(array_ops.ones([2, 3])), + variables.Variable(array_ops.ones([2, 3])), ], Layout.replicated(self.mesh, rank=2), ) diff --git a/tensorflow/dtensor/python/tests/variable_test.py b/tensorflow/dtensor/python/tests/variable_test.py new file mode 100644 index 00000000000000..2c3343b609824c --- /dev/null +++ b/tensorflow/dtensor/python/tests/variable_test.py @@ -0,0 +1,226 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for DTensor support of Variables.""" +import numpy as np + +from tensorflow.dtensor.python import api +from tensorflow.dtensor.python import d_variable +from tensorflow.dtensor.python import layout as layout_lib +from tensorflow.dtensor.python.tests import test_util +from tensorflow.python.eager.polymorphic_function import polymorphic_function +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors_impl +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import array_ops_stack +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +# Makes a 1D mesh with dimension X(2). +_MESH_DIM_X = 'x' +_DEVICE_IDS = test_util.create_device_ids_array((2,)) +_ONE_D_CPU_MESH = layout_lib.Mesh( + [_MESH_DIM_X], + _DEVICE_IDS, + np.ravel(_DEVICE_IDS).tolist(), + test_util.create_device_list((2,), 'CPU'), +) +_ONE_D_TPU_MESH = layout_lib.Mesh( + [_MESH_DIM_X], + _DEVICE_IDS, + np.ravel(_DEVICE_IDS).tolist(), + test_util.create_device_list((2,), 'TPU'), +) +_ONE_D_GPU_MESH = layout_lib.Mesh( + [_MESH_DIM_X], + _DEVICE_IDS, + np.ravel(_DEVICE_IDS).tolist(), + test_util.create_device_list((2,), 'GPU'), +) + +Layout = layout_lib.Layout +UNSHARDED = layout_lib.UNSHARDED +DVariable = d_variable.DVariable + + +class Var(object): + + def __init__(self): + self.v = None + + +class DTensorVariableTest(test_util.DTensorBaseTest): + + def setUp(self): + super(DTensorVariableTest, self).setUp() + mesh_dict = { + 'CPU': _ONE_D_CPU_MESH, + 'GPU': _ONE_D_GPU_MESH, + 'TPU': _ONE_D_TPU_MESH, + } + self.mesh = self.configTestMesh(mesh_dict) + self._replicated_layout = Layout([UNSHARDED, UNSHARDED], self.mesh) + self._one_d_replicated_layout = Layout([UNSHARDED], self.mesh) + self._scalar_replicated_layout = Layout([], self.mesh) + self._one_d_shard_layout = Layout([_MESH_DIM_X], self.mesh) + self._first_d_shard_layout = Layout([_MESH_DIM_X, UNSHARDED], self.mesh) + + def testNonDtensorVariable(self): + non_dtensor_variable = variables.Variable(1.0) + with ops.device_v2(api.device_name()): + with self.assertRaisesRegex( + errors_impl.InvalidArgumentError, + 'No default mesh has been registered to DTensor', + ): + non_dtensor_variable.read_value() + + def testDVariableNoLayout(self): + with self.assertRaisesRegex( + errors_impl.InvalidArgumentError, + 'Neither layout nor DTensor initial value are provided.', + ): + DVariable(1.0) + + def testDVariableConflictingLayout(self): + a = api.relayout([1, 2, 3, 4], self._one_d_replicated_layout) + with self.assertRaisesRegex( + errors_impl.InvalidArgumentError, 'Conflicting layout are provided' + ): + DVariable(a, layout=self._one_d_shard_layout) + + def testVariable(self): + with ops.device_v2(api.device_name()): + initial_value = api.relayout([1.0], self._one_d_replicated_layout) + v = variables.Variable(initial_value) + v = api.relayout(v, self._one_d_replicated_layout) + api.check_layout(v, self._one_d_replicated_layout) + + def testVariableWithInitialValue(self): + a = constant_op.constant([1.0]) + a = api.relayout(a, self._one_d_replicated_layout) + with ops.device_v2(api.device_name()): + v = variables.Variable(initial_value=a) + api.check_layout(v, self._one_d_replicated_layout) + to_add = api.relayout([1.0], self._one_d_replicated_layout) + v = v.assign_add(to_add) + api.check_layout(v, self._one_d_replicated_layout) + + def testVarAssignmentOpByOp(self): + v = constant_op.constant(1.0) + v = api.relayout(v, Layout.replicated(self.mesh, rank=0)) + w = d_variable.DVariable(v) + api.check_layout(w, Layout.replicated(self.mesh, rank=0)) + self.assertEqual(w.numpy(), 1.0) + w.assign_add(v) + api.check_layout(w, Layout.replicated(self.mesh, rank=0)) + self.assertEqual(w.numpy(), 2.0) + + def testVarInitOutsideTfFunction(self): + v = constant_op.constant(1.0) + v = api.relayout(v, Layout.replicated(self.mesh, rank=0)) + + w = d_variable.DVariable(v) + + @polymorphic_function.function() + def assign_var(x): + w.assign(x * 2) + return w + x + + out = assign_var(constant_op.constant(1.0)) + api.check_layout(w, Layout.replicated(self.mesh, rank=0)) + self.assertEqual(w.numpy(), 2.0) + self.assertEqual(out.numpy(), 3.0) + + def testDVariableInitFromValues(self): + # Python value 1 will be converted as dtypes.int32 without a dtype. + non_dtensor_variable = variables.Variable(1, dtype=dtypes.int64) + with ops.device_v2(api.device_name()): + with self.assertRaisesRegex( + errors_impl.InvalidArgumentError, + 'No default mesh has been registered to DTensor', + ): + non_dtensor_variable.read_value() + + dtensor_variable = DVariable( + api.relayout( + constant_op.constant(1, dtype=dtypes.int64), + Layout.replicated(self.mesh, rank=0), + ) + ) + + with ops.device_v2(api.device_name()): + dtensor_variable = DVariable( + api.relayout( + constant_op.constant(1, dtype=dtypes.int64), + Layout.replicated(self.mesh, rank=0), + ) + ) + self.assertEqual(dtensor_variable.numpy(), 1) + + def testCreateVarInsideFunctionWithInitScope(self): + var = Var() + + @polymorphic_function.function + def assign_add(): + with ops.init_scope(): + if var.v is None: + c = constant_op.constant(1.0) + c = api.relayout(c, Layout.replicated(self.mesh, rank=0)) + var.v = variables.Variable(c) + var.v.assign_add(1.0) + + with api._dtensor_device()._default_layout( + Layout.replicated(self.mesh, rank=0)): + assign_add() + output = var.v.read_value() + api.check_layout(output, Layout.replicated(self.mesh, rank=0)) + self.assertAllEqual(output, 2.) + + def testBufferAliasingOnDF(self): + # TODO(b/239471086): re-enable once b/239471086 is fixed + self.skipTest('Disabled due to b/239471086') + self.skipForDeviceType(['GPU', 'CPU'], 'Test only applies to DF TPU') + + @polymorphic_function.function + def add_var(v): + new_v = array_ops_stack.stack([v, v]) + v.assign(math_ops.reduce_sum(new_v, axis=0)) + # Note that this only works with DF. When updated to PF, we need to + # adjust tensor size accordingly. + # + # Without aliasing, the returned tensor is 3.5G * 4 = 14G, plus the + # reserved 500MB and 3.5G arguments, this exceeds the 16G memory. + # With aliasing, the 3.5G arguments will be aliased so it lowers the + # memory pressure to 18-3.5 = 14.5G, which barely fits the memory. + return v, new_v + + # Around 3.5G tensor. + v = DVariable( + initial_value=api.relayout( + array_ops.ones((7, 512, 1024, 256), dtype=dtypes.float32), + Layout.replicated(self.mesh, rank=4), + ) + ) + + add_var(v) + + self.assertEqual(api.fetch_layout(v), Layout.replicated(self.mesh, rank=4)) + + +if __name__ == '__main__': + test.main() From 5c1559b3a8fe45c7b5d1c4381cc2e79406936a28 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 27 Jul 2023 10:54:42 -0700 Subject: [PATCH 252/410] Include indirect dependencies check under `DependOn()` in AllReduce combiner using side effect analysis. This enables multi-steps training in a TF function, or finer-grain AllReduce combinations for optimization purposes. PiperOrigin-RevId: 551580422 --- tensorflow/dtensor/mlir/BUILD | 1 + .../dtensor_allreduce_combine_optimization.cc | 35 +++++++++++++------ .../dtensor/python/tests/collective_test.py | 18 ++-------- 3 files changed, 29 insertions(+), 25 deletions(-) diff --git a/tensorflow/dtensor/mlir/BUILD b/tensorflow/dtensor/mlir/BUILD index d627c163409b62..2e7f1a66247b47 100644 --- a/tensorflow/dtensor/mlir/BUILD +++ b/tensorflow/dtensor/mlir/BUILD @@ -235,6 +235,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:convert_tensor", "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", "//tensorflow/compiler/mlir/tensorflow:error_util", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_analysis", "//tensorflow/compiler/mlir/tensorflow:tensorflow_attributes", "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", diff --git a/tensorflow/dtensor/mlir/dtensor_allreduce_combine_optimization.cc b/tensorflow/dtensor/mlir/dtensor_allreduce_combine_optimization.cc index 038df0a00eecd8..ea4f08e88b1c49 100644 --- a/tensorflow/dtensor/mlir/dtensor_allreduce_combine_optimization.cc +++ b/tensorflow/dtensor/mlir/dtensor_allreduce_combine_optimization.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -32,6 +33,7 @@ limitations under the License. #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/TopologicalSortUtils.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" @@ -61,7 +63,8 @@ constexpr int32 kAllReducePadding = 1024; // TODO(jiawenhao): Repeatedly computing dependency sets for a large cluster can // get expensive when the number of all-reduces is high. Consider building a // cluster-scope op dependency graph ahead of time to amortize the cost. -bool DependsOn(mlir::Operation* successor, mlir::Operation* predecessor) { +bool DependsOn(mlir::Operation* successor, mlir::Operation* predecessor, + const mlir::TF::detail::SideEffectAnalysisInfo& info) { llvm::SmallVector to_visit; llvm::SmallPtrSet visited; to_visit.push_back(predecessor); @@ -74,6 +77,11 @@ bool DependsOn(mlir::Operation* successor, mlir::Operation* predecessor) { if (visited.contains(user)) continue; to_visit.push_back(user); } + // Include indirectly dependent ops from side effects + for (mlir::Operation* user : info.DirectControlSuccessors(producer)) { + if (visited.contains(user)) continue; + to_visit.push_back(user); + } } return false; } @@ -224,16 +232,17 @@ mlir::LogicalResult MergeAllReduceGroup( // Dump the dependencies between AllReduce ops as a DOT graph. std::string DrawAllReduceDependencies( - std::vector all_reduces) { + std::vector all_reduces, + const mlir::TF::detail::SideEffectAnalysisInfo& info) { std::vector> dependents(all_reduces.size(), std::vector()); for (int j = 0; j < all_reduces.size(); ++j) { mlir::TF::DTensorAllReduceOp later = all_reduces[j]; for (int i = 0; i < j; ++i) { mlir::TF::DTensorAllReduceOp earlier = all_reduces[i]; - DCHECK(!DependsOn(earlier, later)); + DCHECK(!DependsOn(earlier, later, info)); if (earlier->getBlock() != later->getBlock() || - DependsOn(later, earlier)) { + DependsOn(later, earlier, info)) { dependents[i].push_back(j); } } @@ -357,7 +366,8 @@ bool same_group_assignments(mlir::Value group_assignment_a, std::vector> createIndependentReduceOpsGroups( - const std::vector& ordered_all_reduces) { + const std::vector& ordered_all_reduces, + const mlir::TF::detail::SideEffectAnalysisInfo& info) { // Build a reverse adjacency matrix from node to its dependents. std::vector> dependents(ordered_all_reduces.size(), std::vector()); @@ -366,8 +376,8 @@ createIndependentReduceOpsGroups( mlir::TF::DTensorAllReduceOp requirement = ordered_all_reduces[i]; for (int j = i + 1; j < num_all_reduces; ++j) { mlir::TF::DTensorAllReduceOp dependent = ordered_all_reduces[j]; - DCHECK(!DependsOn(requirement, - dependent)); // guaranteed by program order + DCHECK(!DependsOn(requirement, dependent, + info)); // guaranteed by program order // In this example, all three DTensorAllReduce ops are independent // from each other according to MLIR value use-def chains considered // by DependsOn. However, moving all three to after the WhileRegion @@ -394,7 +404,7 @@ createIndependentReduceOpsGroups( // on" the first one, and the third on the second. This effectively // prevents any two DTensorAllReduce from merging together. if (requirement->getBlock() != dependent->getBlock() || - DependsOn(dependent, requirement)) { + DependsOn(dependent, requirement, info)) { dependents[i].push_back(j); } } @@ -434,7 +444,7 @@ createIndependentReduceOpsGroups( // Export the all reduces as a DOT graph. VLOG(4) << "Visualizing AllReduce dependencies:\n" - << DrawAllReduceDependencies(ordered_all_reduces); + << DrawAllReduceDependencies(ordered_all_reduces, info); return all_reduce_groups; } @@ -675,10 +685,15 @@ struct DTensorAllReduceCombineOptimization }); if (ordered_all_reduces.size() > 1) { + // Build side effect analysis to identify indirect dependencies between + // all eligible all_reduce operations + mlir::TF::SideEffectAnalysis side_effect_analysis(module); + const mlir::TF::detail::SideEffectAnalysisInfo& info = + side_effect_analysis.GetAnalysisForFunc(function); // Create dependency graph for all eligible all_reduce operations, // so that independent ops can be merged auto all_reduce_groups = - createIndependentReduceOpsGroups(ordered_all_reduces); + createIndependentReduceOpsGroups(ordered_all_reduces, info); VLOG(2) << ordered_all_reduces.size() << " all-reduce ops in " << all_reduce_groups.size() << " groups"; diff --git a/tensorflow/dtensor/python/tests/collective_test.py b/tensorflow/dtensor/python/tests/collective_test.py index 0e9886f1985447..a6af55a5f2cd18 100644 --- a/tensorflow/dtensor/python/tests/collective_test.py +++ b/tensorflow/dtensor/python/tests/collective_test.py @@ -573,21 +573,9 @@ def func(v): expected_result += expected_result + math_ops.reduce_sum(expected_result) expected_result = math_ops.reduce_sum(expected_result) - # TODO(b/288347987): AllReduce combiner would currently generate the wrong - # output from merging the two ops, where b would not read the updated v from - # addition with a and generate the wrong output. - # - # Mismatched elements: 1 / 1 (100%) - # Max absolute difference: 1056. - # Max relative difference: 33. - # x: array(1088.) # expected output - # y: array(32.) # dtensor output - # - # This test is set to pass by asserting equal on dtensor_result itself. Once - # the bug is fixed, please update the check to expected_result. - self.assertDTensorEqual(dtensor_result, # FIXME: update to expected_result - Layout.replicated(mesh=mesh, rank=0), - dtensor_result) + self.assertDTensorEqual( + expected_result, Layout.replicated(mesh=mesh, rank=0), dtensor_result + ) if __name__ == '__main__': test.main() From a9e383efca4cf62119be163ec8212e86afbf375c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 27 Jul 2023 10:55:30 -0700 Subject: [PATCH 253/410] Add mlir tests for the current FFT SPMD expander(CL 550657034). PiperOrigin-RevId: 551580678 --- tensorflow/dtensor/mlir/tests/spmd_fft.mlir | 150 ++++++++++++++++++++ 1 file changed, 150 insertions(+) create mode 100644 tensorflow/dtensor/mlir/tests/spmd_fft.mlir diff --git a/tensorflow/dtensor/mlir/tests/spmd_fft.mlir b/tensorflow/dtensor/mlir/tests/spmd_fft.mlir new file mode 100644 index 00000000000000..c9f44f21b3f73b --- /dev/null +++ b/tensorflow/dtensor/mlir/tests/spmd_fft.mlir @@ -0,0 +1,150 @@ +// RUN: dtensor-opt %s -split-input-file -dtensor-spmd-expansion -verify-diagnostics | FileCheck %s --dump-input=fail + +// Check the SPMD expansion for FFT2D +// CHECK-LABEL: module @test_FFT2D +module @test_FFT2D { + func.func @main(%arg0: tensor, + %arg1: tensor<2x4x8xcomplex> {tf._layout = "sharding_specs:b,x,unsharded, mesh:|b=1,x=1,y=2|0,1|0,1|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1", tf._mesh = "|b=1,x=1,y=2|0,1|0,1|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1"}) { + // CHECK: "tf_device.cluster" + // CHECK: %[[FFT_OUT_1:.*]] = "tf.FFT"(%arg1) + // CHECK-SAME: _layout = ["sharding_specs:b,x,unsharded, mesh:|b=1,x=1,y=2|0,1|0,1|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1"] + // CHECK-SAME: (tensor<2x4x8xcomplex>) -> tensor<2x4x8xcomplex> + // CHECK-NEXT: %[[CONST_OUT_1:.*]] = "tf.Const"() + // CHECK-NEXT: %[[TRANS_OUT_1:.*]] = "tf.Transpose"(%[[FFT_OUT_1]], %[[CONST_OUT_1]]) + // CHECK-SAME: (tensor<2x4x8xcomplex>, tensor<3xi64>) -> tensor<2x8x4xcomplex> + // CHECK-NEXT: %[[IDENT_OUT:.*]] = "tf.Identity"(%[[TRANS_OUT_1]]) + // CHECK-NEXT: %[[FFT_OUT_2:.*]] = "tf.FFT"(%[[IDENT_OUT]]) + // CHECK-SAME: _layout = ["sharding_specs:b,x,unsharded, mesh:|b=1,x=1,y=2|0,1|0,1|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1"] + // CHECK-SAME: (tensor<2x8x4xcomplex>) -> tensor<2x8x4xcomplex> + // CHECK-NEXT: %[[CONST_OUT_2:.*]] = "tf.Const"() + // CHECK-NEXT: %[[TRANS_OUT_2:.*]] = "tf.Transpose"(%[[FFT_OUT_2]], %[[CONST_OUT_2]]) + // CHECK-SAME: (tensor<2x8x4xcomplex>, tensor<3xi64>) -> tensor<2x4x8xcomplex> + // CHECK-NEXT: tf_device.return + %0 = "tf_device.cluster"() ({ + %1 = "tf.DTensorLayout"(%arg1) {global_shape = #tf_type.shape<2x4x8>, layout = #dtensor.layout} : (tensor<2x4x8xcomplex>) -> tensor<2x4x8xcomplex> + %2 = "tf.FFT2D"(%1) : (tensor<2x4x8xcomplex>) -> tensor<2x4x8xcomplex> + %3 = "tf.DTensorLayout"(%2) {global_shape = #tf_type.shape<2x4x8>, layout = #dtensor.layout} : (tensor<2x4x8xcomplex>) -> tensor<2x4x8xcomplex> + tf_device.return %3 : tensor<2x4x8xcomplex> + }) {_mesh = "|b=1,x=1,y=2|0,1|0,1|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1"} : () -> tensor<2x4x8xcomplex> + func.return + } +} + +// ----- + +// Check the SPMD expansion for IFFT2D +// CHECK-LABEL: module @test_IFFT2D +module @test_IFFT2D { + func.func @main(%arg0: tensor, + %arg1: tensor<2x4x8xcomplex> {tf._layout = "sharding_specs:b,x,y, mesh:|b=1,x=1,y=2|0,1|0,1|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1", tf._mesh = "|b=1,x=1,y=2|0,1|0,1|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1"}) { + // CHECK: "tf_device.cluster" + // CHECK: %[[CONST_OUT_1:.*]] = "tf.Const"() + // CHECK-NEXT: %[[TRANS_OUT_1:.*]] = "tf.Transpose"(%arg1, %[[CONST_OUT_1]]) + // CHECK-SAME: (tensor<2x4x4xcomplex>, tensor<3xi64>) -> tensor<2x4x4xcomplex> + // CHECK-NEXT: %[[IDENT_OUT:.*]] = "tf.Identity"(%[[TRANS_OUT_1]]) + // CHECK-NEXT: %[[IFFT_OUT_1:.*]] = "tf.IFFT"(%[[IDENT_OUT]]) + // CHECK-SAME: _layout = ["sharding_specs:b,y,unsharded, mesh:|b=1,x=1,y=2|0,1|0,1|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1"] + // CHECK-SAME: (tensor<2x4x4xcomplex>) -> tensor<2x4x4xcomplex> + // CHECK-NEXT: %[[CONST_OUT_2:.*]] = "tf.Const"() + // CHECK-NEXT: %[[TRANS_OUT_2:.*]] = "tf.Transpose"(%[[IFFT_OUT_1]], %[[CONST_OUT_2]]) + // CHECK-SAME: (tensor<2x4x4xcomplex>, tensor<3xi64>) -> tensor<2x4x4xcomplex> + // CHECK-NEXT: %[[ALLTOALL_OUT:.*]] = "tf.DTensorAllToAll"(%[[TRANS_OUT_2]]) + // CHECK-SAME: _layout = ["sharding_specs:b,y,unsharded, mesh:|b=1,x=1,y=2|0,1|0,1|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1"] + // CHECK-SAME: (tensor<2x4x4xcomplex>) -> tensor<2x2x8xcomplex> + // CHECK-NEXT: %[[IFFT_OUT_2:.*]] = "tf.IFFT"(%[[ALLTOALL_OUT]]) + // CHECK-SAME: _layout = ["sharding_specs:b,y,unsharded, mesh:|b=1,x=1,y=2|0,1|0,1|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1"] + // CHECK-SAME: (tensor<2x2x8xcomplex>) -> tensor<2x2x8xcomplex> + // CHECK-NEXT: tf_device.return + %0 = "tf_device.cluster"() ({ + %1 = "tf.DTensorLayout"(%arg1) {global_shape = #tf_type.shape<2x4x8>, layout = #dtensor.layout} : (tensor<2x4x8xcomplex>) -> tensor<2x4x8xcomplex> + %2 = "tf.IFFT2D"(%1) : (tensor<2x4x8xcomplex>) -> tensor<2x4x8xcomplex> + %3 = "tf.DTensorLayout"(%2) {global_shape = #tf_type.shape<2x4x8>, layout = #dtensor.layout} : (tensor<2x4x8xcomplex>) -> tensor<2x4x8xcomplex> + tf_device.return %3 : tensor<2x4x8xcomplex> + }) {_mesh = "|b=1,x=1,y=2|0,1|0,1|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1"} : () -> tensor<2x4x8xcomplex> + func.return + } +} + +// ----- + +// Check the SPMD expansion for RFFT2D +// CHECK-LABEL: module @test_RFFT2D +module @test_RFFT2D { + func.func @main(%arg0: tensor, + %arg1: tensor<2x4x12xf64> {tf._layout = "sharding_specs:b,x,y, mesh:|b=1,x=1,y=2|0,1|0,1|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1", tf._mesh = "|b=1,x=1,y=2|0,1|0,1|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1"}, + %arg2: tensor<2xi32> {tf._layout = "sharding_specs:unsharded, mesh:|b=1,x=1,y=2|0,1|0,1|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1", tf._mesh = "|b=1,x=1,y=2|0,1|0,1|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1"}) { + // CHECK: "tf_device.cluster" + // CHECK: %[[CONST_OUT_1:.*]] = "tf.Const"() + // CHECK-SAME: _layout = ["sharding_specs:unsharded, mesh:|b=1,x=1,y=2|0,1|0,1|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1"] + // CHECK-NEXT: %[[ALLGATHER_OUT:.*]] = "tf.DTensorAllGather"(%arg1) + // CHECK-SAME: _layout = ["sharding_specs:b,x,unsharded, mesh:|b=1,x=1,y=2|0,1|0,1|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1"] + // CHECK: %[[CONST_OUT_2:.*]] = "tf.Const"() + // CHECK: %[[RFFT_OUT:.*]] = "tf.RFFT"(%[[ALLGATHER_OUT]], %[[CONST_OUT_2]]) + // CHECK-SAME: _layout = ["sharding_specs:b,x,unsharded, mesh:|b=1,x=1,y=2|0,1|0,1|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1"] + // CHECK-SAME: (tensor<2x4x12xf64>, tensor<1xi32>) -> tensor<2x4x6xcomplex> + // CHECK-NEXT: %[[CONST_OUT_3:.*]] = "tf.Const"() + // CHECK-NEXT: %[[TRANS_OUT_1:.*]] = "tf.Transpose"(%[[RFFT_OUT]], %[[CONST_OUT_3]]) + // CHECK-SAME: (tensor<2x4x6xcomplex>, tensor<3xi64>) -> tensor<2x6x4xcomplex> + // CHECK-NEXT: %[[IDENT_OUT:.*]] = "tf.Identity"(%[[TRANS_OUT_1]]) + // CHECK-NEXT: %[[FFT_OUT:.*]] = "tf.FFT"(%[[IDENT_OUT]]) + // CHECK-SAME: _layout = ["sharding_specs:b,x,unsharded, mesh:|b=1,x=1,y=2|0,1|0,1|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1"] + // CHECK-SAME: (tensor<2x6x4xcomplex>) -> tensor<2x6x4xcomplex> + // CHECK-NEXT: %[[CONST_OUT_4:.*]] = "tf.Const"() + // CHECK-NEXT: %[[TRANS_OUT_2:.*]] = "tf.Transpose"(%[[FFT_OUT]], %[[CONST_OUT_4]]) + // CHECK-SAME: _layout = ["sharding_specs:b,unsharded,x, mesh:|b=1,x=1,y=2|0,1|0,1|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1"] + // CHECK-SAME: (tensor<2x6x4xcomplex>, tensor<3xi64>) -> tensor<2x4x6xcomplex> + // CHECK-NEXT: tf_device.return + %0 = "tf_device.cluster"() ({ + %cst = "tf.Const"() {value = dense<[4, 10]> : tensor<2xi32>} : () -> tensor<2xi32> + %1 = "tf.DTensorLayout"(%arg2) {global_shape = #tf_type.shape<2>, layout = #dtensor.layout} : (tensor<2xi32>) -> tensor<2xi32> + %2 = "tf.DTensorLayout"(%arg1) {global_shape = #tf_type.shape<2x4x12>, layout = #dtensor.layout} : (tensor<2x4x12xf64>) -> tensor<2x4x12xf64> + %3 = "tf.DTensorLayout"(%cst) {global_shape = #tf_type.shape<2>, layout = #dtensor.layout} : (tensor<2xi32>) -> tensor<2xi32> + %4 = "tf.RFFT2D"(%2, %3) : (tensor<2x4x12xf64>, tensor<2xi32>) -> tensor<2x4x6xcomplex> + %5 = "tf.DTensorLayout"(%4) {global_shape = #tf_type.shape<2x4x6>, layout = #dtensor.layout} : (tensor<2x4x6xcomplex>) -> tensor<2x4x6xcomplex> + tf_device.return %5 : tensor<2x4x6xcomplex> + }) {_mesh = "|b=1,x=1,y=2|0,1|0,1|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1"} : () -> tensor<2x4x6xcomplex> + func.return + } +} + + +// ----- + +// Check the SPMD expansion for IRFFT2D +// CHECK-LABEL: module @test_IRFFT2D +module @test_IRFFT2D { + func.func @main(%arg0: tensor, + %arg1: tensor<2x4x8xcomplex> {tf._layout = "sharding_specs:b,unsharded,y, mesh:|b=1,x=1,y=2|0,1|0,1|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1", tf._mesh = "|b=1,x=1,y=2|0,1|0,1|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1"}, + %arg2: tensor<2xi32> {tf._layout = "sharding_specs:unsharded, mesh:|b=1,x=1,y=2|0,1|0,1|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1", tf._mesh = "|b=1,x=1,y=2|0,1|0,1|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1"}) { + // CHECK: "tf_device.cluster" + // CHECK: %[[CONST_OUT_1:.*]] = "tf.Const"() + // CHECK-SAME: _layout = ["sharding_specs:unsharded, mesh:|b=1,x=1,y=2|0,1|0,1|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1"] + // CHECK-NEXT: %[[CONST_OUT_2:.*]] = "tf.Const"() + // CHECK-NEXT: %[[TRANS_OUT_1:.*]] = "tf.Transpose"(%arg1, %[[CONST_OUT_2]]) + // CHECK-SAME: (tensor<2x4x4xcomplex>, tensor<3xi64>) -> tensor<2x4x4xcomplex> + // CHECK-NEXT: %[[IFFT_OUT:.*]] = "tf.IFFT"(%[[TRANS_OUT_1]]) + // CHECK-SAME: _layout = ["sharding_specs:b,y,unsharded, mesh:|b=1,x=1,y=2|0,1|0,1|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1"] + // CHECK-SAME: (tensor<2x4x4xcomplex>) -> tensor<2x4x4xcomplex> + // CHECK-NEXT: %[[CONST_OUT_3:.*]] = "tf.Const"() + // CHECK-NEXT: %[[TRANS_OUT_2:.*]] = "tf.Transpose"(%[[IFFT_OUT]], %[[CONST_OUT_3]]) + // CHECK-SAME: (tensor<2x4x4xcomplex>, tensor<3xi64>) -> tensor<2x4x4xcomplex> + // CHECK-NEXT: %[[ALLTOALL_OUT:.*]] = "tf.DTensorAllToAll"(%[[TRANS_OUT_2]]) + // CHECK-SAME: _layout = ["sharding_specs:b,y,unsharded, mesh:|b=1,x=1,y=2|0,1|0,1|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1"] + // CHECK-SAME: (tensor<2x4x4xcomplex>) -> tensor<2x2x8xcomplex> + // CHECK-NEXT: %[[CONST_OUT_4:.*]] = "tf.Const"() + // CHECK-NEXT: %[[IRFFT_OUT:.*]] = "tf.IRFFT"(%[[ALLTOALL_OUT]], %[[CONST_OUT_4]]) + // CHECK-SAME: _layout = ["sharding_specs:b,y,unsharded, mesh:|b=1,x=1,y=2|0,1|0,1|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1"] + // CHECK-SAME: (tensor<2x2x8xcomplex>, tensor<1xi32>) -> tensor<2x2x8xf64> + // CHECK-NEXT: tf_device.return + %0 = "tf_device.cluster"() ({ + %cst = "tf.Const"() {value = dense<[4, 8]> : tensor<2xi32>} : () -> tensor<2xi32> + %1 = "tf.DTensorLayout"(%arg2) {global_shape = #tf_type.shape<2>, layout = #dtensor.layout} : (tensor<2xi32>) -> tensor<2xi32> + %2 = "tf.DTensorLayout"(%arg1) {global_shape = #tf_type.shape<2x4x8>, layout = #dtensor.layout} : (tensor<2x4x8xcomplex>) -> tensor<2x4x8xcomplex> + %3 = "tf.DTensorLayout"(%cst) {global_shape = #tf_type.shape<2>, layout = #dtensor.layout} : (tensor<2xi32>) -> tensor<2xi32> + %4 = "tf.IRFFT2D"(%2, %3) : (tensor<2x4x8xcomplex>, tensor<2xi32>) -> tensor<2x4x8xf64> + %5 = "tf.DTensorLayout"(%4) {global_shape = #tf_type.shape<2x4x8>, layout = #dtensor.layout} : (tensor<2x4x8xf64>) -> tensor<2x4x8xf64> + tf_device.return %5 : tensor<2x4x8xf64> + }) {_mesh = "|b=1,x=1,y=2|0,1|0,1|/job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1"} : () -> tensor<2x4x8xf64> + func.return + } +} \ No newline at end of file From 94f293389cbf6b33c8acd95727398f37bf8b7401 Mon Sep 17 00:00:00 2001 From: Austin Anderson Date: Thu, 27 Jul 2023 10:56:43 -0700 Subject: [PATCH 254/410] Fix paths and convert continuous envs from cache to RBE PiperOrigin-RevId: 551581062 --- ci/official/envs/continuous_linux_x86_cpu_py310 | 4 ++-- ci/official/envs/continuous_linux_x86_cpu_py311 | 4 ++-- ci/official/envs/continuous_linux_x86_cpu_py39 | 4 ++-- ci/official/envs/continuous_linux_x86_cuda_py310 | 2 +- ci/official/envs/continuous_linux_x86_cuda_py311 | 2 +- ci/official/envs/continuous_linux_x86_cuda_py39 | 2 +- ci/official/envs/nightly_libtensorflow_linux_x86_cpu | 2 +- ci/official/envs/nightly_linux_x86_cpu_py310 | 2 +- ci/official/envs/nightly_linux_x86_cpu_py311 | 2 +- ci/official/envs/nightly_linux_x86_cpu_py39 | 2 +- ci/official/utilities/rename_and_verify_wheels.sh | 6 +++--- 11 files changed, 16 insertions(+), 16 deletions(-) diff --git a/ci/official/envs/continuous_linux_x86_cpu_py310 b/ci/official/envs/continuous_linux_x86_cpu_py310 index 46997fb367d3db..7e3370b1ce5829 100644 --- a/ci/official/envs/continuous_linux_x86_cpu_py310 +++ b/ci/official/envs/continuous_linux_x86_cpu_py310 @@ -1,7 +1,7 @@ #TFCI_UPLOAD_LIB_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" #TFCI_UPLOAD_WHL_GCS_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" TFCI_BAZEL_BAZELRC_ARGS=(--bazelrc ./ci/official/bazelrcs/cpu.bazelrc) -TFCI_BAZEL_COMMON_ARGS=(--config sigbuild_remote_cache_push) +TFCI_BAZEL_COMMON_ARGS=(--config rbe) TFCI_BUILD_PIP_PACKAGE_ARGS=(--cpu) TFCI_COPYBARA_ENABLE=0 TFCI_DOCKER_ENABLE=1 @@ -12,7 +12,7 @@ TFCI_GIT_DIR=$KOKORO_ARTIFACTS_DIR/github/tensorflow TFCI_INDEX_HTML_ENABLE=1 TFCI_LIB_SUFFIX="-cpu-linux-x86_64" TFCI_NIGHTLY_UPDATE_VERSION_ENABLE= -TFCI_NVIDIA_SMI_ENABLE=1 +TFCI_NVIDIA_SMI_ENABLE= TFCI_UPLOAD_LIB_ENABLE= TFCI_UPLOAD_LIB_LATEST_ENABLE= TFCI_UPLOAD_LIB_LATEST_URI= diff --git a/ci/official/envs/continuous_linux_x86_cpu_py311 b/ci/official/envs/continuous_linux_x86_cpu_py311 index 1e3b7df5ea6857..6b057815a6f24a 100644 --- a/ci/official/envs/continuous_linux_x86_cpu_py311 +++ b/ci/official/envs/continuous_linux_x86_cpu_py311 @@ -1,7 +1,7 @@ #TFCI_UPLOAD_LIB_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" #TFCI_UPLOAD_WHL_GCS_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" TFCI_BAZEL_BAZELRC_ARGS=(--bazelrc ./ci/official/bazelrcs/cpu.bazelrc) -TFCI_BAZEL_COMMON_ARGS=(--config sigbuild_remote_cache_push) +TFCI_BAZEL_COMMON_ARGS=(--config rbe) TFCI_BUILD_PIP_PACKAGE_ARGS=(--cpu) TFCI_COPYBARA_ENABLE=0 TFCI_DOCKER_ENABLE=1 @@ -12,7 +12,7 @@ TFCI_GIT_DIR=$KOKORO_ARTIFACTS_DIR/github/tensorflow TFCI_INDEX_HTML_ENABLE=1 TFCI_LIB_SUFFIX="-cpu-linux-x86_64" TFCI_NIGHTLY_UPDATE_VERSION_ENABLE= -TFCI_NVIDIA_SMI_ENABLE=1 +TFCI_NVIDIA_SMI_ENABLE= TFCI_UPLOAD_LIB_ENABLE= TFCI_UPLOAD_LIB_LATEST_ENABLE= TFCI_UPLOAD_LIB_LATEST_URI= diff --git a/ci/official/envs/continuous_linux_x86_cpu_py39 b/ci/official/envs/continuous_linux_x86_cpu_py39 index fc7ca80562235d..f640481fc9ce83 100644 --- a/ci/official/envs/continuous_linux_x86_cpu_py39 +++ b/ci/official/envs/continuous_linux_x86_cpu_py39 @@ -1,7 +1,7 @@ #TFCI_UPLOAD_LIB_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" #TFCI_UPLOAD_WHL_GCS_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" TFCI_BAZEL_BAZELRC_ARGS=(--bazelrc ./ci/official/bazelrcs/cpu.bazelrc) -TFCI_BAZEL_COMMON_ARGS=(--config sigbuild_remote_cache_push) +TFCI_BAZEL_COMMON_ARGS=(--config rbe) TFCI_BUILD_PIP_PACKAGE_ARGS=(--cpu) TFCI_COPYBARA_ENABLE=0 TFCI_DOCKER_ENABLE=1 @@ -12,7 +12,7 @@ TFCI_GIT_DIR=$KOKORO_ARTIFACTS_DIR/github/tensorflow TFCI_INDEX_HTML_ENABLE=1 TFCI_LIB_SUFFIX="-cpu-linux-x86_64" TFCI_NIGHTLY_UPDATE_VERSION_ENABLE= -TFCI_NVIDIA_SMI_ENABLE=1 +TFCI_NVIDIA_SMI_ENABLE= TFCI_UPLOAD_LIB_ENABLE= TFCI_UPLOAD_LIB_LATEST_ENABLE= TFCI_UPLOAD_LIB_LATEST_URI= diff --git a/ci/official/envs/continuous_linux_x86_cuda_py310 b/ci/official/envs/continuous_linux_x86_cuda_py310 index 5c6cf6f8397867..102196c2c39bff 100644 --- a/ci/official/envs/continuous_linux_x86_cuda_py310 +++ b/ci/official/envs/continuous_linux_x86_cuda_py310 @@ -1,7 +1,7 @@ #TFCI_UPLOAD_LIB_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" #TFCI_UPLOAD_WHL_GCS_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" TFCI_BAZEL_BAZELRC_ARGS=(--bazelrc ./ci/official/bazelrcs/cuda.bazelrc) -TFCI_BAZEL_COMMON_ARGS=(--config sigbuild_remote_cache_push) +TFCI_BAZEL_COMMON_ARGS=(--config rbe) TFCI_BUILD_PIP_PACKAGE_ARGS=() TFCI_COPYBARA_ENABLE=0 TFCI_DOCKER_ENABLE=1 diff --git a/ci/official/envs/continuous_linux_x86_cuda_py311 b/ci/official/envs/continuous_linux_x86_cuda_py311 index 039c1634c6c23c..3b499c1bc91e39 100644 --- a/ci/official/envs/continuous_linux_x86_cuda_py311 +++ b/ci/official/envs/continuous_linux_x86_cuda_py311 @@ -1,7 +1,7 @@ #TFCI_UPLOAD_LIB_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" #TFCI_UPLOAD_WHL_GCS_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" TFCI_BAZEL_BAZELRC_ARGS=(--bazelrc ./ci/official/bazelrcs/cuda.bazelrc) -TFCI_BAZEL_COMMON_ARGS=(--config sigbuild_remote_cache_push) +TFCI_BAZEL_COMMON_ARGS=(--config rbe) TFCI_BUILD_PIP_PACKAGE_ARGS=() TFCI_COPYBARA_ENABLE=0 TFCI_DOCKER_ENABLE=1 diff --git a/ci/official/envs/continuous_linux_x86_cuda_py39 b/ci/official/envs/continuous_linux_x86_cuda_py39 index 1eae7b537a0598..2ebc027606f37f 100644 --- a/ci/official/envs/continuous_linux_x86_cuda_py39 +++ b/ci/official/envs/continuous_linux_x86_cuda_py39 @@ -1,7 +1,7 @@ #TFCI_UPLOAD_LIB_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" #TFCI_UPLOAD_WHL_GCS_URI="gs://tensorflow-release-packages/$RELEASE_VERSION/$KOKORO_GIT_COMMIT_tensorflow" TFCI_BAZEL_BAZELRC_ARGS=(--bazelrc ./ci/official/bazelrcs/cuda.bazelrc) -TFCI_BAZEL_COMMON_ARGS=(--config sigbuild_remote_cache_push) +TFCI_BAZEL_COMMON_ARGS=(--config rbe) TFCI_BUILD_PIP_PACKAGE_ARGS=() TFCI_COPYBARA_ENABLE=0 TFCI_DOCKER_ENABLE=1 diff --git a/ci/official/envs/nightly_libtensorflow_linux_x86_cpu b/ci/official/envs/nightly_libtensorflow_linux_x86_cpu index 0c4b25904482dd..f67cb34a46a986 100644 --- a/ci/official/envs/nightly_libtensorflow_linux_x86_cpu +++ b/ci/official/envs/nightly_libtensorflow_linux_x86_cpu @@ -12,7 +12,7 @@ TFCI_GIT_DIR=$KOKORO_ARTIFACTS_DIR/github/tensorflow TFCI_INDEX_HTML_ENABLE=1 TFCI_LIB_SUFFIX="-cpu-linux-x86_64" TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 -TFCI_NVIDIA_SMI_ENABLE=1 +TFCI_NVIDIA_SMI_ENABLE= TFCI_UPLOAD_LIB_ENABLE= TFCI_UPLOAD_LIB_LATEST_ENABLE= TFCI_UPLOAD_LIB_LATEST_URI="gs://libtensorflow-nightly/latest" diff --git a/ci/official/envs/nightly_linux_x86_cpu_py310 b/ci/official/envs/nightly_linux_x86_cpu_py310 index ab41bbce158334..8755f38cf508a6 100644 --- a/ci/official/envs/nightly_linux_x86_cpu_py310 +++ b/ci/official/envs/nightly_linux_x86_cpu_py310 @@ -12,7 +12,7 @@ TFCI_GIT_DIR=$KOKORO_ARTIFACTS_DIR/github/tensorflow TFCI_INDEX_HTML_ENABLE=1 TFCI_LIB_SUFFIX="-cpu-linux-x86_64" TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 -TFCI_NVIDIA_SMI_ENABLE=1 +TFCI_NVIDIA_SMI_ENABLE= TFCI_UPLOAD_LIB_ENABLE= TFCI_UPLOAD_LIB_LATEST_ENABLE= TFCI_UPLOAD_LIB_LATEST_URI="gs://libtensorflow-nightly/latest" diff --git a/ci/official/envs/nightly_linux_x86_cpu_py311 b/ci/official/envs/nightly_linux_x86_cpu_py311 index 83c6db47e3ac3a..26a22054856bfd 100644 --- a/ci/official/envs/nightly_linux_x86_cpu_py311 +++ b/ci/official/envs/nightly_linux_x86_cpu_py311 @@ -12,7 +12,7 @@ TFCI_GIT_DIR=$KOKORO_ARTIFACTS_DIR/github/tensorflow TFCI_INDEX_HTML_ENABLE=1 TFCI_LIB_SUFFIX="-cpu-linux-x86_64" TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 -TFCI_NVIDIA_SMI_ENABLE=1 +TFCI_NVIDIA_SMI_ENABLE= TFCI_UPLOAD_LIB_ENABLE= TFCI_UPLOAD_LIB_LATEST_ENABLE= TFCI_UPLOAD_LIB_LATEST_URI="gs://libtensorflow-nightly/latest" diff --git a/ci/official/envs/nightly_linux_x86_cpu_py39 b/ci/official/envs/nightly_linux_x86_cpu_py39 index 51f87676f61c56..1bf91c8f69448b 100644 --- a/ci/official/envs/nightly_linux_x86_cpu_py39 +++ b/ci/official/envs/nightly_linux_x86_cpu_py39 @@ -12,7 +12,7 @@ TFCI_GIT_DIR=$KOKORO_ARTIFACTS_DIR/github/tensorflow TFCI_INDEX_HTML_ENABLE=1 TFCI_LIB_SUFFIX="-cpu-linux-x86_64" TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 -TFCI_NVIDIA_SMI_ENABLE=1 +TFCI_NVIDIA_SMI_ENABLE= TFCI_UPLOAD_LIB_ENABLE= TFCI_UPLOAD_LIB_LATEST_ENABLE= TFCI_UPLOAD_LIB_LATEST_URI="gs://libtensorflow-nightly/latest" diff --git a/ci/official/utilities/rename_and_verify_wheels.sh b/ci/official/utilities/rename_and_verify_wheels.sh index 8b499afd7c4438..72b26ea56d131c 100755 --- a/ci/official/utilities/rename_and_verify_wheels.sh +++ b/ci/official/utilities/rename_and_verify_wheels.sh @@ -19,13 +19,13 @@ # "manylinux_xyz" into the wheel filename. set -euxo pipefail -cd $1 -for wheel in *.whl; do +DIR=$1 +for wheel in $DIR/*.whl; do echo "Checking and renaming $wheel..." time python3 -m auditwheel repair --plat manylinux2014_x86_64 "$wheel" --wheel-dir build 2>&1 | tee check.txt # We don't need the original wheel if it was renamed - new_wheel=$(grep --extended-regexp --only-matching '\S+.whl' check.txt) + new_wheel=$(grep --extended-regexp --only-matching '\S+.whl' check.txt | tail -n 1) if [[ "$new_wheel" != "$wheel" ]]; then rm "$wheel" wheel="$new_wheel" From c5eb50d9388d4d49ef80c424a1c7e394e00b22cf Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 27 Jul 2023 11:40:53 -0700 Subject: [PATCH 255/410] Add C++ gradient for BatchMatMulV3 PiperOrigin-RevId: 551595101 --- tensorflow/cc/gradients/math_grad.cc | 53 +++++++++--- tensorflow/cc/gradients/math_grad_test.cc | 99 ++++++++++++++++------- 2 files changed, 112 insertions(+), 40 deletions(-) diff --git a/tensorflow/cc/gradients/math_grad.cc b/tensorflow/cc/gradients/math_grad.cc index ef74aee967f943..625fb98852c020 100644 --- a/tensorflow/cc/gradients/math_grad.cc +++ b/tensorflow/cc/gradients/math_grad.cc @@ -1034,8 +1034,9 @@ REGISTER_GRADIENT_OP("SegmentSum", SegmentSumGrad); // based on input matrix transposition combinations. Status MatMulGradHelper(const Scope& scope, const bool is_batch, const Output& x0, const bool adj_x0, const Output& x1, - const bool adj_x1, const Output& y0, const bool adj_y0, - const Output& y1, const bool adj_y1, + const bool adj_x1, const DataType x_data_type, + const Output& y0, const bool adj_y0, const Output& y1, + const bool adj_y1, const DataType y_data_type, std::vector* grad_outputs) { if (is_batch == false) { auto dx = @@ -1045,11 +1046,11 @@ Status MatMulGradHelper(const Scope& scope, const bool is_batch, MatMul(scope, y0, y1, MatMul::TransposeA(adj_y0).TransposeB(adj_y1)); grad_outputs->push_back(dy); } else { - auto dx = - BatchMatMul(scope, x0, x1, BatchMatMul::AdjX(adj_x0).AdjY(adj_x1)); + auto dx = BatchMatMulV3(scope, x0, x1, x_data_type, + BatchMatMulV3::AdjX(adj_x0).AdjY(adj_x1)); grad_outputs->push_back(dx); - auto dy = - BatchMatMul(scope, y0, y1, BatchMatMul::AdjX(adj_y0).AdjY(adj_y1)); + auto dy = BatchMatMulV3(scope, y0, y1, y_data_type, + BatchMatMulV3::AdjX(adj_y0).AdjY(adj_y1)); grad_outputs->push_back(dy); } return scope.status(); @@ -1078,17 +1079,21 @@ Status MatMulGradCommon(const Scope& scope, const Operation& op, TF_RETURN_IF_ERROR(GetNodeAttr(product.node()->attrs(), attr_adj_y, &tb)); if (!ta && !tb) { - return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, b, true, a, - true, grad_inputs[0], false, grad_outputs); + return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, b, true, + a.type(), a, true, grad_inputs[0], false, b.type(), + grad_outputs); } else if (!ta && tb) { return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, b, false, - grad_inputs[0], true, a, false, grad_outputs); + a.type(), grad_inputs[0], true, a, false, b.type(), + grad_outputs); } else if (ta && !tb) { - return MatMulGradHelper(scope, is_batch, b, false, grad_inputs[0], true, a, - false, grad_inputs[0], false, grad_outputs); + return MatMulGradHelper(scope, is_batch, b, false, grad_inputs[0], true, + a.type(), a, false, grad_inputs[0], false, b.type(), + grad_outputs); } return MatMulGradHelper(scope, is_batch, b, true, grad_inputs[0], true, - grad_inputs[0], true, a, true, grad_outputs); + a.type(), grad_inputs[0], true, a, true, b.type(), + grad_outputs); } Status MatMulGrad(const Scope& scope, const Operation& op, @@ -1107,6 +1112,30 @@ Status BatchMatMulGrad(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("BatchMatMul", BatchMatMulGrad); +Status BatchMatMulV2Grad(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + TF_RETURN_IF_ERROR(MatMulGradCommon(scope, op, true, grad_inputs, "adj_x", + "adj_y", grad_outputs)); + + // Reduce along the broadcasted batch dimensions. + Output sx = Shape(scope, op.input(0)); + Output sy = Shape(scope, op.input(1)); + + Output x_batch_shape = Slice(scope, sx, {0}, Sub(scope, Shape(scope, sx), 2)); + Output y_batch_shape = Slice(scope, sy, {0}, Sub(scope, Shape(scope, sy), 2)); + + auto reduce = + internal::BroadcastGradientArgs(scope, x_batch_shape, y_batch_shape); + (*grad_outputs)[0] = + Reshape(scope, ReduceSum(scope, (*grad_outputs)[0], reduce.r0), sx); + (*grad_outputs)[1] = + Reshape(scope, ReduceSum(scope, (*grad_outputs)[1], reduce.r1), sy); + return scope.status(); +} +REGISTER_GRADIENT_OP("BatchMatMulV2", BatchMatMulV2Grad); +REGISTER_GRADIENT_OP("BatchMatMulV3", BatchMatMulV2Grad); + Status CumsumGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc index b7b0b713861bff..3026c90babdeec 100644 --- a/tensorflow/cc/gradients/math_grad_test.cc +++ b/tensorflow/cc/gradients/math_grad_test.cc @@ -13,12 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/cc/client/client_session.h" #include "tensorflow/cc/framework/grad_op_registry.h" #include "tensorflow/cc/framework/gradient_checker.h" #include "tensorflow/cc/framework/gradients.h" #include "tensorflow/cc/framework/testutil.h" #include "tensorflow/cc/gradients/grad_testutil.h" +#include "tensorflow/cc/ops/math_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -33,6 +36,7 @@ using ops::AddN; using ops::AddV2; using ops::Atan2; using ops::BatchMatMul; +using ops::BatchMatMulV3; using ops::Cast; using ops::Const; using ops::Cumsum; @@ -603,11 +607,42 @@ class MathGradTest : public ::testing::Test { MathGradTest() : root_(Scope::NewRootScope().WithDevice("/cpu:0")) {} template - void TestMatMulGrad(const bool is_batch, const bool t_x, const bool t_y) { + void TestMatMulGrad(const bool t_x, const bool t_y) { + TestMatMulGradHelper( + /*is_x_batch=*/false, /*is_y_batch=*/false, t_x, t_y, + [&](Output x, Output y) { + return MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y)); + }); + } + + template + void TestBatchMatMulGrad(const bool t_x, const bool t_y) { + TestMatMulGradHelper( + /*is_x_batch=*/true, /*is_y_batch=*/true, t_x, t_y, + [&](Output x, Output y) { + return BatchMatMul(root_, x, y, BatchMatMul::AdjX(t_x).AdjY(t_y)); + }); + } + + template + void TestBatchMatMulV3Grad(const bool is_x_batch, const bool is_y_batch, + const bool t_x, const bool t_y) { + TestMatMulGradHelper( + /*is_x_batch=*/true, /*is_y_batch=*/true, t_x, t_y, + [&](Output x, Output y) { + return BatchMatMulV3(root_, x, y, DataTypeToEnum::v(), + BatchMatMulV3::AdjX(t_x).AdjY(t_y)); + }); + } + + template + void TestMatMulGradHelper(const bool is_x_batch, const bool is_y_batch, + const bool t_x, const bool t_y, + std::function mul_fn) { TF_ASSERT_OK(root_.status()); // Generate random (but compatible) shapes for matrix multiplication. std::vector shapes; - RandMatMulShapes(is_batch, t_x, t_y, &shapes); + RandMatMulShapes(is_x_batch, is_y_batch, t_x, t_y, &shapes); TensorShape x_shape = shapes[0]; TensorShape y_shape = shapes[1]; TensorShape z_shape = shapes[2]; @@ -615,12 +650,7 @@ class MathGradTest : public ::testing::Test { Placeholder(root_, DataTypeToEnum::v(), Placeholder::Shape(x_shape)); auto y = Placeholder(root_, DataTypeToEnum::v(), Placeholder::Shape(y_shape)); - Output z; - if (is_batch) { - z = BatchMatMul(root_, x, y, BatchMatMul::AdjX(t_x).AdjY(t_y)); - } else { - z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y)); - } + Output z = mul_fn(x, y); float max_error; TF_ASSERT_OK((ComputeGradientError( @@ -628,7 +658,8 @@ class MathGradTest : public ::testing::Test { EXPECT_LT(max_error, 1e-3); } - void RandMatMulShapes(const bool is_batch, const bool tx, const bool ty, + void RandMatMulShapes(const bool is_x_batch, const bool is_y_batch, + const bool tx, const bool ty, std::vector* shapes) { // Choose a random batch size in [1, 4] const int b = 1 + (random::New64() % 4); @@ -638,7 +669,7 @@ class MathGradTest : public ::testing::Test { const int n = Rand(); TensorShape x_shape; - if (is_batch) { + if (is_x_batch) { // x.shape = [b, m, k] x_shape = tx ? TensorShape({b, k, m}) : TensorShape({b, m, k}); } else { @@ -648,7 +679,7 @@ class MathGradTest : public ::testing::Test { shapes->push_back(x_shape); TensorShape y_shape; - if (is_batch) { + if (is_y_batch) { // y.shape = [b, k, n] y_shape = ty ? TensorShape({b, n, k}) : TensorShape({b, k, n}); } else { @@ -658,7 +689,7 @@ class MathGradTest : public ::testing::Test { shapes->push_back(y_shape); TensorShape z_shape; - if (is_batch) { + if (is_x_batch || is_y_batch) { // z.shape = [b, m, n] z_shape = TensorShape({b, m, n}); } else { @@ -674,67 +705,79 @@ class MathGradTest : public ::testing::Test { }; TEST_F(MathGradTest, MatMulGrad_NoTranspose) { - TestMatMulGrad(false, false, false); + TestMatMulGrad(false, false); } TEST_F(MathGradTest, MatMulComplexGrad_NoTranspose) { - TestMatMulGrad(false, false, false); + TestMatMulGrad(false, false); } TEST_F(MathGradTest, MatMulGrad_TransposeX) { - TestMatMulGrad(false, true, false); + TestMatMulGrad(true, false); } TEST_F(MathGradTest, MatMulComplexGrad_TransposeX) { - TestMatMulGrad(false, true, false); + TestMatMulGrad(true, false); } TEST_F(MathGradTest, MatMulGrad_TransposeY) { - TestMatMulGrad(false, false, true); + TestMatMulGrad(false, true); } TEST_F(MathGradTest, MatMulComplexGrad_TransposeY) { - TestMatMulGrad(false, false, true); + TestMatMulGrad(false, true); } TEST_F(MathGradTest, MatMulGrad_TransposeX_TransposeY) { - TestMatMulGrad(false, true, true); + TestMatMulGrad(true, true); } TEST_F(MathGradTest, MatMulComplexGrad_TransposeX_TransposeY) { - TestMatMulGrad(false, true, true); + TestMatMulGrad(true, true); } TEST_F(MathGradTest, BatchMatMulGrad_NoTranspose) { - TestMatMulGrad(true, false, false); + TestBatchMatMulGrad(false, false); } TEST_F(MathGradTest, BatchMatMulComplexGrad_NoTranspose) { - TestMatMulGrad(true, false, false); + TestBatchMatMulGrad(false, false); } TEST_F(MathGradTest, BatchMatMulGrad_TransposeX) { - TestMatMulGrad(true, true, false); + TestBatchMatMulGrad(true, false); } TEST_F(MathGradTest, BatchMatMulComplexGrad_TransposeX) { - TestMatMulGrad(true, true, false); + TestBatchMatMulGrad(true, false); } TEST_F(MathGradTest, BatchMatMulGrad_TransposeY) { - TestMatMulGrad(true, false, true); + TestBatchMatMulGrad(false, true); } TEST_F(MathGradTest, BatchMatMulComplexGrad_TransposeY) { - TestMatMulGrad(true, false, true); + TestBatchMatMulGrad(false, true); } TEST_F(MathGradTest, BatchMatMulGrad_TransposeX_TransposeY) { - TestMatMulGrad(true, true, true); + TestBatchMatMulGrad(true, true); } TEST_F(MathGradTest, BatchMatMulComplexGrad_TransposeX_TransposeY) { - TestMatMulGrad(true, true, true); + TestBatchMatMulGrad(true, true); +} + +TEST_F(MathGradTest, BatchMatMulV3Grad_BroadcastX) { + TestBatchMatMulV3Grad(false, true, false, false); +} + +TEST_F(MathGradTest, BatchMatMulV3Grad_BroadcastY) { + TestBatchMatMulV3Grad(true, false, false, false); +} + +TEST_F(MathGradTest, BatchMatMulV3Grad_BroadcastYTransposeY) { + TestBatchMatMulV3Grad(true, false, false, true); } class NaryGradTest : public ::testing::Test { From 98446be4eb09e2307e494b3b8e9b160e4d2b4516 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 27 Jul 2023 11:51:16 -0700 Subject: [PATCH 256/410] [XLA:Python] Fix build failure on Mac OS. PiperOrigin-RevId: 551598083 --- tensorflow/compiler/xla/python/callback.cc | 9 ++++++++- tensorflow/compiler/xla/python/callback.h | 2 +- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/xla/python/callback.cc b/tensorflow/compiler/xla/python/callback.cc index decee881fca7f0..930096259869c3 100644 --- a/tensorflow/compiler/xla/python/callback.cc +++ b/tensorflow/compiler/xla/python/callback.cc @@ -15,12 +15,15 @@ limitations under the License. #include "tensorflow/compiler/xla/python/callback.h" +#include + #include #include #include #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/python/exceptions.h" #include "tensorflow/compiler/xla/service/custom_call_status.h" @@ -41,7 +44,11 @@ Status CpuCallback::PrepareAndCallInternal(void* result, void** arg_ptrs) { if (args_[i].type == xla::TOKEN) { args[i] = py::none(); } else { - args[i] = py::array(args_[i].dtype, args_[i].dims, args_[i].strides, + static_assert(sizeof(ssize_t) == sizeof(int64_t)); + absl::Span strides( + reinterpret_cast(args_[i].strides.data()), + args_[i].strides.size()); + args[i] = py::array(args_[i].dtype, args_[i].dims, strides, const_cast(inputs[i])); args[i].attr("flags").attr("writeable") = Py_False; } diff --git a/tensorflow/compiler/xla/python/callback.h b/tensorflow/compiler/xla/python/callback.h index bf032352c44d86..9325962fd8e9fb 100644 --- a/tensorflow/compiler/xla/python/callback.h +++ b/tensorflow/compiler/xla/python/callback.h @@ -35,7 +35,7 @@ class CpuCallback { xla::PrimitiveType type; // XLA type pybind11::dtype dtype; // NumPy type, for array types. absl::InlinedVector dims; // Dimensions, for array types. - std::vector strides; // Byte strides, for array types. + std::vector strides; // Byte strides, for array types. size_t size_in_bytes; // Size of the array in bytes. }; struct Result { From a693975b7fe515287a0a8064e5ca4debab923bd5 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 27 Jul 2023 11:51:50 -0700 Subject: [PATCH 257/410] Unique stacks and frames when attaching debug information to graphs. PiperOrigin-RevId: 551598230 --- ..._1.0_224_frozen.wrong_attr.line.part.pbtxt | 2 +- ...1.0_224_frozen.wrong_attr.stack.part.pbtxt | 8 +- .../mlir/tensorflow/translate/import_model.cc | 45 ++- .../mlir/tensorflow/utils/import_utils.cc | 2 +- .../core/common_runtime/function_def_utils.cc | 5 +- .../core/common_runtime/function_test.cc | 126 --------- .../core/common_runtime/graph_constructor.cc | 76 ++---- .../core/common_runtime/graph_constructor.h | 7 +- tensorflow/core/framework/function.cc | 100 ------- tensorflow/core/framework/function.h | 77 +----- .../core/framework/graph_debug_info.proto | 28 +- tensorflow/core/graph/graph.cc | 2 +- .../core/graph/graph_debug_info_builder.cc | 256 ++++++++++++++++-- .../core/graph/graph_debug_info_builder.h | 127 ++++++++- .../graph/graph_debug_info_builder_test.cc | 133 ++++++++- tensorflow/core/graph/graph_partition.cc | 2 +- tensorflow/core/graph/graph_partition_test.cc | 32 +-- tensorflow/core/graph/validate_test.cc | 1 - .../core/ir/importexport/graphdef_import.cc | 25 +- tensorflow/lite/python/authoring/authoring.py | 5 +- tensorflow/lite/python/lite_test.py | 2 +- tensorflow/lite/python/util.py | 15 +- .../python/eager/polymorphic_function/BUILD | 1 + .../polymorphic_function/atomic_function.py | 24 +- .../polymorphic_function_test.py | 2 +- tensorflow/python/framework/BUILD | 1 + .../python/framework/error_interpolation.py | 77 +----- .../framework/error_interpolation_test.py | 160 ----------- .../python/framework/meta_graph_test.py | 24 -- tensorflow/python/saved_model/BUILD | 1 + tensorflow/python/saved_model/save.py | 10 +- tensorflow/python/saved_model/save_test.py | 6 +- tensorflow/python/util/BUILD | 5 +- tensorflow/python/util/tf_stack.cc | 33 ++- tensorflow/python/util/tf_stack.py | 18 ++ tensorflow/python/util/tf_stack_test.py | 19 ++ tensorflow/tsl/platform/stack_frame.h | 6 + 37 files changed, 699 insertions(+), 764 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/tests/debuginfo/v1_1.0_224_frozen.wrong_attr.line.part.pbtxt b/tensorflow/compiler/mlir/lite/tests/debuginfo/v1_1.0_224_frozen.wrong_attr.line.part.pbtxt index 6b01e99da2a4f1..99794445a91d0a 100644 --- a/tensorflow/compiler/mlir/lite/tests/debuginfo/v1_1.0_224_frozen.wrong_attr.line.part.pbtxt +++ b/tensorflow/compiler/mlir/lite/tests/debuginfo/v1_1.0_224_frozen.wrong_attr.line.part.pbtxt @@ -1,6 +1,6 @@ # RUN: not tf_tfl_translate -mlir-pretty-debuginfo -tf-input-arrays=input -tf-input-data-types=DT_FLOAT -tf-input-shapes=1,224,224,3 -tf-output-arrays=MobilenetV1/MobilenetV1/Conv2d_0/BatchNorm/FusedBatchNorm -tf-debug-info=%s.debug %s -o - 2>&1 | FileCheck %s -# CHECK: fake/user/code/file_C.py:27:0: error: 'tf.Conv2D' op attribute 'data_format' failed to satisfy constraint: 'NHWC' or 'NCHW' convnet data format +# CHECK: fake/user/code/file_C.py:27:1: error: 'tf.Conv2D' op attribute 'data_format' failed to satisfy constraint: 'NHWC' or 'NCHW' convnet data format node { name: "input" diff --git a/tensorflow/compiler/mlir/lite/tests/debuginfo/v1_1.0_224_frozen.wrong_attr.stack.part.pbtxt b/tensorflow/compiler/mlir/lite/tests/debuginfo/v1_1.0_224_frozen.wrong_attr.stack.part.pbtxt index 9a676c196ce3cc..10df30c5a304ce 100644 --- a/tensorflow/compiler/mlir/lite/tests/debuginfo/v1_1.0_224_frozen.wrong_attr.stack.part.pbtxt +++ b/tensorflow/compiler/mlir/lite/tests/debuginfo/v1_1.0_224_frozen.wrong_attr.stack.part.pbtxt @@ -1,9 +1,9 @@ # RUN: not tf_tfl_translate -mlir-pretty-debuginfo -tf-input-arrays=input -tf-input-data-types=DT_FLOAT -tf-input-shapes=1,224,224,3 -tf-output-arrays=MobilenetV1/MobilenetV1/Conv2d_0/BatchNorm/FusedBatchNorm -tf-debug-info=%s.debug %s -o - 2>&1 | FileCheck %s -# CHECK: fake/user/code/file_C.py:27:0: error: 'tf.Conv2D' op attribute 'data_format' failed to satisfy constraint: 'NHWC' or 'NCHW' convnet data format -# CHECK: fake/user/code/file_D.py:28:0: note: called from -# CHECK: fake/user/code/file_E.py:29:0: note: called from -# CHECK: fake/user/code/file_F.py:30:0: note: called from +# CHECK: fake/user/code/file_C.py:27:1: error: 'tf.Conv2D' op attribute 'data_format' failed to satisfy constraint: 'NHWC' or 'NCHW' convnet data format +# CHECK: fake/user/code/file_D.py:28:1: note: called from +# CHECK: fake/user/code/file_E.py:29:1: note: called from +# CHECK: fake/user/code/file_F.py:30:1: note: called from node { name: "input" diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index a1c2f06ad49b60..3be4200f6b9543 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -112,6 +112,7 @@ limitations under the License. #include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_debug_info_builder.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/graph/tensor_id.h" #include "tensorflow/core/grappler/utils/transitive_fanin.h" @@ -240,6 +241,8 @@ class ImporterBase { LOG(INFO) << "\t" << it.first << " -> " << it.second; } } + + stack_traces_ = LoadTracesFromDebugInfo(debug_info_); } // Returns the inferred function signature of the given function body. Input @@ -465,6 +468,7 @@ class ImporterBase { const FunctionLibraryDefinition& graph_flib_; const GraphImportConfig& specs_; const GraphDebugInfo& debug_info_; + StackTracesMap stack_traces_; llvm::StringRef function_name_for_debug_info_; NodeValueMap node_values_; // TODO(jpienaar): Remove once shape inference on import is removed. @@ -1740,8 +1744,6 @@ Status ImporterBase::ConvertFunctionArgAndRets( mlir::Location ImporterBase::GetLocation(const Node& node) { DVLOG(1) << "Getting location for " << node.name() << " " << &node; // TODO(b/142400497): What is the semantic contract for locations? - const auto& debug_info = debug_info_.traces(); - // Create a location for node `name` in function `function_name`. auto create_location = [&](llvm::StringRef name, llvm::StringRef function_name) -> mlir::Location { @@ -1759,36 +1761,31 @@ mlir::Location ImporterBase::GetLocation(const Node& node) { function_name.empty() ? name.str() : debug_info_key; auto name_loc_id = mlir::StringAttr::get(context_, name_for_name_loc); - llvm::SmallVector locations; + std::shared_ptr stack_trace = node.GetStackTrace(); + // Prefer stack traces if available, fallback to debug info if not, and then - // finally to just name. - if (auto stack_trace = node.GetStackTrace()) { + // finally to just name. Older versions of debug info concatenated `@` onto + // the node name for the default graph, so we check both locations. + if (stack_trace != nullptr) { + } else if (stack_traces_.contains(name_for_name_loc)) { + stack_trace = stack_traces_.at(name_for_name_loc); + } else if (stack_traces_.contains(debug_info_key)) { + stack_trace = stack_traces_.at(debug_info_key); + } else { + DVLOG(1) << "No stack trace for " << node.name(); + } + + llvm::SmallVector locations; + + if (stack_trace != nullptr) { DVLOG(1) << "Stack available for " << node.name(); - absl::Span frames = stack_trace->ToFrames(); - locations.reserve(frames.size()); - for (const StackFrame& frame : llvm::reverse(frames)) { + for (const StackFrame& frame : stack_trace->ToFrames()) { auto file_name = mlir::StringAttr::get(context_, frame.file_name); // Use col 1 as there is no column info in StackTrace. auto file_line_loc = mlir::FileLineColLoc::get(file_name, frame.line_number, 1); locations.push_back(file_line_loc); } - } else { - DVLOG(1) << "No stack trace for " << node.name(); - const auto location_it = debug_info.find(debug_info_key); - if (location_it != debug_info.end()) { - DVLOG(1) << "Available serialized debug info for " << node.name(); - // Convert the stack trace to a chain of mlir::CallSiteLocs. - const auto& trace = location_it->second; - locations.reserve(trace.file_line_cols_size()); - for (const auto& location : trace.file_line_cols()) { - const auto& file = debug_info_.files(location.file_index()); - auto file_name = mlir::StringAttr::get(context_, file); - auto file_line_loc = mlir::FileLineColLoc::get( - file_name, location.line(), location.col()); - locations.push_back(file_line_loc); - } - } } // If there are no locations in the stack trace, fall back to just a diff --git a/tensorflow/compiler/mlir/tensorflow/utils/import_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/import_utils.cc index 7b3312a76a6179..f80cdf14ae7e79 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/import_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/import_utils.cc @@ -47,7 +47,7 @@ Status LoadProtoFromBuffer(absl::string_view input, protobuf::io::ArrayInputStream binary_stream(input.data(), input.size()); if (proto->ParseFromZeroCopyStream(&binary_stream)) return OkStatus(); - LOG(ERROR) << "Error parsing Protobuf"; + LOG(ERROR) << "Error parsing Protobuf: " << proto->GetTypeName(); return errors::InvalidArgument("Could not parse input proto"); } diff --git a/tensorflow/core/common_runtime/function_def_utils.cc b/tensorflow/core/common_runtime/function_def_utils.cc index ed1582c05a8066..b39d2ba54c24fc 100644 --- a/tensorflow/core/common_runtime/function_def_utils.cc +++ b/tensorflow/core/common_runtime/function_def_utils.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_debug_info_builder.h" namespace tensorflow { @@ -50,7 +51,9 @@ Status FunctionDefToBodyHelper( GraphConstructorOptions opts; opts.allow_internal_ops = true; opts.expect_device_spec = false; - TF_RETURN_IF_ERROR(ConvertNodeDefsToGraph(opts, result.nodes, graph.get())); + + TF_RETURN_IF_ERROR(ConvertNodeDefsToGraph(opts, result.nodes, graph.get(), + /*debug_info=*/nullptr)); const StackTracesMap* stack_traces = lib_def->GetStackTraces(fdef.signature().name()); diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc index 31e8dcc553b641..2369a7e0358e03 100644 --- a/tensorflow/core/common_runtime/function_test.cc +++ b/tensorflow/core/common_runtime/function_test.cc @@ -2452,131 +2452,5 @@ TEST(OptimizationTest, RemoveListArrayConverter_WithControlDeps) { TF_EXPECT_GRAPH_EQ(expected, Optimize(remove_listarray_and_identity, func)); } -class TestStackTrace : public AbstractStackTrace { - public: - explicit TestStackTrace(const std::vector& frames) - : frames_(frames) {} - - absl::Span ToFrames() const override { return frames_; } - - StackFrame LastUserFrame() const override { return frames_.back(); } - - std::vector GetUserFrames(int limit) const override { - return frames_; - } - - string ToString(const TracePrintingOptions& opts) const override { - return ""; - } - - std::vector frames_; -}; - -TEST(StackTracesMapToGraphDebugInfoTest, EmptyMap) { - StackTracesMap map; - GraphDebugInfo generated = StackTracesMapToGraphDebugInfo(map); - - EXPECT_EQ(generated.files_size(), 0); - EXPECT_EQ(generated.traces_size(), 0); -} - -TEST(StackTracesMapToGraphDebugInfoTest, EmptyFrames) { - StackTracesMap map; - std::vector frames; - auto stack_trace = std::make_shared(frames); - map.insert({"dummy_name", stack_trace}); - GraphDebugInfo generated = StackTracesMapToGraphDebugInfo(map); - - EXPECT_EQ(generated.files_size(), 0); - EXPECT_EQ(generated.traces_size(), 1); - EXPECT_EQ(generated.traces().at("dummy_name").file_line_cols().size(), 0); -} - -TEST(StackTracesMapToGraphDebugInfoTest, OneFrame) { - StackTracesMap map; - std::vector frames = { - StackFrame({"dummy_file_name", 10, "dummy_function_name"})}; - auto stack_trace = std::make_shared(frames); - map.insert({"dummy_name", stack_trace}); - GraphDebugInfo generated = StackTracesMapToGraphDebugInfo(map); - - EXPECT_EQ(generated.files_size(), 1); - EXPECT_EQ(generated.files()[0], "dummy_file_name"); - - EXPECT_EQ(generated.traces_size(), 1); - EXPECT_EQ(generated.traces().at("dummy_name").file_line_cols().size(), 1); - EXPECT_EQ( - generated.traces().at("dummy_name").file_line_cols()[0].file_index(), 0); - EXPECT_EQ(generated.traces().at("dummy_name").file_line_cols()[0].line(), 10); - EXPECT_EQ(generated.traces().at("dummy_name").file_line_cols()[0].func(), - "dummy_function_name"); -} - -TEST(StackTracesMapToGraphDebugInfoTest, TwoFramesSameFile) { - StackTracesMap map; - std::vector frames = { - StackFrame({"dummy_file_name", 10, "dummy_function_name"}), - StackFrame({"dummy_file_name", 20, "other_function_name"})}; - auto stack_trace = std::make_shared(frames); - map.insert({"dummy_name", stack_trace}); - GraphDebugInfo generated = StackTracesMapToGraphDebugInfo(map); - - EXPECT_EQ(generated.files_size(), 1); - EXPECT_EQ(generated.files()[0], "dummy_file_name"); - - EXPECT_EQ(generated.traces_size(), 1); - EXPECT_EQ(generated.traces().at("dummy_name").file_line_cols().size(), 2); - - EXPECT_EQ( - generated.traces().at("dummy_name").file_line_cols()[0].file_index(), 0); - EXPECT_EQ(generated.traces().at("dummy_name").file_line_cols()[0].line(), 10); - EXPECT_EQ(generated.traces().at("dummy_name").file_line_cols()[0].func(), - "dummy_function_name"); - - EXPECT_EQ( - generated.traces().at("dummy_name").file_line_cols()[1].file_index(), 0); - EXPECT_EQ(generated.traces().at("dummy_name").file_line_cols()[1].line(), 20); - EXPECT_EQ(generated.traces().at("dummy_name").file_line_cols()[1].func(), - "other_function_name"); -} - -TEST(StackTracesMapToGraphDebugInfoTest, TwoFramesDifferentFile) { - StackTracesMap map; - std::vector frames = { - StackFrame({"dummy_file_name", 10, "dummy_function_name"}), - StackFrame({"other_file_name", 20, "other_function_name"})}; - auto stack_trace = std::make_shared(frames); - map.insert({"dummy_name", stack_trace}); - GraphDebugInfo generated = StackTracesMapToGraphDebugInfo(map); - - EXPECT_EQ(generated.files_size(), 2); - EXPECT_EQ(generated.files()[0], "dummy_file_name"); - EXPECT_EQ(generated.files()[1], "other_file_name"); - - EXPECT_EQ(generated.traces_size(), 1); - EXPECT_EQ(generated.traces().at("dummy_name").file_line_cols().size(), 2); - - EXPECT_EQ( - generated.traces().at("dummy_name").file_line_cols()[0].file_index(), 0); - EXPECT_EQ(generated.traces().at("dummy_name").file_line_cols()[0].line(), 10); - EXPECT_EQ(generated.traces().at("dummy_name").file_line_cols()[0].func(), - "dummy_function_name"); - - EXPECT_EQ( - generated.traces().at("dummy_name").file_line_cols()[1].file_index(), 1); - EXPECT_EQ(generated.traces().at("dummy_name").file_line_cols()[1].line(), 20); - EXPECT_EQ(generated.traces().at("dummy_name").file_line_cols()[1].func(), - "other_function_name"); -} - -TEST(StackTracesTest, ToFrames) { - StackTracesMap map; - std::vector frames = { - StackFrame({"dummy_file_name", 10, "dummy_function_name"}), - StackFrame({"other_file_name", 20, "other_function_name"})}; - auto stack_trace = TestStackTrace(frames); - EXPECT_EQ(stack_trace.ToFrames().size(), 2); -} - } // namespace } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/graph_constructor.cc b/tensorflow/core/common_runtime/graph_constructor.cc index bcb255f02fef5c..ff20b14485457b 100644 --- a/tensorflow/core/common_runtime/graph_constructor.cc +++ b/tensorflow/core/common_runtime/graph_constructor.cc @@ -21,10 +21,7 @@ limitations under the License. #include #include #include -#include -#include #include -#include #include #include "absl/algorithm/container.h" @@ -45,6 +42,7 @@ limitations under the License. #include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_debug_info_builder.h" #include "tensorflow/core/graph/tensor_id.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/flatmap.h" @@ -233,8 +231,6 @@ class GraphConstructor { FunctionDefLibraryStackTraces CreateStackTracesForFunctionDefLibrary( const FunctionDefLibrary& library) const; - std::shared_ptr CreateStackTraceForNode( - absl::string_view node_name) const; void Undo(); @@ -321,6 +317,8 @@ class GraphConstructor { // A copy of opts_.prefix, possibly uniquified. string prefix_; + StackTracesMap traces_; + ShapeRefiner* refiner_; // May be null. Not owned. @@ -1150,49 +1148,26 @@ void GraphConstructor::PrintCycles() { } } -FunctionDefLibraryStackTraces -GraphConstructor::CreateStackTracesForFunctionDefLibrary( - const FunctionDefLibrary& library) const { - FunctionDefLibraryStackTraces library_traces; - if (debug_info() == nullptr) { - return library_traces; - } - for (const FunctionDef& fdef : library.function()) { - const std::string& function_name = fdef.signature().name(); - StackTracesMap stack_traces; - std::string key_suffix = absl::StrCat("@", function_name); - for (const auto& [traces_key, stack_trace] : debug_info()->traces()) { - if (!absl::EndsWith(traces_key, key_suffix)) continue; - std::string node_key = - std::string(absl::StripSuffix(traces_key, key_suffix)); - stack_traces[node_key] = - std::make_shared(stack_trace, *debug_info()); - } - if (!stack_traces.empty()) { - library_traces[function_name] = std::move(stack_traces); - } - } - return library_traces; -} - -std::shared_ptr GraphConstructor::CreateStackTraceForNode( - absl::string_view node_name) const { - if (debug_info() == nullptr) { - return nullptr; - } - auto iterator = debug_info()->traces().find(node_name); - if (iterator != debug_info()->traces().end()) { - return std::make_shared(iterator->second, *debug_info()); +Status GraphConstructor::Convert() { + if (debug_info() != nullptr) { + traces_ = LoadTracesFromDebugInfo(*debug_info()); } - return nullptr; -} -Status GraphConstructor::Convert() { // Import functions before adding nodes, since imported nodes may refer to // functions if (auto library = consume_library(); library.has_value()) { - FunctionDefLibraryStackTraces library_traces = - CreateStackTracesForFunctionDefLibrary(library.value()); + FunctionDefLibraryStackTraces library_traces; + for (const FunctionDef& fdef : library->function()) { + const std::string& function_name = fdef.signature().name(); + StackTracesMap& function_traces = library_traces[function_name]; + std::string key_suffix = absl::StrCat("@", function_name); + for (const auto& [traces_key, stack_trace] : traces_) { + if (!absl::EndsWith(traces_key, key_suffix)) continue; + std::string node_key = + std::string(absl::StripSuffix(traces_key, key_suffix)); + function_traces[node_key] = stack_trace; + } + } TF_RETURN_IF_ERROR( g_->AddFunctionLibrary(*std::move(library), library_traces)); } @@ -1324,10 +1299,8 @@ Status GraphConstructor::Convert() { TF_RETURN_IF_ERROR(MakeNode(std::move(node_def), &node)); if (node != nullptr) { - std::shared_ptr stack_trace = - CreateStackTraceForNode(node_name); - if (stack_trace != nullptr) { - node->SetStackTrace(stack_trace); + if (traces_.contains(node_name)) { + node->SetStackTrace(traces_[node_name]); } } @@ -1527,7 +1500,6 @@ Status GraphConstructor::MakeEdge(Node* src, int output_index, Node* dst, g_->AddEdge(src, output_index, dst, input_index); return OkStatus(); } - } // namespace Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts, @@ -1549,7 +1521,8 @@ Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts, } Status ConvertNodeDefsToGraph(const GraphConstructorOptions& opts, - gtl::ArraySlice nodes, Graph* g) { + gtl::ArraySlice nodes, Graph* g, + const GraphDebugInfo* debug_info) { ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, g->op_registry()); // TODO(irving): Copy will go away once NodeInfo exists std::vector node_defs; @@ -1557,8 +1530,9 @@ Status ConvertNodeDefsToGraph(const GraphConstructorOptions& opts, for (const auto& n : nodes) { node_defs.push_back(&n); } - return GraphConstructor::Construct(opts, node_defs, nullptr, nullptr, nullptr, - g, &refiner, /*return_tensors=*/nullptr, + return GraphConstructor::Construct(opts, node_defs, nullptr, nullptr, + debug_info, g, &refiner, + /*return_tensors=*/nullptr, /*return_nodes=*/nullptr, /*missing_unused_input_map_keys=*/nullptr); } diff --git a/tensorflow/core/common_runtime/graph_constructor.h b/tensorflow/core/common_runtime/graph_constructor.h index a4a602891a9ef7..cb164732da150e 100644 --- a/tensorflow/core/common_runtime/graph_constructor.h +++ b/tensorflow/core/common_runtime/graph_constructor.h @@ -31,7 +31,7 @@ class ShapeRefiner; // nodes) when provided to ConvertGraphDefToGraph. To enhance an existing Graph, // see ImportGraphDef. struct GraphConstructorOptions { - GraphConstructorOptions() {} + GraphConstructorOptions() = default; // If true, allows internal ops in the GraphDef. bool allow_internal_ops = false; @@ -61,8 +61,9 @@ extern Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts, // Same as ConvertGraphDefToGraph, but takes just nodes. Used by function // instantiation. // TODO(irving): This will turn into std::vector soon. -extern Status ConvertNodeDefsToGraph(const GraphConstructorOptions& opts, - gtl::ArraySlice nodes, Graph* g); +extern Status ConvertNodeDefsToGraph( + const GraphConstructorOptions& opts, gtl::ArraySlice nodes, + Graph* g, const GraphDebugInfo* debug_info = nullptr); // Options for calling ImportGraphDef(). struct ImportGraphDefOptions { diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc index 6b7f530f98d7c7..224a2c5bb82d67 100644 --- a/tensorflow/core/framework/function.cc +++ b/tensorflow/core/framework/function.cc @@ -1223,106 +1223,6 @@ Status FunctionCallFrame::SetRetval(int index, const Tensor& val) { return OkStatus(); } -// Ignore the frames containing this substring for common prefix calculation. -static const char* kFilenameToIgnorePrefix = " stack_frames, - int shared_prefix_length) { - return absl::StrJoin( - stack_frames, "\n", [&](std::string* out, const StackFrame& frame) { - absl::StrAppend(out, StackFrameToString(frame, shared_prefix_length)); - }); -} - -FrozenStackTrace::FrozenStackTrace(absl::Span frames, - absl::Span user_frames) - : frames_(frames.begin(), frames.end()), - user_frames_(user_frames.begin(), user_frames.end()) { - if (user_frames.empty()) { - user_frames_ = frames_; - } -} - -FrozenStackTrace::FrozenStackTrace( - const GraphDebugInfo::StackTrace& stack_trace, - const GraphDebugInfo& debug_info) { - for (const GraphDebugInfo::FileLineCol& file_line_col : - stack_trace.file_line_cols()) { - int file_index = file_line_col.file_index(); - std::string file_name = - (file_index >= 0 && file_index < debug_info.files_size()) - ? debug_info.files(file_index) - : ""; - frames_.push_back( - StackFrame(file_name, file_line_col.line(), file_line_col.func())); - } -} - -absl::Span FrozenStackTrace::ToFrames() const { - return frames_; -} - -StackFrame FrozenStackTrace::LastUserFrame() const { return frames_.back(); } - -std::vector FrozenStackTrace::GetUserFrames(int limit) const { - std::vector result; - if (limit < 0 || limit > user_frames_.size()) { - limit = user_frames_.size(); - } - result.reserve(limit); - for (int i = 0; i < limit; ++i) { - result.push_back(user_frames_[i]); - } - return result; -} - -std::string FrozenStackTrace::ToString(const TracePrintingOptions& opts) const { - int shared_prefix_length = 0; - if (opts.filter_common_prefix) { - std::vector prefix_file_names; - for (const StackFrame& frame : frames_) { - if (!absl::StrContains(frame.file_name, kFilenameToIgnorePrefix)) { - prefix_file_names.push_back(frame.file_name); - } - } - shared_prefix_length = tsl::io::CommonPathPrefix(prefix_file_names).size(); - } - - if (!opts.drop_internal_frames) { - return ToStringHelper(frames_, shared_prefix_length); - } - - std::vector non_internal_frames; - for (const StackFrame& frame : frames_) { - if (!IsInternalFrameForFilename(frame.file_name)) { - non_internal_frames.push_back(frame); - } - } - return ToStringHelper(non_internal_frames, shared_prefix_length); -} - -tensorflow::GraphDebugInfo StackTracesMapToGraphDebugInfo( - const tensorflow::StackTracesMap& map, bool user_frames) { - GraphDebugInfoBuilder builder; - GraphDebugInfoBuilder::Options options; - options.user_frames = user_frames; - options.user_frames_limit = -1; - builder.AccumulateStackTracesMap(map, "", options); - return builder.Build(); -} - FunctionRecord::FunctionRecord(const FunctionDef& fdef, const StackTracesMap& stack_traces, bool finalized) diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h index cf2c3706254752..5b7ae9800107b3 100644 --- a/tensorflow/core/framework/function.h +++ b/tensorflow/core/framework/function.h @@ -40,14 +40,13 @@ limitations under the License. #include "tensorflow/core/framework/optimized_function_graph.pb.h" #include "tensorflow/core/framework/registration/registration.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/hash/hash.h" -#include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/graph/graph_debug_info_builder.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/protobuf.h" -#include "tensorflow/core/platform/refcount.h" +#include "tensorflow/core/platform/random.h" +#include "tensorflow/core/platform/stack_frame.h" #include "tensorflow/core/platform/threadpool_interface.h" #include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/tsl/protobuf/error_codes.pb.h" @@ -355,80 +354,10 @@ class FunctionCallFrame : public CallFrameInterface { TF_DISALLOW_COPY_AND_ASSIGN(FunctionCallFrame); }; -// Language agnostic stack traces. -class AbstractStackTrace { - public: - struct TracePrintingOptions { - // Show inline the contents of each stack line. - bool show_line_contents = false; - - // Drop the common largest prefix of all filenames in stack frames. - bool filter_common_prefix = false; - - // Do not show internal frames. - bool drop_internal_frames = false; - }; - - virtual ~AbstractStackTrace() {} - - // The returned span is alive as long as the AbstractStackTrace is alive. - virtual absl::Span ToFrames() const = 0; - - // Returns the last stack frame from user code, attempting to ignore the - // framework code. Returns an empty frame if no such stack frame was found. - virtual StackFrame LastUserFrame() const = 0; - - // Returns stack trace from user code (instead of op creation ones returned in - // ToFrames). - virtual std::vector GetUserFrames(int limit) const = 0; - - virtual std::string ToString(const TracePrintingOptions& opts) const = 0; -}; - -// A frozen sequence of StackFrames; an adapter for a span of StackFrames that -// conforms to the AbstractStackTrace contract. -class FrozenStackTrace : public AbstractStackTrace { - public: - // Constructs a FrozenStackTrace from a span of StackFrames by making a copy - // of each stack frame. - explicit FrozenStackTrace(absl::Span frames, - absl::Span user_frames = {}); - - explicit FrozenStackTrace(std::vector&& frames) - : frames_(std::move(frames)), user_frames_({}) {} - - // Constructs a FrozenStackTrace from serialized proto data. - FrozenStackTrace(const GraphDebugInfo::StackTrace& stack_trace, - const GraphDebugInfo& debug_info); - - ~FrozenStackTrace() override = default; - - absl::Span ToFrames() const override; - - StackFrame LastUserFrame() const override; - - std::vector GetUserFrames(int limit) const override; - - std::string ToString(const TracePrintingOptions& opts) const override; - - private: - std::vector frames_; - std::vector user_frames_; -}; - -using StackTracesMap = - std::unordered_map>; - // Map of function names to StackTracesMaps. using FunctionDefLibraryStackTraces = absl::flat_hash_map; -// Generates a GraphDebugInfo proto from a StackTracesMap object. Returns user -// frames by default. If `user_frames` is false, returns all frames. -tensorflow::GraphDebugInfo StackTracesMapToGraphDebugInfo( - const tensorflow::StackTracesMap& map, bool user_frames = true); - // Holds Function information that can be shared in multiple places. // FunctionRecord must be explicitly finalized before being saved in // FunctionLibraryDefinition or any other place that expects immutability. diff --git a/tensorflow/core/framework/graph_debug_info.proto b/tensorflow/core/framework/graph_debug_info.proto index 01d820104c1f86..9a493445dba60b 100644 --- a/tensorflow/core/framework/graph_debug_info.proto +++ b/tensorflow/core/framework/graph_debug_info.proto @@ -1,4 +1,4 @@ -syntax = "proto3"; +syntax = "proto2"; package tensorflow; @@ -13,32 +13,40 @@ message GraphDebugInfo { message FileLineCol { // File name index, which can be used to retrieve the file name string from // `files`. The value should be between 0 and (len(files)-1) - int32 file_index = 1; + optional int32 file_index = 1; // Line number in the file. - int32 line = 2; + optional int32 line = 2; // Col number in the file line. - int32 col = 3; + optional int32 col = 3; // Name of function contains the file line. - string func = 4; + optional string func = 4; // Source code contained in this file line. - string code = 5; + optional string code = 5; } // This represents a stack trace which is a ordered list of `FileLineCol`. message StackTrace { - // Each line in the stack trace. - repeated FileLineCol file_line_cols = 1; + repeated FileLineCol file_line_cols = 1; // Deprecated. + repeated fixed64 frame_id = 2 [packed = true]; } // This stores all the source code file names and can be indexed by the // `file_index`. repeated string files = 1; - // This maps a node name to a stack trace in the source code. + // Stack traces and frames are uniqueified during construction. These maps + // index from the unique id for a frame/trace to the value. + map frames_by_id = 4; + map traces_by_id = 6; + + map traces = 2; // Deprecated. + + // This maps a node name to a trace id contained in `traces_by_id`. + // // The map key is a mangling of the containing function and op name with // syntax: // op.name '@' func_name @@ -49,5 +57,5 @@ message GraphDebugInfo { // names accept a much wider set of characters. // It would be preferable to avoid mangling and use a tuple key of (op.name, // func_name), but this is not supported with protocol buffers. - map traces = 2; + map name_to_trace_id = 5; } diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc index a71e22d4dfde6a..ac7135ae36b060 100644 --- a/tensorflow/core/graph/graph.cc +++ b/tensorflow/core/graph/graph.cc @@ -1074,7 +1074,7 @@ GraphDebugInfo Graph::BuildDebugInfo() const { const std::shared_ptr& stack_trace = node->GetStackTrace(); if (stack_trace != nullptr) { - builder.AccumulateStackTrace(*stack_trace, node->name()); + builder.AccumulateStackTrace(stack_trace, node->name()); } } diff --git a/tensorflow/core/graph/graph_debug_info_builder.cc b/tensorflow/core/graph/graph_debug_info_builder.cc index db21242916ed83..89f596238bf29d 100644 --- a/tensorflow/core/graph/graph_debug_info_builder.cc +++ b/tensorflow/core/graph/graph_debug_info_builder.cc @@ -15,59 +15,271 @@ limitations under the License. #include "tensorflow/core/graph/graph_debug_info_builder.h" +#include #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/hash/hash.h" +#include "absl/status/status.h" +#include "absl/strings/str_format.h" +#include "absl/types/span.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph_debug_info.pb.h" +#include "tensorflow/core/framework/logging.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/stack_frame.h" +#include "tensorflow/tsl/platform/path.h" namespace tensorflow { +// Ignore the frames containing this substring for common prefix calculation. +static const char* kFilenameToIgnorePrefix = " stack_frames, + int shared_prefix_length) { + return absl::StrJoin( + stack_frames, "\n", [&](std::string* out, const StackFrame& frame) { + absl::StrAppend(out, StackFrameToString(frame, shared_prefix_length)); + }); +} + +FrozenStackTrace::FrozenStackTrace(absl::Span frames, + absl::Span user_frames) + : frames_(frames.begin(), frames.end()), + user_frames_(user_frames.begin(), user_frames.end()) { + if (user_frames.empty()) { + user_frames_ = frames_; + } +} + +FrozenStackTrace::FrozenStackTrace( + const GraphDebugInfo::StackTrace& stack_trace, + const GraphDebugInfo& debug_info) { + auto push_frame = [this, + &debug_info](const GraphDebugInfo::FileLineCol& frame) { + int file_index = frame.file_index(); + std::string file_name = + (file_index >= 0 && file_index < debug_info.files_size()) + ? debug_info.files(file_index) + : ""; + frames_.push_back(StackFrame(file_name, frame.line(), frame.func())); + }; + + if (!stack_trace.file_line_cols().empty()) { + for (const GraphDebugInfo::FileLineCol& frame : + stack_trace.file_line_cols()) { + push_frame(frame); + } + } else { + for (const uint64_t frame_id : stack_trace.frame_id()) { + if (debug_info.frames_by_id().contains(frame_id)) { + push_frame(debug_info.frames_by_id().at(frame_id)); + } else { + LOG_FIRST_N(ERROR, 5) << "No matching frame for id:" << frame_id; + } + } + } +} + +absl::Span FrozenStackTrace::ToFrames() const { + return frames_; +} + +StackFrame FrozenStackTrace::LastUserFrame() const { return frames_.back(); } + +std::vector FrozenStackTrace::GetUserFrames(int limit) const { + std::vector result; + if (limit < 0 || limit > user_frames_.size()) { + limit = user_frames_.size(); + } + result.reserve(limit); + for (int i = 0; i < limit; ++i) { + result.push_back(user_frames_[i]); + } + return result; +} + +std::string FrozenStackTrace::ToString(const TracePrintingOptions& opts) const { + int shared_prefix_length = 0; + if (opts.filter_common_prefix) { + std::vector prefix_file_names; + for (const StackFrame& frame : frames_) { + if (!absl::StrContains(frame.file_name, kFilenameToIgnorePrefix)) { + prefix_file_names.push_back(frame.file_name); + } + } + shared_prefix_length = tsl::io::CommonPathPrefix(prefix_file_names).size(); + } + + if (!opts.drop_internal_frames) { + return ToStringHelper(frames_, shared_prefix_length); + } + + std::vector non_internal_frames; + for (const StackFrame& frame : frames_) { + if (!IsInternalFrameForFilename(frame.file_name)) { + non_internal_frames.push_back(frame); + } + } + return ToStringHelper(non_internal_frames, shared_prefix_length); +} + +GraphDebugInfoBuilder::GraphDebugInfoBuilder() + : debug_info_(std::make_unique()) {} + void GraphDebugInfoBuilder::AccumulateStackTracesMap( const StackTracesMap& stack_traces_map, absl::string_view key_suffix, const GraphDebugInfoBuilder::Options& options) { + trace_to_index_.reserve(trace_to_index_.size() + stack_traces_map.size()); for (const auto& [node_name, stack_trace] : stack_traces_map) { if (stack_trace == nullptr) continue; std::string trace_key = absl::StrCat(node_name, key_suffix); - AccumulateStackTrace(*stack_trace, trace_key, options); + AccumulateStackTrace(stack_trace, trace_key, options); } } void GraphDebugInfoBuilder::AccumulateStackTrace( - const AbstractStackTrace& abstract_stack_trace, - absl::string_view traces_key, + std::shared_ptr trace, absl::string_view traces_key, const GraphDebugInfoBuilder::Options& options) { - GraphDebugInfo::StackTrace stack_trace_proto; - if (options.user_frames) { - for (const auto& stack_frame : - abstract_stack_trace.GetUserFrames(options.user_frames_limit)) { - AppendToStackTraceProto(stack_frame, stack_trace_proto); - } + int trace_index = 0; + StackTracePointer p{trace}; + auto found = trace_to_index_.find(p); + if (found != trace_to_index_.end()) { + trace_index = found->second; } else { - for (const auto& stack_frame : abstract_stack_trace.ToFrames()) { - AppendToStackTraceProto(stack_frame, stack_trace_proto); + trace_index = debug_info_->traces_by_id().size(); + trace_to_index_[p] = trace_index; + GraphDebugInfo::StackTrace& stack_trace_proto = + (*debug_info_->mutable_traces_by_id())[trace_index]; + if (options.user_frames) { + frame_to_index_.reserve( + frame_to_index_.size() + + trace->GetUserFrames(options.user_frames_limit).size()); + for (const auto& stack_frame : + trace->GetUserFrames(options.user_frames_limit)) { + AppendToStackTraceProto(stack_frame, stack_trace_proto); + } + } else { + frame_to_index_.reserve(frame_to_index_.size() + + trace->ToFrames().size()); + for (const auto& stack_frame : trace->ToFrames()) { + AppendToStackTraceProto(stack_frame, stack_trace_proto); + } } } - (*debug_info_.mutable_traces())[traces_key] = std::move(stack_trace_proto); + (*debug_info_->mutable_name_to_trace_id())[traces_key] = trace_index; } void GraphDebugInfoBuilder::AppendToStackTraceProto( const StackFrame& stack_frame, GraphDebugInfo::StackTrace& stack_trace_proto) { - auto& file_line_col = *stack_trace_proto.add_file_line_cols(); - if (file_name_to_index_.contains(stack_frame.file_name)) { - file_line_col.set_file_index(file_name_to_index_[stack_frame.file_name]); + int frame_index = 0; + auto found = frame_to_index_.find(stack_frame); + if (found != frame_to_index_.end()) { + frame_index = found->second; } else { - file_line_col.set_file_index(new_name_index_); - file_name_to_index_[stack_frame.file_name] = new_name_index_; - *debug_info_.add_files() = stack_frame.file_name; - new_name_index_++; + frame_index = debug_info_->frames_by_id().size(); + frame_to_index_[stack_frame] = frame_index; + GraphDebugInfo::FileLineCol& frame = + (*debug_info_->mutable_frames_by_id())[frame_index]; + auto file_index = file_name_to_index_.find(stack_frame.file_name); + if (file_index != file_name_to_index_.end()) { + frame.set_file_index(file_index->second); + } else { + frame.set_file_index(new_name_index_); + file_name_to_index_[stack_frame.file_name] = new_name_index_; + *debug_info_->add_files() = stack_frame.file_name; + new_name_index_++; + } + frame.set_line(stack_frame.line_number); + frame.set_func(stack_frame.function_name); + } + stack_trace_proto.add_frame_id(frame_index); +} + +void GraphDebugInfoBuilder::AppendGraphDebugInfo( + absl::string_view prefix, const GraphDebugInfo& new_info) { + for (const auto& pair : new_info.name_to_trace_id()) { + auto trace = new_info.traces_by_id().at(pair.second); + auto frozen = std::make_shared(trace, new_info); + std::string key = + prefix.empty() ? pair.first : absl::StrCat(pair.first, "@", prefix); + AccumulateStackTrace(frozen, key, GraphDebugInfoBuilder::Options{}); + } +} + +GraphDebugInfo GraphDebugInfoBuilder::Build() const { return *debug_info_; } + +absl::Status GraphDebugInfoBuilder::AppendGraphDebugInfoStr( + absl::string_view prefix, absl::string_view new_info_str) { + GraphDebugInfo debug_info; + if (!debug_info.ParseFromArray(new_info_str.data(), new_info_str.size())) { + return absl::InvalidArgumentError("Failed to parse GraphDebugInfo proto."); + } + AppendGraphDebugInfo(prefix, debug_info); + return absl::OkStatus(); +} + +std::string GraphDebugInfoBuilder::ToGraphDebugInfoStr() const { + return Build().SerializeAsString(); +} + +StackTracesMap LoadTracesFromDebugInfo(const GraphDebugInfo& debug_info) { + StackTracesMap traces; + absl::flat_hash_map> + traces_by_id; + traces_by_id.reserve(debug_info.traces_by_id_size()); + for (const auto& [id, frames] : debug_info.traces_by_id()) { + traces_by_id[id] = std::make_shared(frames, debug_info); + } + + traces.reserve(debug_info.name_to_trace_id_size() + debug_info.traces_size()); + for (const auto& [name, trace_id] : debug_info.name_to_trace_id()) { + if (!traces_by_id.contains(trace_id)) { + LOG_FIRST_N(ERROR, 5) << "No matching trace for id:" << trace_id; + continue; + } + traces[name] = traces_by_id[trace_id]; } - file_line_col.set_line(stack_frame.line_number); - file_line_col.set_func(stack_frame.function_name); + + for (const auto& [name, frames] : debug_info.traces()) { + traces[name] = std::make_shared(frames, debug_info); + } + + return traces; } -GraphDebugInfo GraphDebugInfoBuilder::Build() const { return debug_info_; } +absl::StatusOr LoadTracesFromDebugInfoStr( + absl::string_view debug_info_str) { + GraphDebugInfo debug_info; + if (!debug_info.ParseFromArray(debug_info_str.data(), + debug_info_str.size())) { + return absl::InvalidArgumentError("Failed to parse GraphDebugInfo proto."); + } + return LoadTracesFromDebugInfo(debug_info); +} + +GraphDebugInfo StackTracesMapToGraphDebugInfo(const StackTracesMap& map, + bool user_frames) { + GraphDebugInfoBuilder builder; + GraphDebugInfoBuilder::Options options; + options.user_frames = user_frames; + options.user_frames_limit = -1; + builder.AccumulateStackTracesMap(map, "", options); + return builder.Build(); +} } // namespace tensorflow diff --git a/tensorflow/core/graph/graph_debug_info_builder.h b/tensorflow/core/graph/graph_debug_info_builder.h index 570f3a43ec5339..392f79bea45ed9 100644 --- a/tensorflow/core/graph/graph_debug_info_builder.h +++ b/tensorflow/core/graph/graph_debug_info_builder.h @@ -16,15 +16,85 @@ limitations under the License. #ifndef TENSORFLOW_CORE_GRAPH_GRAPH_DEBUG_INFO_BUILDER_H_ #define TENSORFLOW_CORE_GRAPH_GRAPH_DEBUG_INFO_BUILDER_H_ +#include #include - -#include "tensorflow/core/framework/function.h" +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "tensorflow/core/framework/graph_debug_info.pb.h" #include "tensorflow/core/platform/stack_frame.h" #include "tensorflow/tsl/platform/macros.h" namespace tensorflow { +// Language agnostic stack traces. +class AbstractStackTrace { + public: + struct TracePrintingOptions { + // Show inline the contents of each stack line. + bool show_line_contents = false; + + // Drop the common largest prefix of all filenames in stack frames. + bool filter_common_prefix = false; + + // Do not show internal frames. + bool drop_internal_frames = false; + }; + + virtual ~AbstractStackTrace() = default; + + // The returned span is alive as long as the AbstractStackTrace is alive. + virtual absl::Span ToFrames() const = 0; + + // Returns the last stack frame from user code, attempting to ignore the + // framework code. Returns an empty frame if no such stack frame was found. + virtual StackFrame LastUserFrame() const = 0; + + // Returns stack trace from user code (instead of op creation ones returned in + // ToFrames). + virtual std::vector GetUserFrames(int limit) const = 0; + + virtual std::string ToString(const TracePrintingOptions& opts) const = 0; +}; + +// A frozen sequence of StackFrames; an adapter for a span of StackFrames that +// conforms to the AbstractStackTrace contract. +class FrozenStackTrace : public AbstractStackTrace { + public: + // Constructs a FrozenStackTrace from a span of StackFrames by making a copy + // of each stack frame. + explicit FrozenStackTrace(absl::Span frames, + absl::Span user_frames = {}); + + explicit FrozenStackTrace(std::vector&& frames) + : frames_(std::move(frames)), user_frames_({}) {} + + FrozenStackTrace(FrozenStackTrace&&) = default; + + // Constructs a FrozenStackTrace from serialized proto data. + FrozenStackTrace(const GraphDebugInfo::StackTrace& stack_trace, + const GraphDebugInfo& debug_info); + + ~FrozenStackTrace() override = default; + + absl::Span ToFrames() const override; + + StackFrame LastUserFrame() const override; + + std::vector GetUserFrames(int limit) const override; + + std::string ToString(const TracePrintingOptions& opts) const override; + + private: + std::vector frames_; + std::vector user_frames_; +}; + // Builder for GraphDebugInfo protos from either an existing map of string keys // to stack traces, or individual stack traces, or both. All stack traces in a // GraphDebugInfo are stored with a string key in the `traces` field. In the @@ -40,6 +110,39 @@ namespace tensorflow { // Typical usage is to call one or both of the accumulate methods one or more // times and then to call the Build(). +// Holder type to use `AbstractStackTrace` as a key. +struct StackTracePointer { + std::shared_ptr trace; + + template + friend H AbslHashValue(H h, const StackTracePointer& p) { + for (const auto& frame : p.trace->ToFrames()) { + h = H::combine(std::move(h), frame); + } + return h; + } + + bool operator==(const StackTracePointer& other) const { + absl::Span other_frames = other.trace->ToFrames(); + absl::Span frames = trace->ToFrames(); + return frames == other_frames; + } +}; + +using StackTracesMap = + absl::flat_hash_map>; + +// Load all stack traces from `debug_info`. +StackTracesMap LoadTracesFromDebugInfo(const GraphDebugInfo& debug_info); +absl::StatusOr LoadTracesFromDebugInfoStr( + absl::string_view debug_info_str); + +// Generates a GraphDebugInfo proto from a StackTracesMap object. Returns user +// frames by default. If `user_frames` is false, returns all frames. +GraphDebugInfo StackTracesMapToGraphDebugInfo(const StackTracesMap& map, + bool user_frames = true); + class GraphDebugInfoBuilder { public: struct Options { @@ -50,7 +153,8 @@ class GraphDebugInfoBuilder { int user_frames_limit; }; - explicit GraphDebugInfoBuilder() = default; + GraphDebugInfoBuilder(); + virtual ~GraphDebugInfoBuilder() = default; // Adds a map of stack traces to the GraphDebugInfo proto. For each key (node // id) and stack traces entry in `stack_traces_map`, combine the key with @@ -64,11 +168,21 @@ class GraphDebugInfoBuilder { // Adds one stack trace to the GraphDebugInfo proto, using `traces_key` as the // key for the `traces` field of the proto. - void AccumulateStackTrace(const AbstractStackTrace& abstract_stack_trace, + void AccumulateStackTrace(std::shared_ptr trace, absl::string_view traces_key, const GraphDebugInfoBuilder::Options& options = GraphDebugInfoBuilder::Options()); + void AppendGraphDebugInfo(absl::string_view prefix, + const GraphDebugInfo& new_info); + + // These string methods are used in the Python bindings to avoid symbol + // resolution errors with pybind on Windows. + absl::Status AppendGraphDebugInfoStr(absl::string_view prefix, + absl::string_view new_info_str); + + std::string ToGraphDebugInfoStr() const; + // Returns the GraphDebugInfo proto. GraphDebugInfo Build() const; @@ -76,8 +190,11 @@ class GraphDebugInfoBuilder { void AppendToStackTraceProto(const StackFrame& stack_frame, GraphDebugInfo::StackTrace& stack_trace_proto); - GraphDebugInfo debug_info_; + std::unique_ptr debug_info_; absl::flat_hash_map file_name_to_index_; + + absl::flat_hash_map trace_to_index_; + absl::flat_hash_map frame_to_index_; int new_name_index_ = 0; TF_DISALLOW_COPY_AND_ASSIGN(GraphDebugInfoBuilder); diff --git a/tensorflow/core/graph/graph_debug_info_builder_test.cc b/tensorflow/core/graph/graph_debug_info_builder_test.cc index 439084122214ce..6c236de97d0bba 100644 --- a/tensorflow/core/graph/graph_debug_info_builder_test.cc +++ b/tensorflow/core/graph/graph_debug_info_builder_test.cc @@ -16,9 +16,11 @@ limitations under the License. #include "tensorflow/core/graph/graph_debug_info_builder.h" #include +#include #include #include +#include #include "tensorflow/core/framework/graph_debug_info.pb.h" #include "tensorflow/core/platform/test.h" @@ -51,7 +53,7 @@ class TestStackTrace : public AbstractStackTrace { }; TEST(GraphDebugInfoBuilderTest, AccumulateStackTrace) { - TestStackTrace stack_trace( + auto stack_trace = std::make_shared( std::vector{{"dummy_file_alpha.cc", 20, "function_bar"}, {"dummy_file_beta.cc", 30, "function_sop"}}); @@ -61,12 +63,14 @@ TEST(GraphDebugInfoBuilderTest, AccumulateStackTrace) { EXPECT_THAT(debug_info.files(), UnorderedElementsAre("dummy_file_alpha.cc", "dummy_file_beta.cc")); - EXPECT_THAT(debug_info.traces_size(), Eq(1)); - - EXPECT_THAT(debug_info.traces().find("alpha_beta"), - Ne(debug_info.traces().end())); - auto actual_stack_trace = debug_info.traces().find("alpha_beta")->second; - EXPECT_THAT(actual_stack_trace.file_line_cols_size(), Eq(2)); + EXPECT_THAT(debug_info.traces_by_id_size(), Eq(1)); + + EXPECT_THAT(debug_info.name_to_trace_id().find("alpha_beta"), + Ne(debug_info.name_to_trace_id().end())); + auto actual_stack_trace = debug_info.traces_by_id().at( + debug_info.name_to_trace_id().at("alpha_beta")); + EXPECT_THAT(actual_stack_trace.frame_id_size(), Eq(2)) + << debug_info.DebugString(); } TEST(GraphDebugInfoBuilderTest, AccumulateStackTracesMap) { @@ -91,19 +95,25 @@ TEST(GraphDebugInfoBuilderTest, AccumulateStackTracesMap) { EXPECT_THAT(debug_info.files(), UnorderedElementsAre("dummy_file_alpha.cc", "dummy_file_beta.cc")); - EXPECT_THAT(debug_info.traces_size(), Eq(3)); + EXPECT_THAT(debug_info.name_to_trace_id_size(), Eq(3)); // Examine one of the three stack traces in detail. - EXPECT_THAT(debug_info.traces().find("scale@func"), - Ne(debug_info.traces().end())); - auto stack_trace = debug_info.traces().find("scale@func")->second; - EXPECT_THAT(stack_trace.file_line_cols_size(), Eq(2)); + EXPECT_THAT(debug_info.name_to_trace_id().find("scale@func"), + Ne(debug_info.name_to_trace_id().end())); + auto stack_trace = debug_info.traces_by_id().at( + debug_info.name_to_trace_id().at("scale@func")); + EXPECT_THAT(stack_trace.frame_id_size(), Eq(2)); + + std::vector file_line_cols; + for (auto& frame_id : stack_trace.frame_id()) { + file_line_cols.push_back(debug_info.frames_by_id().at(frame_id)); + } // `FileLineCol.file_index` is non-deterministic because the GraphDebugInfo is // built by accumulating all file names into a set, and then storing that in // the `files` field in an arbitrary order. - auto file_line_col_0 = stack_trace.file_line_cols(0); - auto file_line_col_1 = stack_trace.file_line_cols(1); + auto file_line_col_0 = file_line_cols[0]; + auto file_line_col_1 = file_line_cols[1]; EXPECT_THAT(std::vector( {file_line_col_0.file_index(), file_line_col_1.file_index()}), UnorderedElementsAre(0, 1)); @@ -113,5 +123,100 @@ TEST(GraphDebugInfoBuilderTest, AccumulateStackTracesMap) { EXPECT_THAT(file_line_col_1.func(), Eq("function_sop")); } +TEST(GraphDebugInfoBuilderTest, AppendGraphDebugInfo) { + GraphDebugInfo a; + + // Function stack traces are commonly returned without a prefix. + // Validate that we can accumulate these correctly. + { + GraphDebugInfoBuilder builder; + StackTracesMap stack_traces; + stack_traces["two"] = std::make_shared( + std::vector{{"dummy_file_alpha.cc", 20, "function_bar"}}); + stack_traces["scale"] = std::make_shared( + std::vector{{"dummy_file_alpha.cc", 10, "function_foo"}}); + builder.AccumulateStackTracesMap(stack_traces, ""); + a = builder.Build(); + } + + GraphDebugInfo b; + { + GraphDebugInfoBuilder builder; + StackTracesMap stack_traces; + stack_traces["y"] = + std::make_shared(std::vector{ + {"dummy_file_alpha.cc", 15, "function_flex"}, + }); + builder.AccumulateStackTracesMap(stack_traces, ""); + b = builder.Build(); + } + + // With builtin prefix + GraphDebugInfo c; + { + GraphDebugInfoBuilder builder; + StackTracesMap stack_traces; + stack_traces["z"] = + std::make_shared(std::vector{ + {"dummy_file_alpha.cc", 15, "function_flex"}, + }); + builder.AccumulateStackTracesMap(stack_traces, "@func3"); + c = builder.Build(); + } + + GraphDebugInfoBuilder builder; + builder.AppendGraphDebugInfo("func1", a); + builder.AppendGraphDebugInfo("func2", b); + builder.AppendGraphDebugInfo("", c); + GraphDebugInfo combined = builder.Build(); + + EXPECT_EQ(combined.name_to_trace_id().size(), 4); + std::vector keys{"two@func1", "scale@func1", "y@func2", + "z@func3"}; + + for (const auto& key : keys) { + EXPECT_THAT(combined.name_to_trace_id().find(key), + Ne(combined.name_to_trace_id().end())); + } +} + +TEST(StackTracesMapToGraphDebugInfoTest, EmptyMap) { + StackTracesMap map; + GraphDebugInfo generated = StackTracesMapToGraphDebugInfo(map); + + EXPECT_EQ(generated.files_size(), 0); + EXPECT_EQ(generated.traces_size(), 0); +} + +TEST(StackTracesMapToGraphDebugInfoTest, EmptyFrames) { + StackTracesMap map; + std::vector frames; + auto stack_trace = std::make_shared(frames); + map.insert({"dummy_name", stack_trace}); + GraphDebugInfo generated = StackTracesMapToGraphDebugInfo(map); + + EXPECT_EQ(generated.files_size(), 0); + EXPECT_EQ(generated.traces_by_id_size(), 1); + EXPECT_TRUE(generated.name_to_trace_id().contains("dummy_name")); +} + +TEST(StackTracesMapToGraphDebugInfoTest, RoundTripStackTraces) { + StackTracesMap map; + std::vector frames = { + StackFrame({"dummy_file_name", 10, "dummy_function_name"}), + StackFrame({"dummy_file_name", 20, "other_function_name"})}; + auto stack_trace = std::make_shared(frames); + map.insert({"dummy_name", stack_trace}); + GraphDebugInfo generated = StackTracesMapToGraphDebugInfo(map); + + StackTracesMap output = LoadTracesFromDebugInfo(generated); + + for (auto [name, trace] : output) { + auto orig_trace = map[name]; + EXPECT_NE(orig_trace, nullptr); + EXPECT_EQ(orig_trace->ToFrames(), trace->ToFrames()); + } +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/core/graph/graph_partition.cc b/tensorflow/core/graph/graph_partition.cc index 2b948c66de4f62..d80338e02aa29f 100644 --- a/tensorflow/core/graph/graph_partition.cc +++ b/tensorflow/core/graph/graph_partition.cc @@ -1223,7 +1223,7 @@ Status Partition(const PartitionOptions& opts, Graph* g, if (!builder) { builder = std::make_unique(); } - builder->AccumulateStackTrace(*stack_trace, dst->name()); + builder->AccumulateStackTrace(stack_trace, dst->name()); } } diff --git a/tensorflow/core/graph/graph_partition_test.cc b/tensorflow/core/graph/graph_partition_test.cc index 9f44ebdf739c32..3ab6c24d80e20b 100644 --- a/tensorflow/core/graph/graph_partition_test.cc +++ b/tensorflow/core/graph/graph_partition_test.cc @@ -37,6 +37,7 @@ limitations under the License. #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_debug_info_builder.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -583,27 +584,24 @@ TEST_F(GraphPartitionTest, GraphDebugInfo) { // A stack trace for A1 should be in the A partition (".../cpu:0"). string a = "/job:a/replica:0/task:0/cpu:0"; const GraphDebugInfo& a_debug_info = partitions_[a].debug_info(); - const auto& a_it = a_debug_info.traces().find("A1"); - EXPECT_EQ(1, a_debug_info.traces().size()); - EXPECT_THAT(a_it, Ne(a_debug_info.traces().end())); - EXPECT_THAT(FormatStackTrace(a_it->second, a_debug_info), - Eq("x@main.cc:20.0\n" - "a1@alpha.cc:30.0\n")); + StackTracesMap traces = LoadTracesFromDebugInfo(a_debug_info); + const auto& a_it = traces.find("A1"); + EXPECT_THAT(a_it, Ne(traces.end())); + EXPECT_THAT(a_it->second->ToString({}), + ::testing::ContainsRegex("alpha.cc.*30")); // Stack traces for B1 and B2 should be in the B partition (".../cpu:1"). string b = "/job:a/replica:0/task:0/cpu:1"; const GraphDebugInfo& b_debug_info = partitions_[b].debug_info(); - const auto& b1_it = b_debug_info.traces().find("B1"); - const auto& b2_it = b_debug_info.traces().find("B2"); - EXPECT_EQ(2, b_debug_info.traces().size()); - EXPECT_THAT(b1_it, Ne(b_debug_info.traces().end())); - EXPECT_THAT(b2_it, Ne(b_debug_info.traces().end())); - EXPECT_THAT(FormatStackTrace(b1_it->second, b_debug_info), - Eq("y@window.cc:21.0\n" - "b1@beta.cc:35.0\n")); - EXPECT_THAT(FormatStackTrace(b2_it->second, b_debug_info), - Eq("bar@cache.cc:22.0\n" - "b2@beta.cc:39.0\n")); + traces = LoadTracesFromDebugInfo(b_debug_info); + const auto& b1_it = traces.find("B1"); + const auto& b2_it = traces.find("B2"); + EXPECT_THAT(b1_it, Ne(traces.end())); + EXPECT_THAT(b2_it, Ne(traces.end())); + EXPECT_THAT(b1_it->second->ToString({}), + ::testing::ContainsRegex("beta.cc.*35")); + EXPECT_THAT(b2_it->second->ToString({}), + ::testing::ContainsRegex("beta.cc.*39")); } TEST(TopologicalSortNodesWithTimePriorityTest, NoDependencies) { diff --git a/tensorflow/core/graph/validate_test.cc b/tensorflow/core/graph/validate_test.cc index 160efcfc9c8176..4a57af42326402 100644 --- a/tensorflow/core/graph/validate_test.cc +++ b/tensorflow/core/graph/validate_test.cc @@ -241,7 +241,6 @@ TEST(ValidateGraphHasNoCycleTest, CycleFails) { // Need to construct graph explicitly, since GraphDefToGraph has its own // cycle validation routine. Graph graph(OpRegistry::Global()); - GraphConstructorOptions opts; Node* a = AddNodeFromNodeDef(graph, "A", "FloatInput", 0); Node* c = AddNodeFromNodeDef(graph, "B", "Mul", 2); diff --git a/tensorflow/core/ir/importexport/graphdef_import.cc b/tensorflow/core/ir/importexport/graphdef_import.cc index f4709231165ee7..eb09991b6a0973 100644 --- a/tensorflow/core/ir/importexport/graphdef_import.cc +++ b/tensorflow/core/ir/importexport/graphdef_import.cc @@ -49,6 +49,7 @@ limitations under the License. #include "tensorflow/core/framework/op_def_builder.h" #include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_debug_info_builder.h" #include "tensorflow/core/graph/tensor_id.h" #include "tensorflow/core/ir/dialect.h" #include "tensorflow/core/ir/importexport/convert_attributes.h" @@ -72,6 +73,7 @@ using tensorflow::NodeDef; using tensorflow::OpDef; using tensorflow::OpRegistrationData; using tensorflow::OpRegistry; +using tensorflow::StackTracesMap; using tensorflow::Status; using tensorflow::StatusOr; using tensorflow::StringPiece; @@ -97,6 +99,7 @@ class GraphDefImporter { unknown_loc_(UnknownLoc::get(ctx_)), placeholder_state_(unknown_loc_, "tfg._mlir_placeholder") { placeholder_state_.addTypes(dialect_->getControlType()); + stack_traces_ = LoadTracesFromDebugInfo(debug_info); } // Convert a GraphDef to MLIR module. @@ -235,6 +238,8 @@ class GraphDefImporter { // Operation state for creating placeholder ops. OperationState placeholder_state_; + StackTracesMap stack_traces_; + // Map of function OpDefs. absl::flat_hash_map function_op_defs_; }; @@ -463,15 +468,19 @@ Location GraphDefImporter::ConvertLocation(StringRef node_name, auto name_loc_id = b_.getStringAttr(name_loc); SmallVector locs; - const auto &traces = debug_info_.traces(); + // Try to find a stack trace to convert to locations. - auto it = traces.find(debug_info_key); - if (it != traces.end()) { - const auto &trace = it->second; - locs.reserve(trace.file_line_cols_size()); - for (const auto &loc : trace.file_line_cols()) { - auto file_name = b_.getStringAttr(debug_info_.files(loc.file_index())); - locs.push_back(FileLineColLoc::get(file_name, loc.line(), loc.col())); + auto it = stack_traces_.find(name_loc); + if (it == stack_traces_.end()) { + it = stack_traces_.find(debug_info_key); + } + if (it != stack_traces_.end()) { + std::shared_ptr trace = it->second; + auto frames = trace->ToFrames(); + locs.reserve(frames.size()); + for (const auto &frame : frames) { + auto file_attr = b_.getStringAttr(frame.file_name); + locs.push_back(FileLineColLoc::get(file_attr, frame.line_number, 1)); } } diff --git a/tensorflow/lite/python/authoring/authoring.py b/tensorflow/lite/python/authoring/authoring.py index 88c982fba0ad28..c8a8df467327cf 100644 --- a/tensorflow/lite/python/authoring/authoring.py +++ b/tensorflow/lite/python/authoring/authoring.py @@ -141,12 +141,9 @@ def get_concrete_function(self, *args, **kwargs): def _get_location_string(self, location): """Dump location of ConveterError.errors.location.""" callstack = [] - for single_call in location.call: + for single_call in reversed(location.call): if (location.type == converter_error_data_pb2.ConverterErrorData.CALLSITELOC): - # Stop showing CallSite after func_graph.py which isn't meaningful. - if _FUNC_GRAPH_SRC_PATH in single_call.source.filename: - break callstack.append( f" - {single_call.source.filename}:{single_call.source.line}") else: diff --git a/tensorflow/lite/python/lite_test.py b/tensorflow/lite/python/lite_test.py index 365b6edc2224d1..945d239a14535c 100644 --- a/tensorflow/lite/python/lite_test.py +++ b/tensorflow/lite/python/lite_test.py @@ -1589,7 +1589,7 @@ def plus_placeholder(x, placeholder): # Check the add node in the inlined function is included. func = sess.graph.as_graph_def().library.function[0].signature.name - self.assertIn(('add@' + func), converter._debug_info.traces) + self.assertIn(('add@' + func), repr(converter._debug_info)) def testOutputOnlyModel(self): with ops.Graph().as_default(): diff --git a/tensorflow/lite/python/util.py b/tensorflow/lite/python/util.py index c819e3c67df39f..04c9f425ade31b 100644 --- a/tensorflow/lite/python/util.py +++ b/tensorflow/lite/python/util.py @@ -21,7 +21,6 @@ from absl import logging import flatbuffers -from tensorflow.core.framework import graph_debug_info_pb2 from tensorflow.core.protobuf import config_pb2 as _config_pb2 from tensorflow.core.protobuf import meta_graph_pb2 as _meta_graph_pb2 from tensorflow.lite.python import conversion_metadata_schema_py_generated as conversion_metadata_fb @@ -380,18 +379,8 @@ def convert_debug_info_func(saved_debug_info): def f(original_nodes): """Function to create `GraphDebugInfo` for the given `original_nodes`.""" - if not saved_debug_info: - return None - - output_debug_info = graph_debug_info_pb2.GraphDebugInfo() - # All the files are copied over, so the index wouldn't be changed. - output_debug_info.files[:] = saved_debug_info.files - # We only copy over the debug info for the input nodes - for func, node in original_nodes: - debug_key = node + "@" + func - output_debug_info.traces[debug_key].CopyFrom( - saved_debug_info.traces[debug_key]) - return output_debug_info + del original_nodes + return saved_debug_info return f diff --git a/tensorflow/python/eager/polymorphic_function/BUILD b/tensorflow/python/eager/polymorphic_function/BUILD index ccd2e8b52f1ea3..543375d2b3af98 100644 --- a/tensorflow/python/eager/polymorphic_function/BUILD +++ b/tensorflow/python/eager/polymorphic_function/BUILD @@ -62,6 +62,7 @@ pytype_strict_library( "//tensorflow/python/types:core", "//tensorflow/python/util:compat", "//tensorflow/python/util:function_utils", + "//tensorflow/python/util:tf_stack", ], ) diff --git a/tensorflow/python/eager/polymorphic_function/atomic_function.py b/tensorflow/python/eager/polymorphic_function/atomic_function.py index ae9ec964659a9f..c2eb5da8d6b3ab 100644 --- a/tensorflow/python/eager/polymorphic_function/atomic_function.py +++ b/tensorflow/python/eager/polymorphic_function/atomic_function.py @@ -17,7 +17,7 @@ import dataclasses import traceback import typing -from typing import Any, Dict, List, Sequence, Optional, Union +from typing import Any, Dict, List, Optional, Sequence, Union from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import function_pb2 @@ -39,6 +39,7 @@ from tensorflow.python.types import core from tensorflow.python.util import compat from tensorflow.python.util import function_utils +from tensorflow.python.util import tf_stack # TODO(fmuham): Should be lowered to FunctionDef/FunctionRecord. @@ -635,24 +636,17 @@ def __init__(self, top_level_func): def interpolate(self, message, node_names, graph_debug_info): """Uses the GraphDebugInfo to generate an error message.""" error_message = ["Graph execution error:", ""] + traces = tf_stack.LoadTracesFromDebugInfo(graph_debug_info) + for node_name in node_names: error_message.append( f"Detected at node {node_name} defined at (most recent call last):" ) - if node_name in graph_debug_info.traces: - stack_trace = graph_debug_info.traces[node_name] - tb_frames = [] - for frame in stack_trace.file_line_cols: - tb_frames.append( - traceback.FrameSummary( - graph_debug_info.files[frame.file_index], - frame.line, - frame.func, - ) - ) - for formatted_frame in traceback.format_list(tb_frames): - if not any(p in formatted_frame for p in self.DENY_LIST_PHRASES): - error_message.append(formatted_frame) + if node_name in traces: + stack_trace = traces[node_name] + for formatted_frame in traceback.format_list(stack_trace): + if not any(p in formatted_frame for p in self.DENY_LIST_PHRASES): + error_message.append(formatted_frame) else: error_message.append("") diff --git a/tensorflow/python/eager/polymorphic_function/polymorphic_function_test.py b/tensorflow/python/eager/polymorphic_function/polymorphic_function_test.py index 3c0d320a9f4014..aca8107094eb11 100644 --- a/tensorflow/python/eager/polymorphic_function/polymorphic_function_test.py +++ b/tensorflow/python/eager/polymorphic_function/polymorphic_function_test.py @@ -3313,7 +3313,7 @@ def test_fn(): return script_ops.eager_py_func( func=lambda: array_ops.constant([2.]), inp=(), Tout=dtypes.int32) - error_pattern = re.compile(r'Graph execution error.*func=lambda', re.DOTALL) + error_pattern = re.compile(r'Graph execution error.*test_fn', re.DOTALL) with self.assertRaisesRegex(errors.InvalidArgumentError, error_pattern): test_fn() diff --git a/tensorflow/python/framework/BUILD b/tensorflow/python/framework/BUILD index 8b09a321682609..f0ddf57b8ff253 100644 --- a/tensorflow/python/framework/BUILD +++ b/tensorflow/python/framework/BUILD @@ -518,6 +518,7 @@ py_strict_library( srcs_version = "PY3", deps = [ "//tensorflow/core:protos_all_py", + "//tensorflow/python/util:tf_stack", ], ) diff --git a/tensorflow/python/framework/error_interpolation.py b/tensorflow/python/framework/error_interpolation.py index 23a78224ee5088..76c727a5fe5ea1 100644 --- a/tensorflow/python/framework/error_interpolation.py +++ b/tensorflow/python/framework/error_interpolation.py @@ -24,7 +24,7 @@ import site import traceback -from tensorflow.core.framework import graph_debug_info_pb2 +from tensorflow.python.util import tf_stack _NAME_REGEX = r"[A-Za-z0-9_.][A-Za-z0-9_.\-/]*?" _TAG_REGEX = fr"{{{{(?P{_NAME_REGEX}) (?P{_NAME_REGEX})}}}}" @@ -306,80 +306,15 @@ def create_graph_debug_info_def(func_named_operations): Raises: TypeError: If the arguments are not of the correct proto buffer type. """ - # Creates an empty GraphDebugInfoDef proto. - graph_debug_info_def = graph_debug_info_pb2.GraphDebugInfo() - - # Gets the file names and line numbers for the exported node names. Also - # collects the unique file names. - all_file_names = set() - node_to_trace = {} + builder = tf_stack.GraphDebugInfoBuilder() for func_name, op in func_named_operations: if op.traceback is None: continue - # Gets the stack trace of the operation and then the file location. - node_name = op.name + "@" + func_name - node_to_trace[node_name] = _compute_useful_frames(op.traceback, 10) - for frame in node_to_trace[node_name]: - all_file_names.add(frame.filename) - - # Sets the `files` field in the GraphDebugInfo proto - graph_debug_info_def.files.extend(all_file_names) - - # Builds a mapping between file names and index of the `files` field, so we - # only store the indexes for the nodes in the GraphDebugInfo. - file_to_index = dict( - [(y, x) for x, y in enumerate(graph_debug_info_def.files)]) - - # Creates the FileLineCol proto for each node and sets the value in the - # GraphDebugInfo proto. We only store the file name index for each node to - # save the storage space. - for node_name, frames in node_to_trace.items(): - trace_def = graph_debug_info_def.traces[node_name] - for frame in reversed(frames): - trace_def.file_line_cols.add( - file_index=file_to_index[frame.filename], - line=frame.lineno) - - return graph_debug_info_def - - -def merge_graph_debug_info_def(per_fn_info): - """Construct and returns a `GraphDebugInfo` protocol buffer. - - Args: - per_fn_info: An iterable of (func_name, GraphDebugInfo) tuples. - - Returns: - GraphDebugInfo protocol buffer. + builder.AccumulateStackTrace( + func_name, op.name, _compute_useful_frames(op.traceback, 10) + ) - Raises: - TypeError: If the arguments are not of the correct proto buffer type. - """ - graph_debug_info_def = graph_debug_info_pb2.GraphDebugInfo() - - all_file_names = set() - for _, fn_info in per_fn_info: - all_file_names.update(fn_info.files) - # Ensure determinism. - all_file_names = sorted(all_file_names) - - graph_debug_info_def.files.extend(all_file_names) - file_to_index = dict( - [(y, x) for x, y in enumerate(graph_debug_info_def.files)]) - - for fn_name, fn_info in per_fn_info: - for fn_node_name, fn_trace in fn_info.traces.items(): - trace_def = graph_debug_info_def.traces[fn_node_name + "@" + fn_name] - for fn_frame in fn_trace.file_line_cols: - trace_def.file_line_cols.add( - file_index=file_to_index[fn_info.files[fn_frame.file_index]], - line=fn_frame.line, - col=fn_frame.col, - func=fn_frame.func, - code=fn_frame.code, - ) - - return graph_debug_info_def + return builder.Build() def _compute_field_dict(op): diff --git a/tensorflow/python/framework/error_interpolation_test.py b/tensorflow/python/framework/error_interpolation_test.py index fb6512657e5cfd..7f9c35591cc9e0 100644 --- a/tensorflow/python/framework/error_interpolation_test.py +++ b/tensorflow/python/framework/error_interpolation_test.py @@ -18,7 +18,6 @@ import os import re -from tensorflow.core.framework import graph_debug_info_pb2 from tensorflow.python.eager import def_function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -133,165 +132,6 @@ def testCorrectFormatWhenNoColocationsWereActive(self): self.assertIn("No node-device colocations", summary) -# Note that the create_graph_debug_info_def needs to run on graph mode ops, -# so it is excluded from eager tests. Even when used in eager mode, it is -# via FunctionGraphs, and directly verifying in graph mode is the narrowest -# way to unit test the functionality. -class CreateGraphDebugInfoDefTest(test.TestCase): - - def _getFirstStackTraceForFile(self, graph_debug_info, key, file_index): - self.assertIn(key, graph_debug_info.traces) - stack_trace = graph_debug_info.traces[key] - found_flc = None - for flc in stack_trace.file_line_cols: - if flc.file_index == file_index: - found_flc = flc - break - self.assertIsNotNone( - found_flc, "Could not find a stack trace entry for file" - ) - return found_flc - - def testStackTraceExtraction(self): - # This test is verifying stack trace information added in graph mode, so - # only makes sense in graph mode. - with ops.Graph().as_default(): - # Since the create_graph_debug_info_def() function does not actually - # do anything special with functions except name mangling, just verify - # it with a loose op and manually provided function name. - # The following ops *must* be on consecutive lines (it will be verified - # in the resulting trace). - # pyformat: disable - global_op = constant_op.constant(0, name="Global").op - op1 = constant_op.constant(1, name="One").op - op2 = constant_op.constant(2, name="Two").op - # pyformat: enable - - # Ensure op without traceback does not fail - node_def_copy = type(op1.node_def)() - node_def_copy.CopyFrom(op1.node_def) - node_def_copy.name = "NonTraceback" - c_op = ops._create_c_op( - ops.get_default_graph(), - node_def=node_def_copy, - inputs=[], - control_inputs=[], - extract_traceback=False, - ) - - non_traceback_op = ops.Operation._from_c_op(c_op, ops.get_default_graph()) - self.assertIsNone(non_traceback_op.traceback) - - export_ops = [ - ("", global_op), - ("func1", op1), - ("func2", op2), - ("func2", non_traceback_op), - ] - graph_debug_info = error_interpolation.create_graph_debug_info_def( - export_ops - ) - this_file_index = -1 - for file_index, file_name in enumerate(graph_debug_info.files): - if "{}error_interpolation_test.py".format(os.sep) in file_name: - this_file_index = file_index - self.assertGreaterEqual( - this_file_index, - 0, - "Could not find this file in trace:" + repr(graph_debug_info), - ) - - # Verify the traces exist for each op. - global_flc = self._getFirstStackTraceForFile( - graph_debug_info, "Global@", this_file_index - ) - op1_flc = self._getFirstStackTraceForFile( - graph_debug_info, "One@func1", this_file_index - ) - op2_flc = self._getFirstStackTraceForFile( - graph_debug_info, "Two@func2", this_file_index - ) - - self.assertNotIn("NonTraceback@func2", graph_debug_info.traces) - - global_line = global_flc.line - self.assertEqual(op1_flc.line, global_line + 1, "op1 not on next line") - self.assertEqual(op2_flc.line, global_line + 2, "op2 not on next line") - - -class MergeGraphDebugInfoDefTest(test.TestCase): - - def testMerges(self): - fn_1 = graph_debug_info_pb2.GraphDebugInfo( - files=["a.py", "b.py", "c.py"], - traces={ - "node_1": graph_debug_info_pb2.GraphDebugInfo.StackTrace( - file_line_cols=[ - graph_debug_info_pb2.GraphDebugInfo.FileLineCol( - file_index=0, line=19, col=2 - ) - ] - ), - "node_2": graph_debug_info_pb2.GraphDebugInfo.StackTrace( - file_line_cols=[ - graph_debug_info_pb2.GraphDebugInfo.FileLineCol( - file_index=1, line=33, col=4 - ) - ] - ), - }, - ) - - fn_2 = graph_debug_info_pb2.GraphDebugInfo( - files=["c.py", "a.py", "b.py"], - traces={ - "node_1": graph_debug_info_pb2.GraphDebugInfo.StackTrace( - file_line_cols=[ - graph_debug_info_pb2.GraphDebugInfo.FileLineCol( - file_index=0, line=9, col=6 - ) - ] - ), - "node_2": graph_debug_info_pb2.GraphDebugInfo.StackTrace( - file_line_cols=[ - graph_debug_info_pb2.GraphDebugInfo.FileLineCol( - file_index=1, line=56, col=7 - ) - ] - ), - }, - ) - - result = error_interpolation.merge_graph_debug_info_def( - [("fn_1", fn_1), ("fn_2", fn_2)] - ) - - self.assertEqual(set(result.files), {"b.py", "a.py", "c.py"}) - self.assertLen(result.traces, 4) - - # Check file names are correctly indexed. - self.assertEqual( - result.files[result.traces["node_1@fn_1"].file_line_cols[0].file_index], - "a.py", - ) - self.assertEqual( - result.files[result.traces["node_2@fn_1"].file_line_cols[0].file_index], - "b.py", - ) - self.assertEqual( - result.files[result.traces["node_1@fn_2"].file_line_cols[0].file_index], - "c.py", - ) - self.assertEqual( - result.files[result.traces["node_2@fn_2"].file_line_cols[0].file_index], - "a.py", - ) - - # Check properties of a node. - self.assertEqual(result.traces["node_1@fn_1"].file_line_cols[0].line, 19) - self.assertEqual(result.traces["node_1@fn_1"].file_line_cols[0].col, 2) - - class InterpolateFilenamesAndLineNumbersTest(test.TestCase): def testFindIndexOfDefiningFrameForOp(self): diff --git a/tensorflow/python/framework/meta_graph_test.py b/tensorflow/python/framework/meta_graph_test.py index fe04a6f5c4c3d4..054a2e2d354505 100644 --- a/tensorflow/python/framework/meta_graph_test.py +++ b/tensorflow/python/framework/meta_graph_test.py @@ -24,7 +24,6 @@ from tensorflow.python.client import session from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes -from tensorflow.python.framework import error_interpolation from tensorflow.python.framework import function from tensorflow.python.framework import meta_graph from tensorflow.python.framework import ops @@ -747,29 +746,6 @@ def testScopedWithQueue(self): test_util.assert_meta_graph_protos_equal(self, orig_meta_graph, new_meta_graph) - def testExportDebugInfo(self): - graph1 = ops.Graph() - with graph1.as_default(): - with ops.name_scope("hidden1/hidden2/hidden3"): - images = constant_op.constant( - 1.0, dtypes.float32, shape=[3, 2], name="images") - weights1 = variables.Variable([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], - name="weights") - biases1 = resource_variable_ops.ResourceVariable( - [0.1] * 3, name="biases") - nn_ops.relu(math_ops.matmul(images, weights1) + biases1, name="relu") - func_named_operations = [] - for op in graph1.get_operations(): - func_named_operations.append(("", op)) - debug_info_def = error_interpolation.create_graph_debug_info_def( - func_named_operations) - - # The unique file names in all the stack traces should be larger or equal - # than 1. - self.assertTrue(len(debug_info_def.files) >= 1) - # All the nodes from the exported graphdef are included. - self.assertEqual(len(debug_info_def.traces), len(graph1.get_operations())) - # Verifies that we can export a subgraph in a nested name scope containing a # "hidden1/hidden2" and import it into "new_hidden1/new_hidden2" in a new # graph. diff --git a/tensorflow/python/saved_model/BUILD b/tensorflow/python/saved_model/BUILD index 49168e55b893cf..f2bbc1443546a8 100644 --- a/tensorflow/python/saved_model/BUILD +++ b/tensorflow/python/saved_model/BUILD @@ -411,6 +411,7 @@ py_strict_library( "//tensorflow/python/util:compat", "//tensorflow/python/util:object_identity", "//tensorflow/python/util:tf_export", + "//tensorflow/python/util:tf_stack", "@absl_py//absl/logging", ] + if_google([ "//tensorflow/tools/proto_splitter/python:saved_model", diff --git a/tensorflow/python/saved_model/save.py b/tensorflow/python/saved_model/save.py index 110e677d4aca47..ca571d13a99db0 100644 --- a/tensorflow/python/saved_model/save.py +++ b/tensorflow/python/saved_model/save.py @@ -40,7 +40,6 @@ from tensorflow.python.eager.polymorphic_function import saved_model_exported_concrete from tensorflow.python.eager.polymorphic_function import saved_model_utils from tensorflow.python.framework import dtypes -from tensorflow.python.framework import error_interpolation from tensorflow.python.framework import errors from tensorflow.python.framework import function as framework_fn from tensorflow.python.framework import meta_graph @@ -76,6 +75,7 @@ from tensorflow.python.types import core as types_core from tensorflow.python.util import compat from tensorflow.python.util import object_identity +from tensorflow.python.util import tf_stack from tensorflow.python.util.tf_export import tf_export # Placeholder for protosplitter import. @@ -1077,16 +1077,14 @@ def _export_debug_info(exported_graph, export_dir): exported_graph: A Graph that has been created by tracing a saveable view. export_dir: SavedModel directory in which to write the debug info. """ - per_fn_info = [] + debug_builder = tf_stack.GraphDebugInfoBuilder() for fn_name in exported_graph._functions: # pylint: disable=protected-access fn = exported_graph._get_function(fn_name) # pylint: disable=protected-access if not isinstance(fn, defun.AtomicFunction): # pylint: disable=protected-access continue + debug_builder.AppendGraphDebugInfo(fn_name, fn.graph_debug_info) - per_fn_info.append((fn_name, fn.graph_debug_info)) - - graph_debug_info = error_interpolation.merge_graph_debug_info_def( - per_fn_info) + graph_debug_info = debug_builder.Build() file_io.atomic_write_string_to_file( file_io.join( path_helpers.get_or_create_debug_dir(export_dir), diff --git a/tensorflow/python/saved_model/save_test.py b/tensorflow/python/saved_model/save_test.py index ec83138b3e8edc..1b7d9c122f4123 100644 --- a/tensorflow/python/saved_model/save_test.py +++ b/tensorflow/python/saved_model/save_test.py @@ -1093,11 +1093,13 @@ def test_save_debug_info_enabled(self): # Verify that there is a trace for DEBUG_INFO_OP just to ensure that # function debug info tracing is nominally functioning. found_op = False - for key in debug_info.traces.keys(): + for key in debug_info.name_to_trace_id.keys(): if key.startswith("DEBUG_INFO_OP@"): found_op = True break - self.assertTrue(found_op, "Did not find DEBUG_INFO_OP in trace") + self.assertTrue( + found_op, "Did not find DEBUG_INFO_OP in trace: %s" % debug_info + ) def test_save_debug_info_disabled(self): root = autotrackable.AutoTrackable() diff --git a/tensorflow/python/util/BUILD b/tensorflow/python/util/BUILD index cf34e0af2239f5..8b4ad8155aac3e 100644 --- a/tensorflow/python/util/BUILD +++ b/tensorflow/python/util/BUILD @@ -366,8 +366,6 @@ py_strict_library( ], ) -# Note: this is a heavyweight library specialized for TensorFlow graphs. Do not use for -# other purposes. py_strict_library( name = "tf_stack", srcs = ["tf_stack.py"], @@ -376,6 +374,7 @@ py_strict_library( visibility = ["//visibility:public"], deps = [ ":_tf_stack", + "//tensorflow/core:protos_all_py", "@six_archive//:six", ], ) @@ -400,6 +399,8 @@ tf_python_pybind_extension( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "@pybind11", + "@pybind11_abseil//pybind11_abseil:absl_casters", + "@pybind11_abseil//pybind11_abseil:status_casters", ] + if_static([ ":stack_trace", ]), diff --git a/tensorflow/python/util/tf_stack.cc b/tensorflow/python/util/tf_stack.cc index 25d5cdbc540917..07ebbd038d7bb5 100644 --- a/tensorflow/python/util/tf_stack.cc +++ b/tensorflow/python/util/tf_stack.cc @@ -24,12 +24,13 @@ limitations under the License. // clang-format off // These headers must be at the top, before including Python.h header // Otherwise, we get C2039 on MSVC due to 'copysign' +#include "pybind11_abseil/absl_casters.h" // from @pybind11_abseil +#include "pybind11_abseil/status_casters.h" // from @pybind11_abseil #include "pybind11/complex.h" // from @pybind11 #include "pybind11/pybind11.h" // from @pybind11 #include "pybind11/stl.h" // from @pybind11 #include "pybind11/stl_bind.h" // from @pybind11 // clang-format on - #include #include @@ -45,6 +46,7 @@ limitations under the License. #include "tensorflow/c/c_api_internal.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_debug_info_builder.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/path.h" #include "tensorflow/core/platform/status.h" @@ -181,6 +183,8 @@ class StackTraceWrapper : public AbstractStackTrace { } // namespace PYBIND11_MODULE(_tf_stack, m) { + pybind11::google::ImportStatusModule(); + py::class_(m, "PyBindSourceMap") .def(py::init()) .def("update_to", @@ -217,6 +221,28 @@ PYBIND11_MODULE(_tf_stack, m) { } }); + py::class_(m, "GraphDebugInfoBuilder") + .def(py::init()) + .def( + "AppendGraphDebugInfo", + [](GraphDebugInfoBuilder& self, std::string fn_name, + py::bytes debug_info_str) { + return self.AppendGraphDebugInfoStr(fn_name, debug_info_str); + }, + py::arg("prefix"), py::arg("debug_info")) + .def( + "AccumulateStackTrace", + [](GraphDebugInfoBuilder& self, std::string function, std::string op, + const AbstractStackTrace& trace) { + std::string key = absl::StrCat(op, "@", function); + self.AccumulateStackTrace( + std::make_shared(trace.ToFrames()), key); + }, + py::arg("function"), py::arg("op"), py::arg("trace")) + .def("Build", [](GraphDebugInfoBuilder& self) -> py::bytes { + return py::bytes(self.ToGraphDebugInfoStr()); + }); + py::class_(m, "StackFrame") .def_property_readonly( "filename", @@ -327,6 +353,11 @@ PYBIND11_MODULE(_tf_stack, m) { }, py::arg("source_map"), py::arg("file_set"), py::arg("stacklevel") = 1, py::return_value_policy::take_ownership); + + m.def( + "LoadTracesFromDebugInfo", + [](py::bytes data) { return LoadTracesFromDebugInfoStr(data); }, + py::arg("debug_info_proto")); } } // namespace tensorflow diff --git a/tensorflow/python/util/tf_stack.py b/tensorflow/python/util/tf_stack.py index dc14576fb1a2b8..ae68e39e265ebf 100644 --- a/tensorflow/python/util/tf_stack.py +++ b/tensorflow/python/util/tf_stack.py @@ -18,6 +18,7 @@ import inspect import threading +from tensorflow.core.framework import graph_debug_info_pb2 from tensorflow.python.util import _tf_stack # Generally such lookups should be done using `threading.local()`. See @@ -165,5 +166,22 @@ def extract_stack(stacklevel=1): ) +def LoadTracesFromDebugInfo(debug_info): + return _tf_stack.LoadTracesFromDebugInfo(debug_info.SerializeToString()) + + +class GraphDebugInfoBuilder(_tf_stack.GraphDebugInfoBuilder): + + def AppendGraphDebugInfo(self, fn_name, fn_debug_info): + debug_info_str = fn_debug_info.SerializeToString() + super().AppendGraphDebugInfo(fn_name, debug_info_str) + + def Build(self): + debug_info_str = super().Build() + debug_info = graph_debug_info_pb2.GraphDebugInfo() + debug_info.ParseFromString(debug_info_str) + return debug_info + + StackSummary = _tf_stack.StackTrace FrameSummary = _tf_stack.StackFrame diff --git a/tensorflow/python/util/tf_stack_test.py b/tensorflow/python/util/tf_stack_test.py index b64bc3bf5e0b38..8b607acb30aab5 100644 --- a/tensorflow/python/util/tf_stack_test.py +++ b/tensorflow/python/util/tf_stack_test.py @@ -95,6 +95,25 @@ def func(n): str(trace[0]), 'File "filename", line 42, in function_name' ) + def testStackTraceBuilder(self): + stack1 = tf_stack.extract_stack() + stack2 = tf_stack.extract_stack() + stack3 = tf_stack.extract_stack() + + builder = tf_stack.GraphDebugInfoBuilder() + builder.AccumulateStackTrace('func1', 'node1', stack1) + builder.AccumulateStackTrace('func2', 'node2', stack2) + builder.AccumulateStackTrace('func3', 'node3', stack3) + debug_info = builder.Build() + + trace_map = tf_stack.LoadTracesFromDebugInfo(debug_info) + self.assertSameElements( + trace_map.keys(), ['node1@func1', 'node2@func2', 'node3@func3'] + ) + + for trace in trace_map.values(): + self.assertRegex(repr(trace), 'tf_stack_test.py', trace) + if __name__ == "__main__": test.main() diff --git a/tensorflow/tsl/platform/stack_frame.h b/tensorflow/tsl/platform/stack_frame.h index 780cae0940130c..a52a8d53fca60d 100644 --- a/tensorflow/tsl/platform/stack_frame.h +++ b/tensorflow/tsl/platform/stack_frame.h @@ -39,6 +39,12 @@ struct StackFrame { } bool operator!=(const StackFrame& other) const { return !(*this == other); } + + template + friend H AbslHashValue(H h, const StackFrame& frame) { + return h.combine(std::move(h), frame.file_name, frame.line_number, + frame.function_name); + } }; } // namespace tsl From d65b9f2d41bd6668ada73db0889c13a2d083dd68 Mon Sep 17 00:00:00 2001 From: Penporn Koanantakool Date: Thu, 27 Jul 2023 12:08:20 -0700 Subject: [PATCH 258/410] Guard ARM64 CPU info changes when on Apple or OpenBSD. PiperOrigin-RevId: 551603020 --- tensorflow/tsl/platform/cpu_info.cc | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/tensorflow/tsl/platform/cpu_info.cc b/tensorflow/tsl/platform/cpu_info.cc index 934e4af4563e70..f716fb6b3c5aca 100644 --- a/tensorflow/tsl/platform/cpu_info.cc +++ b/tensorflow/tsl/platform/cpu_info.cc @@ -22,13 +22,13 @@ limitations under the License. #if defined(PLATFORM_IS_X86) #include // NOLINT #endif -#if defined(PLATFORM_IS_ARM64) +#if defined(PLATFORM_IS_ARM64) && !defined(__APPLE__) && !defined(__OpenBSD__) #include #ifndef HWCAP_CPUID #define HWCAP_CPUID (1 << 11) #endif #include -#endif +#endif // PLATFORM_IS_ARM64 && !__APPLE__ && !__OpenBSD__ // SIMD extension querying is only available on x86. #ifdef PLATFORM_IS_X86 @@ -352,7 +352,7 @@ void InitCPUIDInfo() { #endif // PLATFORM_IS_X86 -#ifdef PLATFORM_IS_ARM64 +#if defined(PLATFORM_IS_ARM64) && !defined(__APPLE__) && !defined(__OpenBSD__) class CPUIDInfo; void InitCPUIDInfo(); @@ -377,7 +377,7 @@ class CPUIDInfo { } int present_cpu = -1; -#if !defined(PLATFORM_WINDOWS) && !defined(__APPLE__) && !defined(__OpenBSD__) +#ifndef PLATFORM_WINDOWS std::ifstream CPUspresent; CPUspresent.open("/sys/devices/system/cpu/present", std::ios::in); if (CPUspresent.is_open()) { @@ -398,13 +398,13 @@ class CPUIDInfo { present_cpu = std::stoi(line); } } -#endif +#endif // !PLATFORM_WINDOWS if (present_cpu == -1) { return; } -#if !defined(PLATFORM_WINDOWS) && !defined(__APPLE__) && !defined(__OpenBSD__) +#ifndef PLATFORM_WINDOWS std::stringstream str; str << "/sys/devices/system/cpu/cpu" << present_cpu << "/regs/identification/midr_el1"; @@ -420,7 +420,7 @@ class CPUIDInfo { cpuid->cpunum_ = (midr_el1 >> 4) & 0xFFF; } } -#endif +#endif // !PLATFORM_WINDOWS } int implementer() const { return implementer_; } @@ -438,7 +438,8 @@ void InitCPUIDInfo() { absl::call_once(cpuid_once_flag, CPUIDInfo::Initialize); } -#endif +#endif // PLATFORM_IS_ARM64 && !__APPLE__ && !__OpenBSD__ + } // namespace bool TestCPUFeature(CPUFeature feature) { @@ -462,7 +463,7 @@ int CPUFamily() { #ifdef PLATFORM_IS_X86 InitCPUIDInfo(); return cpuid->family(); -#elif defined(PLATFORM_IS_ARM64) +#elif defined(PLATFORM_IS_ARM64) && !defined(__APPLE__) && !defined(__OpenBSD__) InitCPUIDInfo(); return cpuid->implementer(); #else @@ -474,7 +475,7 @@ int CPUModelNum() { #ifdef PLATFORM_IS_X86 InitCPUIDInfo(); return cpuid->model_num(); -#elif defined(PLATFORM_IS_ARM64) +#elif defined(PLATFORM_IS_ARM64) && !defined(__APPLE__) && !defined(__OpenBSD__) InitCPUIDInfo(); return cpuid->cpunum(); #else From c73eaf43004edbd589ab175f10ed2480be25038a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 27 Jul 2023 12:26:16 -0700 Subject: [PATCH 259/410] Removing jit backend from tfrt. PiperOrigin-RevId: 551607952 --- tensorflow/compiler/mlir/tfrt/BUILD | 1 - tensorflow/core/runtime_fallback/BUILD | 1 - 2 files changed, 2 deletions(-) diff --git a/tensorflow/compiler/mlir/tfrt/BUILD b/tensorflow/compiler/mlir/tfrt/BUILD index 0de3fa7df2efcd..1e2f87545928e8 100644 --- a/tensorflow/compiler/mlir/tfrt/BUILD +++ b/tensorflow/compiler/mlir/tfrt/BUILD @@ -554,7 +554,6 @@ tf_cc_binary( "@tf_runtime//:dtype", "@tf_runtime//:simple_tracing_sink", "@tf_runtime//tools:bef_executor_expensive_kernels", - "@tf_runtime//tools:bef_executor_jit_kernels", "@tf_runtime//tools:bef_executor_lib", "@tf_runtime//tools:bef_executor_lightweight_kernels", ], diff --git a/tensorflow/core/runtime_fallback/BUILD b/tensorflow/core/runtime_fallback/BUILD index af9b5526e130e8..c4b3c5638efa0f 100644 --- a/tensorflow/core/runtime_fallback/BUILD +++ b/tensorflow/core/runtime_fallback/BUILD @@ -45,7 +45,6 @@ tf_cc_binary( # copybara:uncomment "@tf_runtime//backends/cpu:image_alwayslink", # copybara:uncomment "@tf_runtime//backends/cpu:proto_alwayslink", "@tf_runtime//backends/cpu:test_ops_alwayslink", - # copybara:uncomment "@tf_runtime//backends/jitrt:jitrt_corert_kernels_alwayslink", ] + select({ "//tensorflow:android": [ "//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs From e57d550a06b891cad890b2a963ef33b37c4a23ec Mon Sep 17 00:00:00 2001 From: David Silverstone Date: Thu, 27 Jul 2023 12:32:34 -0700 Subject: [PATCH 260/410] Provide interface to enable XLA outside compilation for Cloud TPU VMs. The TPU->CPU transfer and CPU computation is enabled. PiperOrigin-RevId: 551609509 --- .../stream_executor/tpu/c_api_conversions.h | 1 - .../xla/stream_executor/tpu/c_api_decl.h | 2 + .../stream_executor/tpu/tpu_op_executable.cc | 13 +- .../stream_executor/tpu/tpu_op_executable.h | 8 + .../xla/stream_executor/tpu/tpu_ops_c_api.h | 1 + .../next_pluggable_device/c/BUILD | 59 ++ .../c/outside_compilation_params.h | 39 ++ .../c/tf_rendezvous_c_api.h | 138 +++++ .../c/tf_rendezvous_c_api_conversions.cc | 547 ++++++++++++++++++ .../c/tf_rendezvous_c_api_conversions.h | 93 +++ .../c/tf_rendezvous_c_api_conversions_test.cc | 48 ++ .../c/tf_rendezvous_c_api_defn.h | 36 ++ tensorflow/core/tpu/BUILD | 6 +- tensorflow/core/tpu/tpu_execute.cc | 37 +- 14 files changed, 1018 insertions(+), 10 deletions(-) create mode 100644 tensorflow/core/common_runtime/next_pluggable_device/c/outside_compilation_params.h create mode 100644 tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api.h create mode 100644 tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_conversions.cc create mode 100644 tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_conversions.h create mode 100644 tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_conversions_test.cc create mode 100644 tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_defn.h diff --git a/tensorflow/compiler/xla/stream_executor/tpu/c_api_conversions.h b/tensorflow/compiler/xla/stream_executor/tpu/c_api_conversions.h index 03abaaa533dea0..3af72cd3266a12 100644 --- a/tensorflow/compiler/xla/stream_executor/tpu/c_api_conversions.h +++ b/tensorflow/compiler/xla/stream_executor/tpu/c_api_conversions.h @@ -32,7 +32,6 @@ limitations under the License. #include "tensorflow/compiler/xla/stream_executor/device_memory.h" #include "tensorflow/compiler/xla/stream_executor/tpu/c_api_decl.h" #include "tensorflow/compiler/xla/stream_executor/tpu/host_command_handler.h" -#include "tensorflow/compiler/xla/stream_executor/tpu/tpu_executor_c_api.h" // APIs for converting between internal and external versions of // XLA/StreamExecutor data structures. diff --git a/tensorflow/compiler/xla/stream_executor/tpu/c_api_decl.h b/tensorflow/compiler/xla/stream_executor/tpu/c_api_decl.h index 3c5f3f93bf6330..be7d9d8775f943 100644 --- a/tensorflow/compiler/xla/stream_executor/tpu/c_api_decl.h +++ b/tensorflow/compiler/xla/stream_executor/tpu/c_api_decl.h @@ -360,6 +360,8 @@ typedef struct SE_TpuHostCommandHandler { void* context; } SE_TpuHostCommandHandler; +typedef struct SE_OutsideCompilationParams SE_OutsideCompilationParams; + #ifdef __cplusplus } #endif diff --git a/tensorflow/compiler/xla/stream_executor/tpu/tpu_op_executable.cc b/tensorflow/compiler/xla/stream_executor/tpu/tpu_op_executable.cc index 87e3cd51c23e83..0465685921dd9c 100644 --- a/tensorflow/compiler/xla/stream_executor/tpu/tpu_op_executable.cc +++ b/tensorflow/compiler/xla/stream_executor/tpu/tpu_op_executable.cc @@ -37,7 +37,17 @@ TpuOpExecutable::TpuOpExecutable(const XLA_TpuProgram* core_program, HostCommandHandler host_command_handler) : TpuExecutableInterface(std::move(hlo_module)), core_program_(core_program), - host_command_handler_(std::move(host_command_handler)) {} + host_command_handler_(std::move(host_command_handler)), + outside_compilation_params_(nullptr) {} + +TpuOpExecutable::TpuOpExecutable( + const XLA_TpuProgram* core_program, + std::unique_ptr hlo_module, + SE_OutsideCompilationParams* outside_compilation_params) + : TpuExecutableInterface(std::move(hlo_module)), + core_program_(core_program), + host_command_handler_(nullptr), + outside_compilation_params_(outside_compilation_params) {} xla::Status TpuOpExecutable::LoadProgramAndEnqueueToStream( const xla::ServiceExecutableRunOptions& run_options, @@ -100,6 +110,7 @@ xla::Status TpuOpExecutable::LoadProgramAndEnqueueToStream( params.device_assignment = &c_dev_assign; params.stream = stream; params.host_command_handler = ApiConverter::ToC(host_command_handler_); + params.outside_compilation_params = outside_compilation_params_; params.status = status.c_status; stream_executor::tpu::OpsApiFn() diff --git a/tensorflow/compiler/xla/stream_executor/tpu/tpu_op_executable.h b/tensorflow/compiler/xla/stream_executor/tpu/tpu_op_executable.h index eed87fb09ba606..4ce192d939e149 100644 --- a/tensorflow/compiler/xla/stream_executor/tpu/tpu_op_executable.h +++ b/tensorflow/compiler/xla/stream_executor/tpu/tpu_op_executable.h @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/service_executable_run_options.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/stream_executor/device_memory.h" +#include "tensorflow/compiler/xla/stream_executor/tpu/c_api_decl.h" #include "tensorflow/compiler/xla/stream_executor/tpu/host_command_handler.h" #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_executable_interface.h" #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_ops_c_api.h" @@ -42,6 +43,11 @@ class TpuOpExecutable : public xla::TpuExecutableInterface { std::unique_ptr hlo_module, HostCommandHandler host_command_handler = nullptr); + explicit TpuOpExecutable( + const XLA_TpuProgram* core_program, + std::unique_ptr hlo_module, + SE_OutsideCompilationParams* outside_compilation_params); + ~TpuOpExecutable() override = default; const XLA_TpuProgram* core_program() const { return core_program_; } @@ -61,6 +67,8 @@ class TpuOpExecutable : public xla::TpuExecutableInterface { const HostCommandHandler host_command_handler_; + SE_OutsideCompilationParams* outside_compilation_params_; + TF_DISALLOW_COPY_AND_ASSIGN(TpuOpExecutable); }; diff --git a/tensorflow/compiler/xla/stream_executor/tpu/tpu_ops_c_api.h b/tensorflow/compiler/xla/stream_executor/tpu/tpu_ops_c_api.h index a8215981e0886b..8ecd4a69ff4f57 100644 --- a/tensorflow/compiler/xla/stream_executor/tpu/tpu_ops_c_api.h +++ b/tensorflow/compiler/xla/stream_executor/tpu/tpu_ops_c_api.h @@ -235,6 +235,7 @@ typedef struct TpuExecutable_LoadProgramAndEnqueueToStream_Params { XLA_DeviceAssignment* device_assignment; SE_Stream* stream; SE_TpuHostCommandHandler* host_command_handler; + SE_OutsideCompilationParams* outside_compilation_params; TF_Status* status; // out } TpuExecutable_LoadProgramAndEnqueueToStream_Params; diff --git a/tensorflow/core/common_runtime/next_pluggable_device/c/BUILD b/tensorflow/core/common_runtime/next_pluggable_device/c/BUILD index ffcbff253edd76..de20c0e8d1e265 100644 --- a/tensorflow/core/common_runtime/next_pluggable_device/c/BUILD +++ b/tensorflow/core/common_runtime/next_pluggable_device/c/BUILD @@ -57,3 +57,62 @@ tf_cc_shared_object( "//tensorflow/compiler/xla/pjrt/c:pjrt_c_api_hdrs", ], ) + +cc_library( + name = "outside_compilation_params", + hdrs = ["outside_compilation_params.h"], + visibility = ["//visibility:public"], + deps = [ + ":tf_rendezvous_c_api", + "//tensorflow/compiler/xla/stream_executor/tpu:c_api_decl", + ], +) + +cc_library( + name = "tf_rendezvous_c_api", + hdrs = ["tf_rendezvous_c_api.h"], + visibility = ["//visibility:public"], + deps = ["//tensorflow/c:tf_status_headers"], +) + +cc_library( + name = "tf_rendezvous_c_api_defn", + hdrs = ["tf_rendezvous_c_api_defn.h"], + deps = [ + ":tf_rendezvous_c_api", + "//tensorflow/core:framework", + ], +) + +cc_library( + name = "tf_rendezvous_c_api_conversions", + srcs = ["tf_rendezvous_c_api_conversions.cc"], + hdrs = ["tf_rendezvous_c_api_conversions.h"], + visibility = ["//visibility:public"], + deps = [ + ":outside_compilation_params", + ":tf_rendezvous_c_api", + ":tf_rendezvous_c_api_defn", + "//tensorflow/c:tf_status", + "//tensorflow/c:tf_status_helper", + "//tensorflow/compiler/xla/stream_executor/tpu:proto_helper", + "//tensorflow/core:framework_internal", + "//tensorflow/tsl/framework:allocator", + ], +) + +tf_cc_test( + name = "tf_rendezvous_c_api_conversions_test", + srcs = ["tf_rendezvous_c_api_conversions_test.cc"], + tags = [ + "no_mac", + "no_windows", + ], + deps = [ + ":tf_rendezvous_c_api", + ":tf_rendezvous_c_api_conversions", + "//tensorflow/c:tf_status", + "//tensorflow/tsl/framework:allocator", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/tensorflow/core/common_runtime/next_pluggable_device/c/outside_compilation_params.h b/tensorflow/core/common_runtime/next_pluggable_device/c/outside_compilation_params.h new file mode 100644 index 00000000000000..204c560abb16b6 --- /dev/null +++ b/tensorflow/core/common_runtime/next_pluggable_device/c/outside_compilation_params.h @@ -0,0 +1,39 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_C_OUTSIDE_COMPILATION_PARAMS_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_C_OUTSIDE_COMPILATION_PARAMS_H_ + +#include + +#include "tensorflow/compiler/xla/stream_executor/tpu/c_api_decl.h" +#include "tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api.h" + +#ifdef __cplusplus +extern "C" { +#endif + +struct SE_OutsideCompilationParams { + char* device_name; + char* rendezvous_key; + TF_RendezvousThunk* rendezvous; + TpuSerializedProto host_transfers; +}; + +#ifdef __cplusplus +} +#endif + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_C_OUTSIDE_COMPILATION_PARAMS_H_ diff --git a/tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api.h b/tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api.h new file mode 100644 index 00000000000000..fc1155f113a03c --- /dev/null +++ b/tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api.h @@ -0,0 +1,138 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_C_TF_RENDEZVOUS_C_API_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_C_TF_RENDEZVOUS_C_API_H_ + +#include + +#include "tensorflow/c/tf_status.h" + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct TF_DeviceContext TF_DeviceContext; + +typedef struct TFDevice_AllocatorAttributes { + uint32_t value; + int32_t scope_id; +} TFDevice_AllocatorAttributes; + +typedef struct TF_CancellationManager TF_CancellationManager; + +typedef struct TF_TensorWrapper TF_TensorWrapper; + +typedef struct TF_RendezvousArgsStruct { + TF_DeviceContext* device_context; + TFDevice_AllocatorAttributes alloc_attrs; + TF_CancellationManager* cancellation_manager; +} TF_RendezvousArgsStruct; + +typedef struct TF_DeviceUtilsParsedName { + char* job_str; + uint32_t job_str_size; + bool has_replica; + int replica; + bool has_task; + int task; + char* type_str; + uint32_t type_str_size; + bool has_id; + int id; +} TF_DeviceUtilsParsedName; + +typedef struct TF_RendezvousParsedKey { + char* src_device_str; + uint32_t src_device_str_size; + TF_DeviceUtilsParsedName src_parsed_name; + uint64_t src_incarnation; + + char* dst_device_str; + uint32_t dst_device_str_size; + TF_DeviceUtilsParsedName dst_parsed_name; + + char* edge_name; + uint32_t edge_name_size; +} TF_RendezvousParsedKey; + +typedef struct TF_RendezvousSend_Params { + const TF_RendezvousParsedKey* key; + const TF_RendezvousArgsStruct* args; + const TF_TensorWrapper* tensor; + bool is_dead; + + TF_Status* status; // out +} TF_RendezvousSend_Params; + +typedef void (*TF_RendezvousSend_Function)(void*, TF_RendezvousSend_Params*); + +typedef struct TF_RendezvousSenderImpl { + void* context; + TF_RendezvousSend_Function send_func; +} TF_RendezvousSenderImpl; + +typedef struct TF_RendezvousDoneCallback_Params { + void* context; + const TF_Status* status; + const TF_RendezvousArgsStruct* sender_args; + const TF_RendezvousArgsStruct* recver_args; + const TF_TensorWrapper* tensor; + bool is_dead; +} TF_RendezvousDoneCallback_Params; + +typedef void (*TF_RendezvousDoneCallback_Function)( + void*, TF_RendezvousDoneCallback_Params*); + +typedef struct TF_RendezvousDoneCallbackImpl { + void* context; + TF_RendezvousDoneCallback_Function callback; +} TF_RendezvousDoneCallbackImpl; + +typedef struct TF_RendezvousAsyncRecv_Params { + void* context; + const TF_RendezvousParsedKey* key; + const TF_RendezvousArgsStruct* args; + TF_RendezvousDoneCallbackImpl on_done; +} TF_RendezvousAsyncRecv_Params; + +typedef void (*TF_RendezvousAsyncRecv_Function)(void*, + TF_RendezvousAsyncRecv_Params*); + +typedef struct TF_RendezvousAsyncRecverImpl { + void* context; + TF_RendezvousAsyncRecv_Function async_recv_func; +} TF_RendezvousAsyncRecverImpl; + +typedef void (*TF_RendezvousStartAbort_Function)(void* context, + const TF_Status*); + +typedef struct TF_RendezvousStartAbortImpl { + void* context; + TF_RendezvousStartAbort_Function start_abort_func; +} TF_RendezvousStartAbortImpl; + +typedef struct TF_RendezvousThunk { + void* context; // not owned + TF_RendezvousSenderImpl send; + TF_RendezvousAsyncRecverImpl async_recv; + TF_RendezvousStartAbortImpl start_abort; +} TF_RendezvousThunk; + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_C_TF_RENDEZVOUS_C_API_H_ diff --git a/tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_conversions.cc b/tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_conversions.cc new file mode 100644 index 00000000000000..754d5952946dd5 --- /dev/null +++ b/tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_conversions.cc @@ -0,0 +1,547 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_conversions.h" + +#include +#include +#include + +#include "tensorflow/c/tf_status.h" +#include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/compiler/xla/stream_executor/tpu/proto_helper.h" +#include "tensorflow/core/common_runtime/next_pluggable_device/c/outside_compilation_params.h" +#include "tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api.h" +#include "tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_defn.h" +#include "tensorflow/core/framework/rendezvous.h" + +using TF_StatusCallback = std::function; + +namespace tensorflow { + +TFDevice_AllocatorAttributes ToC(const tsl::AllocatorAttributes& attributes) { + TFDevice_AllocatorAttributes c_attributes; + c_attributes.value = attributes.value; + c_attributes.scope_id = attributes.scope_id; + return c_attributes; +} + +tsl::AllocatorAttributes FromC( + const TFDevice_AllocatorAttributes& c_attributes) { + tsl::AllocatorAttributes attributes; + attributes.value = c_attributes.value; + attributes.scope_id = c_attributes.scope_id; + return attributes; +} + +void Destroy(TFDevice_AllocatorAttributes* c_attributes) {} + +TF_RendezvousArgsStruct ToC(const RendezvousInterface::Args& args) { + TF_RendezvousArgsStruct c_args; + c_args.device_context = new TF_DeviceContext(); + c_args.cancellation_manager = new TF_CancellationManager(); + c_args.device_context->device_context = args.device_context; + c_args.alloc_attrs = ToC(args.alloc_attrs); + c_args.cancellation_manager->cancellation_manager = args.cancellation_manager; + return c_args; +} + +RendezvousInterface::Args FromC(const TF_RendezvousArgsStruct& c_args) { + RendezvousInterface::Args args; + args.alloc_attrs = FromC(c_args.alloc_attrs); + args.device_context = (c_args.device_context == nullptr) + ? nullptr + : c_args.device_context->device_context; + args.cancellation_manager = + (c_args.cancellation_manager == nullptr) + ? nullptr + : c_args.cancellation_manager->cancellation_manager; + return args; +} + +void Destroy(TF_RendezvousArgsStruct* c_args) { + if (c_args->device_context != nullptr) { + delete c_args->device_context; + } + if (c_args->cancellation_manager != nullptr) { + delete c_args->cancellation_manager; + } +} + +TF_DeviceUtilsParsedName ToC(const DeviceNameUtils::ParsedName& name) { + TF_DeviceUtilsParsedName c_name; + if (name.has_job) { + c_name.job_str = new char[name.job.size() + 1]; + c_name.job_str_size = name.job.size(); + std::strncpy(c_name.job_str, name.job.data(), name.job.size()); + } else { + c_name.job_str = nullptr; + c_name.job_str_size = 0; + std::strncpy(c_name.type_str, name.type.data(), name.type.size()); + } + if (name.has_type) { + c_name.type_str = new char[name.type.size() + 1]; + c_name.type_str_size = name.type.size(); + } else { + c_name.type_str = nullptr; + c_name.type_str_size = 0; + } + c_name.has_replica = name.has_replica; + c_name.replica = name.replica; + c_name.has_task = name.has_task; + c_name.task = name.task; + c_name.has_id = name.has_id; + c_name.id = name.id; + return c_name; +} + +DeviceNameUtils::ParsedName FromC(const TF_DeviceUtilsParsedName& c_name) { + DeviceNameUtils::ParsedName name; + if (c_name.job_str != nullptr) { + name.job = absl::string_view(c_name.job_str, c_name.job_str_size); + name.has_job = true; + } else { + name.has_job = false; + } + if (c_name.type_str != nullptr) { + name.type = absl::string_view(c_name.type_str, c_name.type_str_size); + name.has_type = true; + } else { + name.has_type = false; + } + name.has_replica = c_name.has_replica; + name.replica = c_name.replica; + name.has_task = c_name.has_task; + name.task = c_name.task; + name.has_id = c_name.has_id; + name.id = c_name.id; + return name; +} + +void Destroy(TF_DeviceUtilsParsedName* c_name) { + if (c_name->job_str != nullptr) { + delete[] c_name->job_str; + } + if (c_name->type_str != nullptr) { + delete[] c_name->type_str; + } +} + +TF_RendezvousParsedKey ToC(const RendezvousInterface::ParsedKey& key) { + TF_RendezvousParsedKey c_key; + c_key.src_device_str_size = key.src_device.size(); + c_key.src_device_str = new char[c_key.src_device_str_size + 1]; + std::strncpy(c_key.src_device_str, key.src_device.data(), + key.src_device.size()); + c_key.src_parsed_name = ToC(key.src); + c_key.src_incarnation = key.src_incarnation; + + c_key.dst_device_str_size = key.dst_device.size(); + c_key.dst_device_str = new char[c_key.dst_device_str_size + 1]; + c_key.dst_device_str_size = key.dst_device.size(); + std::strncpy(c_key.dst_device_str, key.dst_device.data(), + key.dst_device.size()); + c_key.dst_parsed_name = ToC(key.dst); + + c_key.edge_name = new char[key.edge_name.size() + 1]; + c_key.edge_name_size = key.edge_name.size(); + std::strncpy(c_key.edge_name, key.edge_name.data(), key.edge_name.size()); + + return c_key; +} + +RendezvousInterface::ParsedKey FromC(const TF_RendezvousParsedKey& c_key) { + RendezvousInterface::ParsedKey key; + key.src_device = + absl::string_view(c_key.src_device_str, c_key.src_device_str_size); + key.src = FromC(c_key.src_parsed_name); + key.src_incarnation = c_key.src_incarnation; + + key.dst_device = + absl::string_view(c_key.dst_device_str, c_key.dst_device_str_size); + key.dst = FromC(c_key.dst_parsed_name); + + key.edge_name = absl::string_view(c_key.edge_name, c_key.edge_name_size); + + return key; +} + +void Destroy(TF_RendezvousParsedKey* c_key) { + delete[] c_key->src_device_str; + delete[] c_key->dst_device_str; + delete[] c_key->edge_name; + Destroy(&c_key->src_parsed_name); + Destroy(&c_key->dst_parsed_name); +} + +namespace { + +using SendParamDeleter = std::function; +using RecvParamDeleter = std::function; +using DoneCallbackParamDeleter = + std::function; + +using SendParamPtr = + std::unique_ptr; +using RecvParamPtr = + std::unique_ptr; +using DoneCallbackParamPtr = + std::unique_ptr; + +SendParamDeleter MakeSendParamDeleter(); +SendParamPtr SendParamsToC(const RendezvousInterface::ParsedKey& key, + const RendezvousInterface::Args& args, + const Tensor& val, bool is_dead); + +RecvParamDeleter MakeRecvParamDeleter(); +RecvParamPtr RecvParamsToC(const RendezvousInterface::ParsedKey& key, + const RendezvousInterface::Args& args, + RendezvousInterface::DoneCallback on_done); + +DoneCallbackParamDeleter MakeDoneCallbackParamDeleter(); +DoneCallbackParamPtr DoneCallbackParamsToC( + const Status& status, const RendezvousInterface::Args& sender_args, + const RendezvousInterface::Args& recver_args, const Tensor& tensor, + bool is_dead); + +// Use in `TF_RendezvousThunk ToC(tensorflow::RendezvousInterface* rendezvous)` +TF_RendezvousSenderImpl BindSendFunction(RendezvousInterface* rendezvous); + +// Use in `TF_RendezvousThunk ToC(tensorflow::RendezvousInterface* rendezvous)` +TF_RendezvousAsyncRecverImpl BindAsyncRecvFunction( + RendezvousInterface* rendezvous); + +TF_RendezvousStartAbortImpl BindStartAborter(RendezvousInterface* rendezvous); + +void RendezvousCallbackThunk(void* context, + TF_RendezvousDoneCallback_Params* params) { + using CallbackType = std::function; + auto* callback = static_cast(context); + (*callback)(params); +} + +} // namespace + +TF_RendezvousDoneCallbackImpl ToC( + const RendezvousInterface::DoneCallback& on_done) { + TF_RendezvousDoneCallbackImpl done_func; + using CallbackType = std::function; + auto c_callback = new CallbackType( + [on_done](TF_RendezvousDoneCallback_Params* params) -> void { + Status status = tsl::StatusFromTF_Status(params->status); + auto sender_args = FromC(*params->sender_args); + auto recver_args = FromC(*params->recver_args); + const Tensor& tensor = params->tensor->tensor; + on_done(status, sender_args, recver_args, tensor, params->is_dead); + }); + done_func.context = static_cast(c_callback); + done_func.callback = RendezvousCallbackThunk; + return done_func; +} + +RendezvousInterface::DoneCallback FromC( + const TF_RendezvousDoneCallbackImpl& c_on_done) { + if (c_on_done.context == nullptr) { + return nullptr; + } + TF_RendezvousDoneCallback_Function callback = c_on_done.callback; + void* context = c_on_done.context; + auto cpp_callback = [callback, context](const Status& status, + RendezvousInterface::Args sender_args, + RendezvousInterface::Args recver_args, + const Tensor& tensor, + const bool is_dead) -> void { + DoneCallbackParamPtr params = DoneCallbackParamsToC( + status, sender_args, recver_args, tensor, is_dead); + callback(context, params.get()); + }; + return cpp_callback; +} + +void Destroy(TF_RendezvousDoneCallbackImpl* c_on_done) { + if (c_on_done == nullptr) { + return; + } + if (c_on_done->context != nullptr) { + auto runner = + static_cast*>( + c_on_done->context); + delete runner; + } +} + +TF_RendezvousThunk* ToC(RendezvousInterface* rendezvous) { + TF_RendezvousThunk* thunk = new TF_RendezvousThunk(); + thunk->context = rendezvous; + + thunk->send = BindSendFunction(rendezvous); + thunk->async_recv = BindAsyncRecvFunction(rendezvous); + thunk->start_abort = BindStartAborter(rendezvous); + + return thunk; +} + +std::unique_ptr FromC( + const TF_RendezvousThunk* thunk) { + return std::make_unique(thunk); +} + +void Destroy(TF_RendezvousThunk* thunk) { + if (thunk == nullptr) { + return; + } + Destroy(&thunk->send); + Destroy(&thunk->async_recv); + Destroy(&thunk->start_abort); + delete thunk; +} + +namespace { + +SendParamDeleter MakeSendParamDeleter() { + return [](TF_RendezvousSend_Params* params) { + if (params == nullptr) { + return; + } + TF_RendezvousParsedKey* key = + const_cast(params->key); + TF_RendezvousArgsStruct* args = + const_cast(params->args); + Destroy(key); + Destroy(args); + delete params->key; + delete params->args; + delete params->tensor; + TF_DeleteStatus(params->status); + delete params; + }; +} + +SendParamPtr SendParamsToC(const RendezvousInterface::ParsedKey& key, + const RendezvousInterface::Args& args, + const Tensor& val, const bool is_dead) { + TF_RendezvousSend_Params* params = new TF_RendezvousSend_Params(); + params->key = new TF_RendezvousParsedKey(ToC(key)); + params->args = new TF_RendezvousArgsStruct(ToC(args)); + params->tensor = new TF_TensorWrapper({val}); + params->is_dead = is_dead; + params->status = TF_NewStatus(); + return SendParamPtr(params, MakeSendParamDeleter()); +} + +RecvParamDeleter MakeRecvParamDeleter() { + return [](TF_RendezvousAsyncRecv_Params* params) { + if (params == nullptr) { + return; + } + TF_RendezvousParsedKey* key = + const_cast(params->key); + TF_RendezvousArgsStruct* args = + const_cast(params->args); + Destroy(key); + Destroy(args); + Destroy(¶ms->on_done); + delete params->key; + delete params->args; + delete params; + }; +} + +RecvParamPtr RecvParamsToC(const RendezvousInterface::ParsedKey& key, + const RendezvousInterface::Args& args, + RendezvousInterface::DoneCallback on_done) { + TF_RendezvousAsyncRecv_Params* params = new TF_RendezvousAsyncRecv_Params(); + params->key = new TF_RendezvousParsedKey(ToC(key)); + params->args = new TF_RendezvousArgsStruct(ToC(args)); + params->on_done = ToC(on_done); + return RecvParamPtr(params, MakeRecvParamDeleter()); +} + +DoneCallbackParamDeleter MakeDoneCallbackParamDeleter() { + return [](TF_RendezvousDoneCallback_Params* params) { + if (params == nullptr) { + return; + } + TF_RendezvousArgsStruct* sender_args = + const_cast(params->sender_args); + TF_RendezvousArgsStruct* recver_args = + const_cast(params->recver_args); + Destroy(sender_args); + Destroy(recver_args); + TF_Status* status = const_cast(params->status); + TF_DeleteStatus(status); + delete params->sender_args; + delete params->recver_args; + delete params->tensor; + delete params; + }; +} + +DoneCallbackParamPtr DoneCallbackParamsToC( + const Status& status, const RendezvousInterface::Args& sender_args, + const RendezvousInterface::Args& recver_args, const Tensor& tensor, + const bool is_dead) { + TF_RendezvousDoneCallback_Params* params = + new TF_RendezvousDoneCallback_Params; + TF_Status* c_status = TF_NewStatus(); + tsl::Set_TF_Status_from_Status(c_status, status); + params->status = c_status; + params->sender_args = new TF_RendezvousArgsStruct(ToC(sender_args)); + params->recver_args = new TF_RendezvousArgsStruct(ToC(recver_args)); + params->tensor = new TF_TensorWrapper({tensor}); + params->is_dead = is_dead; + return DoneCallbackParamPtr(params, MakeDoneCallbackParamDeleter()); +} + +void SendFunctionThunk(void* context, TF_RendezvousSend_Params* params) { + using SendFunction = std::function; + auto* send_func = static_cast(context); + (*send_func)(params); +} + +// Use in `TF_RendezvousThunk ToC(tensorflow::RendezvousInterface* rendezvous)` +TF_RendezvousSenderImpl BindSendFunction(RendezvousInterface* rendezvous) { + TF_RendezvousSenderImpl send_func; + using SendFunction = std::function; + auto sender = + new SendFunction([rendezvous](TF_RendezvousSend_Params* params) -> void { + RendezvousInterface::ParsedKey key = FromC(*params->key); + RendezvousInterface::Args args = FromC(*params->args); + const Tensor& tensor = params->tensor->tensor; + bool is_dead = params->is_dead; + tsl::Set_TF_Status_from_Status( + params->status, rendezvous->Send(key, args, tensor, is_dead)); + }); + send_func.context = static_cast(sender); + send_func.send_func = SendFunctionThunk; + return send_func; +} + +void RecvFunctionThunk(void* context, TF_RendezvousAsyncRecv_Params* params) { + using RecvFunction = std::function; + auto* recv_func = static_cast(context); + (*recv_func)(params); +} + +// Use in `TF_RendezvousThunk ToC(tensorflow::RendezvousInterface* rendezvous)` +TF_RendezvousAsyncRecverImpl BindAsyncRecvFunction( + RendezvousInterface* rendezvous) { + TF_RendezvousAsyncRecverImpl recv_func; + using RecvFunction = std::function; + auto recver = new RecvFunction( + [rendezvous](TF_RendezvousAsyncRecv_Params* params) -> void { + RendezvousInterface::ParsedKey key = FromC(*params->key); + RendezvousInterface::Args args = FromC(*params->args); + RendezvousInterface::DoneCallback on_done = FromC(params->on_done); + rendezvous->RecvAsync(key, args, on_done); + }); + recv_func.context = static_cast(recver); + recv_func.async_recv_func = RecvFunctionThunk; + return recv_func; +} + +void StartAbortFunctionThunk(void* context, const TF_Status* status) { + auto* callback = static_cast(context); + (*callback)(status); +} + +// Use in `TF_RendezvousThunk ToC(tensorflow::RendezvousInterface* rendezvous)` +TF_RendezvousStartAbortImpl BindStartAborter(RendezvousInterface* rendezvous) { + TF_RendezvousStartAbortImpl start_abort; + auto aborter = + new TF_StatusCallback([rendezvous](const TF_Status* status) -> void { + rendezvous->StartAbort(tsl::StatusFromTF_Status(status)); + }); + start_abort.context = static_cast(aborter); + start_abort.start_abort_func = StartAbortFunctionThunk; + return start_abort; +} + +} // namespace + +void Destroy(TF_RendezvousSenderImpl* send_func) { + if (send_func == nullptr) { + return; + } + if (send_func->context != nullptr) { + auto runner = static_cast*>( + send_func->context); + delete runner; + } +} + +void Destroy(TF_RendezvousAsyncRecverImpl* recv_func) { + if (recv_func == nullptr) { + return; + } + if (recv_func->context != nullptr) { + auto runner = + static_cast*>( + recv_func->context); + delete runner; + } +} + +void Destroy(TF_RendezvousStartAbortImpl* start_abort_func) { + if (start_abort_func == nullptr) { + return; + } + if (start_abort_func->context != nullptr) { + auto runner = static_cast(start_abort_func->context); + delete runner; + } +} + +namespace c_api { + +Status TfCThunkRendezvous::Send(const ParsedKey& key, const Args& args, + const Tensor& val, const bool is_dead) { + SendParamPtr params = SendParamsToC(key, args, val, is_dead); + const TF_RendezvousSenderImpl& sender = thunk_->send; + sender.send_func(sender.context, params.get()); + return tsl::StatusFromTF_Status(params->status); +} + +void TfCThunkRendezvous::RecvAsync(const ParsedKey& key, const Args& args, + DoneCallback done) { + RecvParamPtr params = RecvParamsToC(key, args, done); + const TF_RendezvousAsyncRecverImpl& async_recv = thunk_->async_recv; + async_recv.async_recv_func(async_recv.context, params.get()); +} + +void TfCThunkRendezvous::StartAbort(const Status& status) { + std::unique_ptr> c_status( + TF_NewStatus(), &TF_DeleteStatus); + tsl::Set_TF_Status_from_Status(c_status.get(), status); + const TF_RendezvousStartAbortImpl& start_abort = thunk_->start_abort; + start_abort.start_abort_func(start_abort.context, c_status.get()); +} + +} // namespace c_api + +void DestroyOCParams(SE_OutsideCompilationParams* params) { + if (params == nullptr) { + return; + } + delete[] params->device_name; + delete[] params->rendezvous_key; + Destroy(params->rendezvous); + if (params->host_transfers.size > 0) { + StreamExecutor_Tpu_FreeSerializedProto(¶ms->host_transfers); + } + delete params; +} + +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_conversions.h b/tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_conversions.h new file mode 100644 index 00000000000000..08cff3725b06b8 --- /dev/null +++ b/tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_conversions.h @@ -0,0 +1,93 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_C_TF_RENDEZVOUS_C_API_CONVERSIONS_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_C_TF_RENDEZVOUS_C_API_CONVERSIONS_H_ + +#include + +#include "tensorflow/core/common_runtime/next_pluggable_device/c/outside_compilation_params.h" +#include "tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api.h" +#include "tensorflow/core/framework/rendezvous.h" +#include "tensorflow/tsl/framework/allocator.h" + +namespace tensorflow { + +namespace c_api { + +class TfCThunkRendezvous final : public ::tensorflow::Rendezvous { + public: + explicit TfCThunkRendezvous(const TF_RendezvousThunk* thunk) + : thunk_(thunk) {} + + ~TfCThunkRendezvous() override = default; + + Status Send(const ParsedKey& key, const Args& args, const Tensor& val, + bool is_dead) override; + + void RecvAsync(const ParsedKey& key, const Args& args, + DoneCallback done) override; + + void StartAbort(const Status& status) override; + + private: + const TF_RendezvousThunk* thunk_; +}; + +} // namespace c_api + +TFDevice_AllocatorAttributes ToC(const tsl::AllocatorAttributes& attributes); +tsl::AllocatorAttributes FromC( + const TFDevice_AllocatorAttributes& c_attributes); +void Destroy(TFDevice_AllocatorAttributes* c_attributes); + +TF_RendezvousArgsStruct ToC(const tensorflow::RendezvousInterface::Args& args); +tensorflow::RendezvousInterface::Args FromC( + const TF_RendezvousArgsStruct& c_args); +void Destroy(TF_RendezvousArgsStruct* c_args); + +TF_DeviceUtilsParsedName ToC( + const tensorflow::DeviceNameUtils::ParsedName& name); +tensorflow::DeviceNameUtils::ParsedName FromC( + const TF_DeviceUtilsParsedName& c_name); +void Destroy(TF_DeviceUtilsParsedName* c_name); + +TF_RendezvousParsedKey ToC( + const tensorflow::RendezvousInterface::ParsedKey& key); +tensorflow::RendezvousInterface::ParsedKey FromC( + const TF_RendezvousParsedKey& c_key); +void Destroy(TF_RendezvousParsedKey* c_key); + +TF_RendezvousDoneCallbackImpl ToC( + const tensorflow::RendezvousInterface::DoneCallback& on_done); +tensorflow::RendezvousInterface::DoneCallback FromC( + const TF_RendezvousDoneCallbackImpl& c_on_done); +void Destroy(TF_RendezvousDoneCallbackImpl* c_on_done); + +TF_RendezvousThunk* ToC(tensorflow::RendezvousInterface* rendezvous); +// `tensorflow::RendezvousInterface` has a protected destructor, so this +// function can't return std::unique_ptr. +std::unique_ptr FromC( + const TF_RendezvousThunk* thunk); +void Destroy(TF_RendezvousThunk* thunk); + +void Destroy(TF_RendezvousSenderImpl* send_func); +void Destroy(TF_RendezvousAsyncRecverImpl* recv_func); +void Destroy(TF_RendezvousStartAbortImpl* start_abort_func); + +void DestroyOCParams(SE_OutsideCompilationParams* params); +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_C_TF_RENDEZVOUS_C_API_CONVERSIONS_H_ diff --git a/tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_conversions_test.cc b/tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_conversions_test.cc new file mode 100644 index 00000000000000..2b20aab285bb93 --- /dev/null +++ b/tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_conversions_test.cc @@ -0,0 +1,48 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_conversions.h" + +#include + +#include +#include "tensorflow/c/tf_status.h" +#include "tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api.h" +#include "tensorflow/tsl/framework/allocator.h" + +namespace tensorflow { + +namespace { + +TEST(AllocatorAttributes, ToAndFromC) { + constexpr uint32_t kValue = 0x1234'5678; + constexpr int32_t kScopeId = 1; + + tsl::AllocatorAttributes in_attributes; + in_attributes.value = kValue; + in_attributes.scope_id = kScopeId; + + TFDevice_AllocatorAttributes c_attributes = ToC(in_attributes); + EXPECT_EQ(kValue, c_attributes.value); + EXPECT_EQ(kScopeId, c_attributes.scope_id); + + tsl::AllocatorAttributes out_attributes = FromC(c_attributes); + EXPECT_EQ(kValue, out_attributes.value); + EXPECT_EQ(kScopeId, out_attributes.scope_id); +} + +} // namespace + +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_defn.h b/tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_defn.h new file mode 100644 index 00000000000000..9017f10f1a3437 --- /dev/null +++ b/tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_defn.h @@ -0,0 +1,36 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_C_TF_RENDEZVOUS_C_API_DEFN_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_C_TF_RENDEZVOUS_C_API_DEFN_H_ + +#include "tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api.h" +#include "tensorflow/core/framework/cancellation.h" +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/tensor.h" + +struct TF_DeviceContext { + tensorflow::DeviceContext* device_context; // not owned +}; + +struct TF_CancellationManager { + tensorflow::CancellationManager* cancellation_manager; // not owned +}; + +struct TF_TensorWrapper { + tensorflow::Tensor tensor; +}; + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_C_TF_RENDEZVOUS_C_API_DEFN_H_ diff --git a/tensorflow/core/tpu/BUILD b/tensorflow/core/tpu/BUILD index d62dfccb65af53..6344b65ce6c16b 100644 --- a/tensorflow/core/tpu/BUILD +++ b/tensorflow/core/tpu/BUILD @@ -259,22 +259,22 @@ cc_library( "//tensorflow/compiler/xla/stream_executor/tpu:c_api_conversions", "//tensorflow/compiler/xla/stream_executor/tpu:c_api_decl", "//tensorflow/compiler/xla/stream_executor/tpu:host_command_handler", + "//tensorflow/compiler/xla/stream_executor/tpu:proto_helper", "//tensorflow/compiler/xla/stream_executor/tpu:status_helper", "//tensorflow/compiler/xla/stream_executor/tpu:tpu_api", - "//tensorflow/compiler/xla/stream_executor/tpu:tpu_executor_c_api_hdrs", "//tensorflow/compiler/xla/stream_executor/tpu:tpu_node_context", "//tensorflow/compiler/xla/stream_executor/tpu:tpu_op_executable", "//tensorflow/compiler/xla/stream_executor/tpu:tpu_ops_c_api_hdrs", "//tensorflow/compiler/xla/stream_executor/tpu:tpu_platform_interface", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core/common_runtime/next_pluggable_device/c:outside_compilation_params", + "//tensorflow/core/common_runtime/next_pluggable_device/c:tf_rendezvous_c_api_conversions", "//tensorflow/core/profiler/lib:traceme", "//tensorflow/core/tpu/kernels:tpu_executable_info_proto_cc", - "//tensorflow/core/tpu/kernels:tpu_execute_op_options", "@com_google_absl//absl/base", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/log", - "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", ], diff --git a/tensorflow/core/tpu/tpu_execute.cc b/tensorflow/core/tpu/tpu_execute.cc index 7008e7e05c6bf3..115689b4db68e1 100644 --- a/tensorflow/core/tpu/tpu_execute.cc +++ b/tensorflow/core/tpu/tpu_execute.cc @@ -17,6 +17,8 @@ limitations under the License. #include #include +#include +#include #include #include #include @@ -51,23 +53,23 @@ limitations under the License. #include "tensorflow/compiler/xla/stream_executor/tpu/c_api_decl.h" #include "tensorflow/compiler/xla/stream_executor/tpu/c_api_defn.h" #include "tensorflow/compiler/xla/stream_executor/tpu/host_command_handler.h" +#include "tensorflow/compiler/xla/stream_executor/tpu/proto_helper.h" #include "tensorflow/compiler/xla/stream_executor/tpu/status_helper.h" #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_api.h" -#include "tensorflow/compiler/xla/stream_executor/tpu/tpu_executor_c_api.h" #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_node_context.h" #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_op_executable.h" #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_ops_c_api.h" #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_platform_interface.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/common_runtime/next_pluggable_device/c/outside_compilation_params.h" +#include "tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_conversions.h" #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/platform/casts.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/profiler/lib/traceme.h" #include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h" -#include "tensorflow/core/tpu/kernels/tpu_execute_op_options.h" #include "tensorflow/tsl/framework/cancellation.h" #include "tensorflow/tsl/platform/errors.h" #include "tensorflow/tsl/platform/statusor.h" @@ -406,7 +408,7 @@ void UnregisterCancellation( // 4) StartCancel() in (1) cannot complete until (3) is done. // // Instead, call TryDeregisterCallback. The functional difference is - // TryDeregisterCallback will not block if cancellation is in proress + // TryDeregisterCallback will not block if cancellation is in progress // so makes no guarantees as to the state of any callbacks. // This is not a problem, as our cancellation handler does not rely on // any external state. @@ -425,6 +427,27 @@ void UnregisterCancellation( stream->ReturnSubStream(deregister_stream); } +std::unique_ptr> +CreateOcParams(const std::string& rendezvous_key_base, + OpKernelContext* op_kernel_context, + const TPUHostTransferInfoProto& host_transfers) { + std::unique_ptr> + oc_params(new SE_OutsideCompilationParams(), &DestroyOCParams); + const std::string& device_name = op_kernel_context->device()->name(); + oc_params->device_name = new char[device_name.size() + 1]; + std::strncpy(oc_params->device_name, device_name.c_str(), + device_name.size() + 1); + oc_params->rendezvous_key = new char[rendezvous_key_base.size() + 1]; + std::strncpy(oc_params->rendezvous_key, rendezvous_key_base.c_str(), + rendezvous_key_base.size() + 1); + oc_params->rendezvous = ToC(op_kernel_context->rendezvous()); + oc_params->host_transfers = + stream_executor::tpu::SerializeProto(host_transfers); + return oc_params; +} + } // namespace xla::StatusOr TPUExecute( @@ -538,8 +561,12 @@ xla::StatusOr TPUExecute( arguments.push_back(std::move(input)); } + std::unique_ptr> + oc_params = CreateOcParams(rendezvous_key_base, ctx, host_transfers); + auto tpu_executable = std::make_unique( - tpu_program, std::move(module), host_command_handler); + tpu_program, std::move(module), oc_params.get()); const int32_t device_ordinal = node_context->device_ordinal(); CancellationToken token; From ce15c1b8af4dd7d426d8f93524d9ed943c07542b Mon Sep 17 00:00:00 2001 From: Armando Ugalde Velasco Date: Thu, 27 Jul 2023 12:44:46 -0700 Subject: [PATCH 261/410] Use VLOG instead of LOG_EVERY_N_SEC for logging the optimal number of workers This is useful to see the history of the reported optimal number of workers when necessary. PiperOrigin-RevId: 551612633 --- tensorflow/core/data/service/auto_scaler.cc | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tensorflow/core/data/service/auto_scaler.cc b/tensorflow/core/data/service/auto_scaler.cc index 4d53bc6625d17b..8b8ad35145c1f3 100644 --- a/tensorflow/core/data/service/auto_scaler.cc +++ b/tensorflow/core/data/service/auto_scaler.cc @@ -216,9 +216,8 @@ tsl::Status MultipleIterationsAutoScaler::UpdateOptimalNumberOfWorkersMetric() "no reported processing and target processing times for at least one " "iteration"); - constexpr float FIVE_MINUTES = 60.0 * 5.0; - LOG_EVERY_N_SEC(INFO, FIVE_MINUTES) << "Estimated optimal number of workers: " - << optimal_number_of_workers.value(); + VLOG(3) << "Estimated optimal number of workers: " + << optimal_number_of_workers.value(); metrics::RecordTFDataServiceOptimalNumberOfWorkers( optimal_number_of_workers.value()); From e2815617a600664e932abd59c13e1c2243abfb79 Mon Sep 17 00:00:00 2001 From: Jieying Luo Date: Thu, 27 Jul 2023 12:46:54 -0700 Subject: [PATCH 262/410] [PJRT C] Add PJRT_Plugin_Initialize for plugin to provide initialization that will be run once when the plugin is dynamically loaded. PiperOrigin-RevId: 551613160 --- tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h | 12 ++- .../compiler/xla/pjrt/c/pjrt_c_api_cpu.cc | 3 +- .../xla/pjrt/c/pjrt_c_api_gpu_internal.cc | 3 +- .../compiler/xla/pjrt/c/pjrt_c_api_helpers.h | 13 +++ .../xla/pjrt/c/pjrt_c_api_wrapper_impl.cc | 10 +++ .../xla/pjrt/c/pjrt_c_api_wrapper_impl.h | 9 +- tensorflow/compiler/xla/pjrt/pjrt_api.cc | 9 ++ tensorflow/compiler/xla/pjrt/pjrt_api.h | 4 +- .../compiler/xla/pjrt/pjrt_c_api_client.cc | 82 +++++++++---------- 9 files changed, 94 insertions(+), 51 deletions(-) diff --git a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h index 76b01484a80757..ec9b9db52111ca 100644 --- a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h +++ b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h @@ -53,7 +53,7 @@ extern "C" { // Changes include: // * Adding a new field to the PJRT_Api or argument structs // * Renaming a method or argument (doesn't affect ABI) -#define PJRT_API_MINOR 12 +#define PJRT_API_MINOR 13 // The plugin should set the major_version and minor_version of // PJRT_Api.pjrt_api_version to be the `PJRT_API_MAJOR` and `PJRT_API_MINOR` in @@ -168,6 +168,15 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_NamedValue, value_size); // ---------------------------------- Plugin ----------------------------------- +struct PJRT_Plugin_Initialize_Args { + size_t struct_size; + void* priv; +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Plugin_Initialize_Args, priv); + +// One-time plugin setup. Must be called before any other functions are called. +typedef PJRT_Error* PJRT_Plugin_Initialize(PJRT_Plugin_Initialize_Args* args); + struct PJRT_Plugin_Attributes_Args { size_t struct_size; void* priv; @@ -1724,6 +1733,7 @@ typedef struct { _PJRT_API_STRUCT_FIELD(PJRT_Error_Message); _PJRT_API_STRUCT_FIELD(PJRT_Error_GetCode); + _PJRT_API_STRUCT_FIELD(PJRT_Plugin_Initialize); _PJRT_API_STRUCT_FIELD(PJRT_Plugin_Attributes); _PJRT_API_STRUCT_FIELD(PJRT_Event_Destroy); diff --git a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_cpu.cc b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_cpu.cc index c39bfa5c950039..38746e6bc14e08 100644 --- a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_cpu.cc +++ b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_cpu.cc @@ -53,6 +53,7 @@ PJRT_Error* PJRT_CpuDeviceTopology_Create( constexpr PJRT_Api pjrt_api = pjrt::CreatePjrtApi(pjrt::cpu_plugin::PJRT_Client_Create, - pjrt::cpu_plugin::PJRT_CpuDeviceTopology_Create); + pjrt::cpu_plugin::PJRT_CpuDeviceTopology_Create, + pjrt::PJRT_Plugin_Initialize_NoOp); const PJRT_Api* GetPjrtApi() { return &pjrt_api; } diff --git a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_gpu_internal.cc b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_gpu_internal.cc index 42f5f8700892b0..1c4dbbb6c751ed 100644 --- a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_gpu_internal.cc +++ b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_gpu_internal.cc @@ -84,7 +84,8 @@ PJRT_Error* PJRT_GpuDeviceTopology_Create( constexpr PJRT_Api pjrt_api = pjrt::CreatePjrtApi(pjrt::gpu_plugin::PJRT_Client_Create, - pjrt::gpu_plugin::PJRT_GpuDeviceTopology_Create); + pjrt::gpu_plugin::PJRT_GpuDeviceTopology_Create, + pjrt::PJRT_Plugin_Initialize_NoOp); const PJRT_Api* GetGpuPjrtApi() { return &pjrt_api; } diff --git a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_helpers.h b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_helpers.h index 6b611133527b81..ac9320d5edc8ee 100644 --- a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_helpers.h +++ b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_helpers.h @@ -34,6 +34,19 @@ ABSL_CONST_INIT extern const absl::string_view kHloFormat; ABSL_CONST_INIT extern const absl::string_view kMlirFormat; ABSL_CONST_INIT extern const absl::string_view kHloWithConfigFormat; +// Return error status if not success and frees the PJRT_Error returned by +// `expr`. +#define RETURN_STATUS_IF_PJRT_ERROR(expr, c_api) \ + do { \ + PJRT_Error* error = (expr); \ + std::unique_ptr _error( \ + error, pjrt::MakeErrorDeleter(c_api)); \ + xla::Status _status = pjrt::PjrtErrorToStatus(_error.get(), c_api); \ + if (!_status.ok()) { \ + return _status; \ + } \ + } while (false) + using PJRT_ClientDeleter = std::function; // Pass in an API pointer; receive a custom deleter for smart pointers. diff --git a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc index 335e0720b13129..d2b050567b316b 100644 --- a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc +++ b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc @@ -212,10 +212,20 @@ PJRT_Error* PJRT_Error_GetCode(PJRT_Error_GetCode_Args* args) { // ---------------------------------- Plugin ----------------------------------- PJRT_Error* PJRT_Plugin_Attributes(PJRT_Plugin_Attributes_Args* args) { + PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes( + "PJRT_Plugin_Attributes_Args", PJRT_Plugin_Attributes_Args_STRUCT_SIZE, + args->struct_size)); args->num_attributes = 0; return nullptr; } +PJRT_Error* PJRT_Plugin_Initialize_NoOp(PJRT_Plugin_Initialize_Args* args) { + PJRT_RETURN_IF_ERROR(CheckMatchingStructSizes( + "PJRT_Plugin_Initialize_Args", PJRT_Plugin_Initialize_Args_STRUCT_SIZE, + args->struct_size)); + return nullptr; +} + // ---------------------------------- Client ----------------------------------- PJRT_Error* PJRT_Client_Destroy(PJRT_Client_Destroy_Args* args) { diff --git a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h index 8d0724bcb42cdf..28353d650eb590 100644 --- a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h +++ b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_wrapper_impl.h @@ -316,11 +316,17 @@ xla::PjRtClient::KeyValueGetCallback ToCppKeyValueGetCallback( xla::PjRtClient::KeyValuePutCallback ToCppKeyValuePutCallback( PJRT_KeyValuePutCallback c_callback, void* user_arg); +// A method that does not nothing other than returning a nullptr. Can be used as +// the implementation of PJRT_Plugin_Initialize for plugins that do not require +// specific initialization. +PJRT_Error* PJRT_Plugin_Initialize_NoOp(PJRT_Plugin_Initialize_Args* args); + // Creates a PJRT_Api with create_fn from the input and other functions in // pjrt_c_api_wrapper_impl. constexpr PJRT_Api CreatePjrtApi( PJRT_Client_Create* create_fn, - PJRT_TopologyDescription_Create* topology_create_fn) { + PJRT_TopologyDescription_Create* topology_create_fn, + PJRT_Plugin_Initialize* plugin_initialize_fn) { return PJRT_Api{ /*struct_size=*/PJRT_Api_STRUCT_SIZE, /*priv=*/nullptr, @@ -335,6 +341,7 @@ constexpr PJRT_Api CreatePjrtApi( /*PJRT_Error_Message=*/pjrt::PJRT_Error_Message, /*PJRT_Error_GetCode=*/pjrt::PJRT_Error_GetCode, + /*PJRT_Plugin_Initialize=*/plugin_initialize_fn, /*PJRT_Plugin_Attributes=*/pjrt::PJRT_Plugin_Attributes, /*PJRT_Event_Destroy=*/pjrt::PJRT_Event_Destroy, diff --git a/tensorflow/compiler/xla/pjrt/pjrt_api.cc b/tensorflow/compiler/xla/pjrt/pjrt_api.cc index ebb31db4c4f334..edad6cb0ae65b0 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_api.cc +++ b/tensorflow/compiler/xla/pjrt/pjrt_api.cc @@ -74,6 +74,15 @@ xla::Status InitPjrtPlugin(PjrtApiInitFn init_fn, } TF_RETURN_IF_ERROR(pjrt::CheckMatchingStructSizes( "PJRT_Api", PJRT_Api_STRUCT_SIZE, pjrt_api->struct_size)); + if (pjrt_api->struct_size >= 592 && + (pjrt_api->pjrt_api_version.major_version > 0 || + pjrt_api->pjrt_api_version.minor_version >= 13)) { + PJRT_Plugin_Initialize_Args args; + args.struct_size = PJRT_Plugin_Initialize_Args_STRUCT_SIZE; + args.priv = nullptr; + RETURN_STATUS_IF_PJRT_ERROR(pjrt_api->PJRT_Plugin_Initialize(&args), + pjrt_api); + } return SetPjrtApi(device_type, pjrt_api); } diff --git a/tensorflow/compiler/xla/pjrt/pjrt_api.h b/tensorflow/compiler/xla/pjrt/pjrt_api.h index ec726824f45213..ee115c9c03c703 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_api.h +++ b/tensorflow/compiler/xla/pjrt/pjrt_api.h @@ -31,12 +31,12 @@ xla::Status SetPjrtApi(absl::string_view device_type, const PJRT_Api* api); // Loads a PJRT plugin. The library provided by library_path must export a // symbol called `GetPjrtApi` with function signature `const PJRT_Api* // GetPjrtApi()`. This method dlopen the plugin library, dlsym `GetPjrtApi`, -// calls `GetPjrtApi`, and `SetPjrtApi`. +// calls `GetPjrtApi`, `SetPjrtApi`, and `PJRT_Plugin_Initialize`. xla::Status LoadPjrtPlugin(absl::string_view device_type, absl::string_view library_path); // Initializes PJRT with a PjrtApiInitFn which is dynamically loaded. This -// method calls init_fn, and `SetPjrtApi`. +// method calls init_fn, `SetPjrtApi` and `PJRT_Plugin_Initialize`. typedef const PJRT_Api* (*PjrtApiInitFn)(); xla::Status InitPjrtPlugin(PjrtApiInitFn init_fn, absl::string_view device_type); diff --git a/tensorflow/compiler/xla/pjrt/pjrt_c_api_client.cc b/tensorflow/compiler/xla/pjrt/pjrt_c_api_client.cc index fd6bfffb043b00..ebad86d42f3811 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_c_api_client.cc +++ b/tensorflow/compiler/xla/pjrt/pjrt_c_api_client.cc @@ -50,19 +50,6 @@ namespace xla { // Helper macros -// Return error status if not success and frees the PJRT_Error returned by -// `expr`. -#define RETURN_STATUS_IF_ERROR(expr, c_api) \ - do { \ - PJRT_Error* error = (expr); \ - std::unique_ptr _error( \ - error, pjrt::MakeErrorDeleter(c_api)); \ - xla::Status _status = pjrt::PjrtErrorToStatus(_error.get(), c_api); \ - if (!_status.ok()) { \ - return _status; \ - } \ - } while (false) - // Return error future if not success and frees the PJRT_Error returned by // `expr`. #define RETURN_FUTURE_IF_ERROR(expr, c_api) \ @@ -197,8 +184,8 @@ StatusOr PjRtCApiClient::GetDefaultDeviceAssignment( std::vector assignment_buffer(num_replicas * num_partitions); args.default_assignment_size = assignment_buffer.size(); args.default_assignment = assignment_buffer.data(); - RETURN_STATUS_IF_ERROR(c_api_->PJRT_Client_DefaultDeviceAssignment(&args), - c_api_); + RETURN_STATUS_IF_PJRT_ERROR( + c_api_->PJRT_Client_DefaultDeviceAssignment(&args), c_api_); absl::Span param{args.default_assignment, args.default_assignment_size}; return CalculateDefaultAssignment(args.num_replicas, args.num_partitions, @@ -216,7 +203,7 @@ StatusOr PjRtCApiClient::LookupDevice(int device_id) const { args.priv = nullptr; args.client = c_client_.get(); args.id = device_id; - RETURN_STATUS_IF_ERROR(c_api_->PJRT_Client_LookupDevice(&args), c_api_); + RETURN_STATUS_IF_PJRT_ERROR(c_api_->PJRT_Client_LookupDevice(&args), c_api_); return GetCppDevice(args.device); } @@ -227,8 +214,8 @@ StatusOr PjRtCApiClient::LookupAddressableDevice( args.priv = nullptr; args.client = c_client_.get(); args.local_hardware_id = local_hardware_id; - RETURN_STATUS_IF_ERROR(c_api_->PJRT_Client_LookupAddressableDevice(&args), - c_api_); + RETURN_STATUS_IF_PJRT_ERROR( + c_api_->PJRT_Client_LookupAddressableDevice(&args), c_api_); return GetCppDevice(args.addressable_device); } @@ -257,7 +244,7 @@ static StatusOr> InitializeArgsAndCompile( program.format_size = format.size(); args.program = &program; - RETURN_STATUS_IF_ERROR(c_api->PJRT_Client_Compile(&args), c_api); + RETURN_STATUS_IF_PJRT_ERROR(c_api->PJRT_Client_Compile(&args), c_api); std::unique_ptr ret = std::make_unique(api_client, args.executable); return ret; @@ -301,8 +288,8 @@ PjRtCApiClient::DeserializeExecutable(absl::string_view serialized, const PJRT_Api* api = pjrt_c_api(); - RETURN_STATUS_IF_ERROR(api->PJRT_Executable_DeserializeAndLoad(&des_args), - api); + RETURN_STATUS_IF_PJRT_ERROR( + api->PJRT_Executable_DeserializeAndLoad(&des_args), api); PJRT_LoadedExecutable* c_exec = des_args.loaded_executable; CHECK(c_exec != nullptr); return std::unique_ptr( @@ -330,7 +317,7 @@ StatusOr PjRtCApiClient::UnsafeBufferPointer( args.buffer = tensorflow::down_cast(buffer)->c_buffer(); - RETURN_STATUS_IF_ERROR(c_api_->PJRT_Buffer_UnsafePointer(&args), c_api_); + RETURN_STATUS_IF_PJRT_ERROR(c_api_->PJRT_Buffer_UnsafePointer(&args), c_api_); return args.buffer_pointer; } @@ -381,8 +368,8 @@ StatusOr> PjRtCApiClient::BufferFromHostBuffer( ::pjrt::ConvertToPjRtHostBufferSemantics(host_buffer_semantics); args.device = tensorflow::down_cast(device)->c_device(); - RETURN_STATUS_IF_ERROR(c_api_->PJRT_Client_BufferFromHostBuffer(&args), - c_api_); + RETURN_STATUS_IF_PJRT_ERROR(c_api_->PJRT_Client_BufferFromHostBuffer(&args), + c_api_); auto buffer = std::unique_ptr( std::make_unique(this, args.buffer)); @@ -410,7 +397,8 @@ StatusOr> PjRtCApiClient::BufferFromHostBuffer( delete on_done_with_host_buffer; }; - RETURN_STATUS_IF_ERROR(c_api_->PJRT_Event_OnReady(&event_args), c_api_); + RETURN_STATUS_IF_PJRT_ERROR(c_api_->PJRT_Event_OnReady(&event_args), + c_api_); } return buffer; @@ -575,7 +563,7 @@ StatusOr PjRtCApiDevice::GetAllocatorStats() const { args.priv = nullptr; args.device = device_; const PJRT_Api* api = client_->pjrt_c_api(); - RETURN_STATUS_IF_ERROR(api->PJRT_Device_MemoryStats(&args), api); + RETURN_STATUS_IF_PJRT_ERROR(api->PJRT_Device_MemoryStats(&args), api); tsl::AllocatorStats result; result.bytes_in_use = args.bytes_in_use; @@ -696,7 +684,8 @@ PjRtCApiExecutable::GetCostAnalysis() const { // Make PJRT C API call const PJRT_Api* c_api = pjrt_c_api(); - RETURN_STATUS_IF_ERROR(c_api->PJRT_Executable_GetCostAnalysis(&args), c_api); + RETURN_STATUS_IF_PJRT_ERROR(c_api->PJRT_Executable_GetCostAnalysis(&args), + c_api); // Copy returned properties to output map return pjrt::ConvertFromPjRtNamedValueList(args.properties, @@ -717,14 +706,16 @@ PjRtCApiExecutable::GetHloModules() const { program.code = nullptr; args.program = &program; - RETURN_STATUS_IF_ERROR(c_api->PJRT_Executable_OptimizedProgram(&args), c_api); + RETURN_STATUS_IF_PJRT_ERROR(c_api->PJRT_Executable_OptimizedProgram(&args), + c_api); constexpr size_t TWO_GIBIBYTES = 2ull * 1024 * 1024 * 1024; const size_t code_size = args.program->code_size; CHECK(code_size < TWO_GIBIBYTES); std::string code(code_size, ' '); args.program->code = code.data(); - RETURN_STATUS_IF_ERROR(c_api->PJRT_Executable_OptimizedProgram(&args), c_api); + RETURN_STATUS_IF_PJRT_ERROR(c_api->PJRT_Executable_OptimizedProgram(&args), + c_api); absl::string_view program_format(program.format, program.format_size); if (program_format != ::pjrt::kHloWithConfigFormat && @@ -783,7 +774,8 @@ StatusOr PjRtCApiExecutable::SerializeExecutable() const { ser_args.executable = executable; ser_args.serialized_executable = nullptr; - RETURN_STATUS_IF_ERROR(c_api->PJRT_Executable_Serialize(&ser_args), c_api); + RETURN_STATUS_IF_PJRT_ERROR(c_api->PJRT_Executable_Serialize(&ser_args), + c_api); PJRT_SerializedExecutable* c_serialized_exec = ser_args.serialized_executable; std::unique_ptr @@ -797,8 +789,8 @@ StatusOr PjRtCApiExecutable::SerializeExecutable() const { data_args.data = nullptr; data_args.data_size = 0; - RETURN_STATUS_IF_ERROR(c_api->PJRT_SerializedExecutable_Data(&data_args), - c_api); + RETURN_STATUS_IF_PJRT_ERROR(c_api->PJRT_SerializedExecutable_Data(&data_args), + c_api); return std::string(data_args.data, data_args.data_size); } @@ -1117,7 +1109,7 @@ PjRtCApiLoadedExecutable::GetCommonExecuteArgs( numoutputs_args.struct_size = PJRT_Executable_NumOutputs_Args_STRUCT_SIZE; numoutputs_args.priv = nullptr; numoutputs_args.executable = c_executable(); - RETURN_STATUS_IF_ERROR( + RETURN_STATUS_IF_PJRT_ERROR( pjrt_c_api()->PJRT_Executable_NumOutputs(&numoutputs_args), pjrt_c_api()); size_t outer_size = args.num_devices; size_t inner_size = numoutputs_args.num_outputs; @@ -1184,8 +1176,8 @@ PjRtCApiLoadedExecutable::Execute( args.execute_device = nullptr; - RETURN_STATUS_IF_ERROR(pjrt_c_api()->PJRT_LoadedExecutable_Execute(&args), - pjrt_c_api()); + RETURN_STATUS_IF_PJRT_ERROR( + pjrt_c_api()->PJRT_LoadedExecutable_Execute(&args), pjrt_c_api()); if (device_complete_events.has_value()) { std::vector> device_complete_futures; @@ -1250,8 +1242,8 @@ PjRtCApiLoadedExecutable::ExecuteWithSingleDevice( args.execute_device = tensorflow::down_cast(device)->c_device(); - RETURN_STATUS_IF_ERROR(pjrt_c_api()->PJRT_LoadedExecutable_Execute(&args), - pjrt_c_api()); + RETURN_STATUS_IF_PJRT_ERROR( + pjrt_c_api()->PJRT_LoadedExecutable_Execute(&args), pjrt_c_api()); if (fill_future) { *returned_future = pjrt::ConvertCEventToCppFuture( @@ -1367,8 +1359,8 @@ StatusOr> PjRtCApiBuffer::logical_dimensions() { args.struct_size = PJRT_Buffer_UnpaddedDimensions_Args_STRUCT_SIZE; args.priv = nullptr; args.buffer = buffer_.get(); - RETURN_STATUS_IF_ERROR(pjrt_c_api()->PJRT_Buffer_UnpaddedDimensions(&args), - pjrt_c_api()); + RETURN_STATUS_IF_PJRT_ERROR( + pjrt_c_api()->PJRT_Buffer_UnpaddedDimensions(&args), pjrt_c_api()); return std::vector(args.unpadded_dims, args.unpadded_dims + args.num_dims); } @@ -1552,7 +1544,7 @@ StatusOr PjRtCApiBuffer::GetOnDeviceSizeInBytes() const { args.struct_size = PJRT_Buffer_OnDeviceSizeInBytes_Args_STRUCT_SIZE; args.priv = nullptr; args.buffer = buffer_.get(); - RETURN_STATUS_IF_ERROR( + RETURN_STATUS_IF_PJRT_ERROR( client_->pjrt_c_api()->PJRT_Buffer_OnDeviceSizeInBytes(&args), client_->pjrt_c_api()); @@ -1598,7 +1590,7 @@ StatusOr> PjRtCApiBuffer::CopyToDevice( args.dst_device = tensorflow::down_cast(dst_device)->c_device(); const PJRT_Api* api = pjrt_c_api(); - RETURN_STATUS_IF_ERROR(api->PJRT_Buffer_CopyToDevice(&args), api); + RETURN_STATUS_IF_PJRT_ERROR(api->PJRT_Buffer_CopyToDevice(&args), api); return std::unique_ptr( std::make_unique(client_, args.dst_buffer)); } else { @@ -1760,7 +1752,7 @@ static StatusOr> InitializeArgsAndCompileAot( program.format_size = format.size(); args.program = &program; - RETURN_STATUS_IF_ERROR(c_api->PJRT_Compile(&args), c_api); + RETURN_STATUS_IF_PJRT_ERROR(c_api->PJRT_Compile(&args), c_api); std::unique_ptr ret = std::make_unique(c_api, args.executable); return ret; @@ -1829,7 +1821,7 @@ StatusOr> GetCApiClient( device_type); } - RETURN_STATUS_IF_ERROR(c_api->PJRT_Client_Create(&init_args), c_api); + RETURN_STATUS_IF_PJRT_ERROR(c_api->PJRT_Client_Create(&init_args), c_api); PJRT_Client* c_client = init_args.client; return std::unique_ptr(std::make_unique( @@ -1853,8 +1845,8 @@ StatusOr> GetCApiTopology( init_args.num_options = c_options.size(); init_args.topology_name = topology_name.data(); init_args.topology_name_size = topology_name.size(); - RETURN_STATUS_IF_ERROR(c_api->PJRT_TopologyDescription_Create(&init_args), - c_api); + RETURN_STATUS_IF_PJRT_ERROR( + c_api->PJRT_TopologyDescription_Create(&init_args), c_api); PJRT_TopologyDescription* c_topology = init_args.topology; return std::unique_ptr( std::make_unique(c_api, c_topology)); From 00cbb02aba06d854178faec70f4ac2fdf8a4b430 Mon Sep 17 00:00:00 2001 From: Jieying Luo Date: Thu, 27 Jul 2023 12:56:24 -0700 Subject: [PATCH 263/410] [TF:PJRT] Set device ordinal in compile option to be physical device ordinal. When there are virtual devices (multiple TF devices on the same physical device), the id in device name is TF device id. These TF devices share the same stream executor. The device ordinal in compile option should be the id for the stream executor (physical device id). PiperOrigin-RevId: 551615742 --- tensorflow/compiler/jit/BUILD | 1 + .../compiler/jit/xla_compiler_options_util.cc | 12 ++- tensorflow/python/compiler/xla/BUILD | 28 +++++++ .../xla/pjrt_compile_virtual_device_test.py | 75 +++++++++++++++++++ tensorflow/tsl/framework/BUILD | 1 + tensorflow/tsl/framework/device_id_utils.cc | 24 ++++-- tensorflow/tsl/framework/device_id_utils.h | 5 ++ .../tsl/framework/device_id_utils_test.cc | 27 +++++++ 8 files changed, 167 insertions(+), 6 deletions(-) create mode 100644 tensorflow/python/compiler/xla/pjrt_compile_virtual_device_test.py diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 0e3171b028fd4b..32ac16094ef8ae 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -358,6 +358,7 @@ cc_library( "//tensorflow/core/tfrt/common:global_state", "//tensorflow/core/tfrt/common:pjrt_util", "//tensorflow/core/tpu:tpu_defs", + "//tensorflow/tsl/framework:device_id_utils", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/types:span", diff --git a/tensorflow/compiler/jit/xla_compiler_options_util.cc b/tensorflow/compiler/jit/xla_compiler_options_util.cc index 1ba962380d83fc..b170fb3cd0b4e9 100644 --- a/tensorflow/compiler/jit/xla_compiler_options_util.cc +++ b/tensorflow/compiler/jit/xla_compiler_options_util.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_compiler_options_util.h" #include "tensorflow/compiler/xla/pjrt/pjrt_client.h" +#include "tensorflow/tsl/framework/device_id_utils.h" namespace tensorflow { namespace { @@ -88,7 +89,16 @@ XlaCompiler::Options GenerateCompilerOptionsForPjRt( const DeviceBase* device_base, const XlaPlatformInfo& platform_info, const PjRtDeviceCompiler* pjrt_device_compiler) { XlaCompiler::Options options; - options.device_ordinal = device_base->parsed_name().id; + StatusOr platform_device_id = + tsl::GetPlatformDeviceIdFromDeviceParsedName( + device_base->parsed_name(), + DeviceType(tensorflow::down_cast(device_base) + ->device_type())); + if (platform_device_id.ok()) { + options.device_ordinal = *platform_device_id; + } else { + options.device_ordinal = device_base->parsed_name().id; + } options.flib_def = function_library.GetFunctionLibraryDefinition(); options.graph_def_version = function_library.graph_def_version(); if (const auto* metadata = platform_info.xla_device_metadata(); diff --git a/tensorflow/python/compiler/xla/BUILD b/tensorflow/python/compiler/xla/BUILD index 5d079abafd1270..535cf6f1842460 100644 --- a/tensorflow/python/compiler/xla/BUILD +++ b/tensorflow/python/compiler/xla/BUILD @@ -152,3 +152,31 @@ tf_py_strict_test( "//tensorflow/python/ops:variables", ], ) + +tf_py_strict_test( + name = "pjrt_compile_virtual_device_test", + srcs = ["pjrt_compile_virtual_device_test.py"], + env = { + "TF_XLA_FLAGS": "--tf_xla_use_device_api --tf_xla_enable_xla_devices --tf_xla_enable_device_api_for_gpu", + }, + python_version = "PY3", + tags = [ + "config-cuda-only", + "gpu", + "no_oss", + "requires-gpu-nvidia", + "xla", + ], + xla_enable_strict_auto_jit = False, + xla_enabled = True, + deps = [ + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:def_function", + "//tensorflow/python/eager:test", + "//tensorflow/python/framework:config", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:ops", + "//tensorflow/python/ops:math_ops", + "//tensorflow/python/ops:variables", + ], +) diff --git a/tensorflow/python/compiler/xla/pjrt_compile_virtual_device_test.py b/tensorflow/python/compiler/xla/pjrt_compile_virtual_device_test.py new file mode 100644 index 00000000000000..dde00db3b183f6 --- /dev/null +++ b/tensorflow/python/compiler/xla/pjrt_compile_virtual_device_test.py @@ -0,0 +1,75 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests virtual device compilation + execution using the Device API (aka PjRt). + +This feature is still under active development and is protected behind the +`--tf_xla_use_device_api` flag in the `TF_XLA_FLAGS` environment variable. +""" +from tensorflow.python.eager import context +from tensorflow.python.eager import def_function +from tensorflow.python.eager import test +from tensorflow.python.framework import config +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.ops import variables + + +class PjrtCompileVirtualDeviceTest(test.TestCase): + + def setUp(self): + super().setUp() + gpus = config.list_physical_devices("GPU") + config.set_logical_device_configuration( + gpus[0], + [ + context.LogicalDeviceConfiguration(memory_limit=1024), + context.LogicalDeviceConfiguration(memory_limit=1024), + context.LogicalDeviceConfiguration(memory_limit=1024), + ], + ) + + def test_xla_launch_and_tf_kernel_on_gpu_device(self): + @def_function.function(jit_compile=True) + def foo(x, y): + return x + y + 1 + + @def_function.function(jit_compile=True) + def bar(x, y): + x.assign(y) + y.assign_add([1.0, 1.0]) + + with ops.device("/device:GPU:1"): + a = constant_op.constant([1.0, 2.0]) + x = variables.Variable([0.0, 1.0]) + result_tensor = foo(x, a) + + self.assertAllClose(result_tensor.numpy(), [2.0, 4.0], atol=1e-05) + + # The following use case is tested: + # Variable updates following an XLA computation that reads the updated + # variables. + with ops.device("/device:GPU:1"): + var_a = variables.Variable([0.0, 1.0]) + var_b = variables.Variable([1.0, 2.0]) + bar(var_a, var_b) + result = foo(var_a, var_b) + + self.assertAllClose([1.0, 2.0], var_a.value(), atol=1e-05) + self.assertAllClose([2.0, 3.0], var_b.value(), atol=1e-05) + self.assertAllClose(result, [4.0, 6.0], atol=1e-05) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/tsl/framework/BUILD b/tensorflow/tsl/framework/BUILD index 53913974b91058..0999e9d87b5736 100644 --- a/tensorflow/tsl/framework/BUILD +++ b/tensorflow/tsl/framework/BUILD @@ -430,6 +430,7 @@ tsl_cc_test( "//tensorflow/tsl/lib/core:status_test_util", "//tensorflow/tsl/platform:status_matchers", "//tensorflow/tsl/platform:test_main", + "//tensorflow/tsl/protobuf:error_codes_proto_impl_cc", "//tensorflow/tsl/util:device_name_utils", ], ) diff --git a/tensorflow/tsl/framework/device_id_utils.cc b/tensorflow/tsl/framework/device_id_utils.cc index 82ee7f59a78745..8178842295e107 100644 --- a/tensorflow/tsl/framework/device_id_utils.cc +++ b/tensorflow/tsl/framework/device_id_utils.cc @@ -29,6 +29,12 @@ limitations under the License. #include "tensorflow/tsl/platform/str_util.h" namespace tsl { +namespace { +int GetTfDeviceIdFromDeviceParsedName( + const DeviceNameUtils::ParsedName& device_name) { + return device_name.id; +} +} // namespace void CheckValidTfDeviceId(const DeviceType& type, const int visible_device_count, @@ -117,20 +123,28 @@ StatusOr GetNumberTfDevicesAndConfigurePlatformDeviceId( return num_tf_devices; } -StatusOr GetDeviceIdFromDeviceParsedName( +StatusOr GetPlatformDeviceIdFromDeviceParsedName( const DeviceNameUtils::ParsedName& device_name, const DeviceType& device_type) { - const TfDeviceId tf_device_id(device_name.id); + const TfDeviceId tf_device_id(GetTfDeviceIdFromDeviceParsedName(device_name)); PlatformDeviceId platform_device_id; Status platform_id_status = DeviceIdManager::TfToPlatformDeviceId( device_type, tf_device_id, &platform_device_id); if (platform_id_status.ok()) { return platform_device_id.value(); } - if (absl::IsNotFound(platform_id_status)) { - return tf_device_id.value(); - } return platform_id_status; } +StatusOr GetDeviceIdFromDeviceParsedName( + const DeviceNameUtils::ParsedName& device_name, + const DeviceType& device_type) { + auto platform_id = + GetPlatformDeviceIdFromDeviceParsedName(device_name, device_type); + if (platform_id.ok()) { + return *platform_id; + } + return GetTfDeviceIdFromDeviceParsedName(device_name); +} + } // namespace tsl diff --git a/tensorflow/tsl/framework/device_id_utils.h b/tensorflow/tsl/framework/device_id_utils.h index 9f361e8b4358a4..9b270437cb2808 100644 --- a/tensorflow/tsl/framework/device_id_utils.h +++ b/tensorflow/tsl/framework/device_id_utils.h @@ -56,6 +56,11 @@ StatusOr GetNumberTfDevicesAndConfigurePlatformDeviceId( absl::string_view device_type, absl::string_view visible_device_list, int visible_device_count); +StatusOr GetPlatformDeviceIdFromDeviceParsedName( + const DeviceNameUtils::ParsedName& device_name, + const DeviceType& device_type); + +// TODO(b/293324740): support virtual devices. // Returns the corresponding PlatformDeviceId if it is found. Otherwise returns // the id in device_name. StatusOr GetDeviceIdFromDeviceParsedName( diff --git a/tensorflow/tsl/framework/device_id_utils_test.cc b/tensorflow/tsl/framework/device_id_utils_test.cc index 7bd023778f7632..6f5cb3dd963df7 100644 --- a/tensorflow/tsl/framework/device_id_utils_test.cc +++ b/tensorflow/tsl/framework/device_id_utils_test.cc @@ -147,6 +147,33 @@ TEST(DeviceIdUtilsTest, GetNumberTfDevicesWithSessionOptionDeviceCount) { DeviceIdManager::TestOnlyReset(); } +TEST(DeviceIdUtilsTest, GetPlatformDeviceId) { + TfDeviceId tf_device_id(0); + PlatformDeviceId platform_device_id(1); + TF_EXPECT_OK(DeviceIdManager::InsertTfPlatformDeviceIdPair( + DeviceType(kTestDeviceType), tf_device_id, platform_device_id)); + DeviceNameUtils::ParsedName device_name; + device_name.id = 0; + + TF_ASSERT_OK_AND_ASSIGN(int device_id, + GetPlatformDeviceIdFromDeviceParsedName( + device_name, DeviceType(kTestDeviceType))); + + EXPECT_EQ(device_id, 1); + DeviceIdManager::TestOnlyReset(); +} + +TEST(DeviceIdUtilsTest, GetPlatformDeviceIdNotFound) { + DeviceNameUtils::ParsedName device_name; + device_name.id = 0; + + EXPECT_THAT( + GetPlatformDeviceIdFromDeviceParsedName(device_name, + DeviceType(kTestDeviceType)), + StatusIs(tensorflow::error::NOT_FOUND, + HasSubstr("TensorFlow device CPU:0 was not registered"))); +} + TEST(DeviceIdUtilsTest, GetDeviceIdWithPlatformDeviceId) { TfDeviceId tf_device_id(0); PlatformDeviceId platform_device_id(1); From dd21cba610ad9d7f72aea4a0aaf99fd11ae3c881 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 27 Jul 2023 12:59:32 -0700 Subject: [PATCH 264/410] Update TFRT dependency to use revision http://github.com/tensorflow/runtime/commit/81aca58669e2db2c53168db65ecdf40add3dcd2d. PiperOrigin-RevId: 551616539 --- third_party/tf_runtime/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/tf_runtime/workspace.bzl b/third_party/tf_runtime/workspace.bzl index 1639568a63e79e..f5a5e9f61cab3c 100644 --- a/third_party/tf_runtime/workspace.bzl +++ b/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "3bf6c17968a52aea580c5398bbcfc0cf0e069dc5" - TFRT_SHA256 = "25c973d5ea4cdf6a0762fb3dead5162339ba7fa00f67cc2681e55f6da6c796ab" + TFRT_COMMIT = "81aca58669e2db2c53168db65ecdf40add3dcd2d" + TFRT_SHA256 = "f9738cfe7e65b03dabdc5ba61da39a93bdf3719de48193f303e3bf475da813a4" tf_http_archive( name = "tf_runtime", From 2c8505a0a58cdba033d471324b9b9061663d349c Mon Sep 17 00:00:00 2001 From: Marc Fisher Date: Thu, 27 Jul 2023 13:06:22 -0700 Subject: [PATCH 265/410] Fix comments in .bazelrc Change // to # --- .bazelrc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.bazelrc b/.bazelrc index 21d45f9991ff44..a1286bf2b58ec3 100644 --- a/.bazelrc +++ b/.bazelrc @@ -415,7 +415,7 @@ build:rbe --google_default_credentials build:rbe --bes_backend=buildeventservice.googleapis.com build:rbe --bes_results_url="https://source.cloud.google.com/results/invocations" build:rbe --bes_timeout=600s -// TODO(b/290857564): Remove from mainline after 2.14 branch cut +# TODO(b/290857564): Remove from mainline after 2.14 branch cut build:rbe --flaky_test_attempts=3 build:rbe --define=EXECUTOR=remote build:rbe --jobs=800 @@ -557,7 +557,7 @@ try-import %workspace%/.bazelrc.user # Here are bazelrc configs for release builds build:release_base --config=v2 test:release_base --test_size_filters=small,medium -// TODO(b/290857564): Remove from mainline after 2.14 branch cut +# TODO(b/290857564): Remove from mainline after 2.14 branch cut test:release_base --flaky_test_attempts=3 build:release_cpu_linux --config=release_base From 141bb0834c6eecfe54233092b0b3acaa53801eee Mon Sep 17 00:00:00 2001 From: Anlun Xu Date: Thu, 27 Jul 2023 13:00:11 -0700 Subject: [PATCH 266/410] [xla:gpu] Improve the mechanism for terminating the concurrent region If the current region is terminated because of reaching an operation that introduces an dependency, we should add the operation to a new empty region. PiperOrigin-RevId: 551616712 --- .../gpu/transforms/add_concurrent_regions.cc | 22 +++++++++++++------ 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/add_concurrent_regions.cc b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/add_concurrent_regions.cc index 6c2b840181795e..17b7c57068e6c1 100644 --- a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/add_concurrent_regions.cc +++ b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/add_concurrent_regions.cc @@ -142,6 +142,16 @@ llvm::SmallVector GetRegionInfos( region.clear(); }; + auto append_node_to_region = [&](const DataflowAnalysis::Node& node) { + if (region.empty()) { + if (!IsNoOp(node.operation)) { + region.push_back(node); + } + } else { + region.push_back(node); + } + }; + for (const DataflowAnalysis::Node& node : dataflow_graph) { if (isa(node.operation)) { break; @@ -157,15 +167,13 @@ llvm::SmallVector GetRegionInfos( } } - if (has_dependency || IsKernelMemoryBound(node.operation)) { + if (IsKernelMemoryBound(node.operation)) { + store_region_and_start_new_region(); + } else if (has_dependency) { store_region_and_start_new_region(); + append_node_to_region(node); } else { - // No dependency with the current region. - if (region.empty() && IsNoOp(node.operation)) { - continue; - } else { - region.push_back(node); - } + append_node_to_region(node); } } From 17f11ef56e83a709e7b376fa4d88f08068e504fa Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 27 Jul 2023 13:20:14 -0700 Subject: [PATCH 267/410] Move `TopologicalIterator` to a separate file for reusing. PiperOrigin-RevId: 551622104 --- tensorflow/dtensor/mlir/BUILD | 20 +++++ tensorflow/dtensor/mlir/sparse_expansion.cc | 1 + .../dtensor/mlir/spmd_expander_common.cc | 56 ------------- .../dtensor/mlir/spmd_expander_common.h | 34 -------- tensorflow/dtensor/mlir/spmd_expansion.cc | 1 + .../dtensor/mlir/topological_iterator.cc | 81 +++++++++++++++++++ .../dtensor/mlir/topological_iterator.h | 63 +++++++++++++++ 7 files changed, 166 insertions(+), 90 deletions(-) create mode 100644 tensorflow/dtensor/mlir/topological_iterator.cc create mode 100644 tensorflow/dtensor/mlir/topological_iterator.h diff --git a/tensorflow/dtensor/mlir/BUILD b/tensorflow/dtensor/mlir/BUILD index 2e7f1a66247b47..492acc148b08f6 100644 --- a/tensorflow/dtensor/mlir/BUILD +++ b/tensorflow/dtensor/mlir/BUILD @@ -228,6 +228,7 @@ cc_library( ":spmd_expander", ":spmd_expander_common", ":tf_dtensor_dialect", + ":topological_iterator", ":value_utils", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:attribute_utils", @@ -505,6 +506,25 @@ cc_library( alwayslink = True, ) +cc_library( + name = "topological_iterator", + srcs = ["topological_iterator.cc"], + hdrs = ["topological_iterator.h"], + deps = [ + ":device_utils", + ":layout_parsing", + ":op_utils", + ":shape_utils", + ":tf_dtensor_dialect", + ":value_utils", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + ], +) + tf_cc_test( name = "dtensor_location_test", srcs = ["dtensor_location_test.cc"], diff --git a/tensorflow/dtensor/mlir/sparse_expansion.cc b/tensorflow/dtensor/mlir/sparse_expansion.cc index a88ccf5551c3c8..16d63399fe6931 100644 --- a/tensorflow/dtensor/mlir/sparse_expansion.cc +++ b/tensorflow/dtensor/mlir/sparse_expansion.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h" #include "tensorflow/dtensor/mlir/sparse_expander.h" #include "tensorflow/dtensor/mlir/spmd_expander_common.h" +#include "tensorflow/dtensor/mlir/topological_iterator.h" namespace tensorflow { namespace dtensor { diff --git a/tensorflow/dtensor/mlir/spmd_expander_common.cc b/tensorflow/dtensor/mlir/spmd_expander_common.cc index 1b1fd774ef32ac..7f411c4da74c73 100644 --- a/tensorflow/dtensor/mlir/spmd_expander_common.cc +++ b/tensorflow/dtensor/mlir/spmd_expander_common.cc @@ -24,10 +24,7 @@ limitations under the License. #include #include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/FormatVariadic.h" -#include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project @@ -38,13 +35,8 @@ limitations under the License. #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/OperationSupport.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project -#include "mlir/Support/DebugStringHelper.h" // from @llvm-project -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" -#include "tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h" -#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" -#include "tensorflow/compiler/xla/mlir_hlo/utils/convert_op_folder.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/dtensor/cc/constants.h" #include "tensorflow/dtensor/cc/tensor_layout.h" @@ -759,53 +751,5 @@ StatusOr ExtractConstScalarStringFromValue(mlir::Value value) { return std::string(*attr.getRawStringData().begin()); } -TopologicalIterator::TopologicalIterator(mlir::func::FuncOp main_func) - : ops_to_visit_{&main_func.front().front()} { - funcs_visited_.insert(main_func.getName()); - funcs_visited_in_call_stack_.insert(main_func.getName()); -} - -mlir::Operation* TopologicalIterator::next() { - if (!hasNext()) return nullptr; - - auto* op = ops_to_visit_.pop_back_val(); - auto* next_op = op->getNextNode(); - if (next_op) ops_to_visit_.push_back(next_op); - - // If this is a function call op, push the first op of the function body so - // that the function body is converted before the call site. - std::optional func = MaybeFindFunction(op); - if (func.has_value()) { - mlir::StringRef func_name = func->getName(); - - if (funcs_visited_.contains(func_name)) return next(); - - ops_to_visit_.push_back(&(func->front().front())); - funcs_visited_.insert(func_name); - } - - // If we have reached the end of a function body, remove the function from - // our active set. - if (!next_op && !funcs_visited_in_call_stack_.empty()) - if (auto func = op->getParentOfType()) - funcs_visited_in_call_stack_.erase(func.getName()); - - if (auto cluster_op = mlir::dyn_cast(op)) - ops_to_visit_.push_back(&cluster_op.GetBody().front()); - - if (auto while_op = mlir::dyn_cast(op)) { - ops_to_visit_.push_back(&while_op.getCond().front().front()); - ops_to_visit_.push_back(&while_op.getBody().front().front()); - } - - if (auto if_op = mlir::dyn_cast(op)) { - ops_to_visit_.push_back(&if_op.getThenBranch().front().front()); - ops_to_visit_.push_back(&if_op.getElseBranch().front().front()); - } - return op; -} - -bool TopologicalIterator::hasNext() { return !ops_to_visit_.empty(); } - } // namespace dtensor } // namespace tensorflow diff --git a/tensorflow/dtensor/mlir/spmd_expander_common.h b/tensorflow/dtensor/mlir/spmd_expander_common.h index ec1d52ba6203f2..3d7bb181150f71 100644 --- a/tensorflow/dtensor/mlir/spmd_expander_common.h +++ b/tensorflow/dtensor/mlir/spmd_expander_common.h @@ -20,7 +20,6 @@ limitations under the License. #include #include -#include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallPtrSet.h" @@ -32,7 +31,6 @@ limitations under the License. #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project -#include "mlir/IR/Visitors.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h" @@ -179,38 +177,6 @@ Status ExtractConstStringVectorFromValue( StatusOr ExtractConstScalarStringFromValue(mlir::Value value); -// A general Iterator that visits a FuncOp's body in topological order. Note -// that this does not visit the given FuncOp itself. Function ops are visited -// exactly once if functions are used in multiple call sites. -// -// An example usage of this Iterator is for SPMD Expansion or Sparse -// Expansion, where we expand ops in topological order starting from the -// `main` FuncOp, only visiting function ops once so that we don't expand -// multiple times. -class TopologicalIterator { - public: - explicit TopologicalIterator(mlir::func::FuncOp main_func); - - // Returns whether there is any further ops to visit. - bool hasNext(); - - // Returns the next op to visit in the topological ordering. Returns - // a nullptr if there is no next op to visit. - mlir::Operation* next(); - - private: - // Stack to keep track of ops to visit. - llvm::SmallVector ops_to_visit_; - - // Keep track of functions we are walking, this is needed to avoid recursive - // function calls. - llvm::SmallDenseSet funcs_visited_in_call_stack_; - - // Keep track of all visit functions. This is to guarantee that - // functions are visited exactly once if functions are used in multiple - // callsites. - llvm::SmallDenseSet funcs_visited_; -}; } // namespace dtensor } // namespace tensorflow diff --git a/tensorflow/dtensor/mlir/spmd_expansion.cc b/tensorflow/dtensor/mlir/spmd_expansion.cc index c2dc62d17e2032..f512edd981ebfd 100644 --- a/tensorflow/dtensor/mlir/spmd_expansion.cc +++ b/tensorflow/dtensor/mlir/spmd_expansion.cc @@ -53,6 +53,7 @@ limitations under the License. #include "tensorflow/dtensor/mlir/op_utils.h" #include "tensorflow/dtensor/mlir/spmd_expander.h" #include "tensorflow/dtensor/mlir/spmd_expander_common.h" +#include "tensorflow/dtensor/mlir/topological_iterator.h" namespace tensorflow { namespace dtensor { diff --git a/tensorflow/dtensor/mlir/topological_iterator.cc b/tensorflow/dtensor/mlir/topological_iterator.cc new file mode 100644 index 00000000000000..e78cad677745c8 --- /dev/null +++ b/tensorflow/dtensor/mlir/topological_iterator.cc @@ -0,0 +1,81 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/dtensor/mlir/topological_iterator.h" + +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/dtensor/mlir/op_utils.h" + +namespace tensorflow { +namespace dtensor { + +TopologicalIterator::TopologicalIterator(mlir::func::FuncOp main_func) + : ops_to_visit_{&main_func.front().front()} { + funcs_visited_.insert(main_func.getName()); + funcs_visited_in_call_stack_.insert(main_func.getName()); +} + +mlir::Operation* TopologicalIterator::next() { + if (!hasNext()) return nullptr; + + auto* op = ops_to_visit_.pop_back_val(); + auto* next_op = op->getNextNode(); + if (next_op) ops_to_visit_.push_back(next_op); + + // If this is a function call op, push the first op of the function body so + // that the function body is converted before the call site. + std::optional func = MaybeFindFunction(op); + if (func.has_value()) { + mlir::StringRef func_name = func->getName(); + + if (funcs_visited_.contains(func_name)) return next(); + + ops_to_visit_.push_back(&(func->front().front())); + funcs_visited_.insert(func_name); + } + + // If we have reached the end of a function body, remove the function from + // our active set. + if (!next_op && !funcs_visited_in_call_stack_.empty()) + if (auto func = op->getParentOfType()) + funcs_visited_in_call_stack_.erase(func.getName()); + + if (auto cluster_op = mlir::dyn_cast(op)) + ops_to_visit_.push_back(&cluster_op.GetBody().front()); + + if (auto while_op = mlir::dyn_cast(op)) { + ops_to_visit_.push_back(&while_op.getCond().front().front()); + ops_to_visit_.push_back(&while_op.getBody().front().front()); + } + + if (auto if_op = mlir::dyn_cast(op)) { + ops_to_visit_.push_back(&if_op.getThenBranch().front().front()); + ops_to_visit_.push_back(&if_op.getElseBranch().front().front()); + } + return op; +} + +bool TopologicalIterator::hasNext() { return !ops_to_visit_.empty(); } + +} // namespace dtensor +} // namespace tensorflow diff --git a/tensorflow/dtensor/mlir/topological_iterator.h b/tensorflow/dtensor/mlir/topological_iterator.h new file mode 100644 index 00000000000000..dce7d6e115b1f8 --- /dev/null +++ b/tensorflow/dtensor/mlir/topological_iterator.h @@ -0,0 +1,63 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_DTENSOR_MLIR_TOPOLOGICAL_ITERATOR_H_ +#define TENSORFLOW_DTENSOR_MLIR_TOPOLOGICAL_ITERATOR_H_ + +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project + +namespace tensorflow { +namespace dtensor { + +// A general Iterator that visits a FuncOp's body in topological order. Note +// that this does not visit the given FuncOp itself. Function ops are visited +// exactly once if functions are used in multiple call sites. +// +// An example usage of this Iterator is for SPMD Expansion or Sparse +// Expansion, where we expand ops in topological order starting from the +// `main` FuncOp, only visiting function ops once so that we don't expand +// multiple times. +class TopologicalIterator { + public: + explicit TopologicalIterator(mlir::func::FuncOp main_func); + + // Returns whether there is any further ops to visit. + bool hasNext(); + + // Returns the next op to visit in the topological ordering. Returns + // a nullptr if there is no next op to visit. + mlir::Operation* next(); + + private: + // Stack to keep track of ops to visit. + llvm::SmallVector ops_to_visit_; + + // Keep track of functions we are walking, this is needed to avoid recursive + // function calls. + llvm::SmallDenseSet funcs_visited_in_call_stack_; + + // Keep track of all visit functions. This is to guarantee that + // functions are visited exactly once if functions are used in multiple + // callsites. + llvm::SmallDenseSet funcs_visited_; +}; + +} // namespace dtensor +} // namespace tensorflow + +#endif // TENSORFLOW_DTENSOR_MLIR_TOPOLOGICAL_ITERATOR_H_ From e9ab84152140061f10aa13ad5676a46f218ab87a Mon Sep 17 00:00:00 2001 From: Fiona Lang Date: Thu, 27 Jul 2023 13:42:06 -0700 Subject: [PATCH 268/410] Move platform imports in python/__init__.py to modules_with_exports.py. PiperOrigin-RevId: 551628382 --- tensorflow/python/BUILD | 14 +++++++------- tensorflow/python/__init__.py | 9 --------- tensorflow/python/modules_with_exports.py | 9 +++++++++ tensorflow/python/ops/BUILD | 9 +++++++++ tensorflow/python/ops/standard_ops.py | 9 +++++++++ tensorflow/python/tools/api/generator/doc_srcs.py | 10 +++++----- 6 files changed, 39 insertions(+), 21 deletions(-) diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index da9aa131bea93f..8a8a2cf2306342 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -80,13 +80,6 @@ py_library( "//tensorflow/python/ops:gradient_checker_v2", "//tensorflow/python/ops:stateful_random_ops", "//tensorflow/python/ops/structured:structured_ops", - "//tensorflow/python/platform:app", - "//tensorflow/python/platform:client_testlib", - "//tensorflow/python/platform:flags", - "//tensorflow/python/platform:gfile", - "//tensorflow/python/platform:resource_loader", - "//tensorflow/python/platform:sysconfig", - "//tensorflow/python/platform:tf_logging", "//tensorflow/python/tpu:tpu_estimator", ], ) @@ -409,6 +402,13 @@ py_library( "//tensorflow/python/ops/ragged:ragged_ops", "//tensorflow/python/ops/signal", "//tensorflow/python/ops/structured:structured_ops", + "//tensorflow/python/platform:app", + "//tensorflow/python/platform:client_testlib", + "//tensorflow/python/platform:flags", + "//tensorflow/python/platform:gfile", + "//tensorflow/python/platform:resource_loader", + "//tensorflow/python/platform:sysconfig", + "//tensorflow/python/platform:tf_logging", "//tensorflow/python/profiler", "//tensorflow/python/profiler:profiler_client", "//tensorflow/python/profiler:profiler_v2", diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py index 467ea20e0b30d9..7c4fe150e6e6ba 100644 --- a/tensorflow/python/__init__.py +++ b/tensorflow/python/__init__.py @@ -47,15 +47,6 @@ # Sub-package for performing i/o directly instead of via ops in a graph. from tensorflow.python.lib.io import python_io -# Make some application and test modules available. -from tensorflow.python.platform import app -from tensorflow.python.platform import flags -from tensorflow.python.platform import gfile -from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.platform import resource_loader -from tensorflow.python.platform import sysconfig as sysconfig_lib -from tensorflow.python.platform import test - from tensorflow.python.compat import v2_compat # Special dunders that we choose to export: diff --git a/tensorflow/python/modules_with_exports.py b/tensorflow/python/modules_with_exports.py index a20af60d562451..72a6fd136a301d 100644 --- a/tensorflow/python/modules_with_exports.py +++ b/tensorflow/python/modules_with_exports.py @@ -122,6 +122,15 @@ from tensorflow.python.ops.signal import signal from tensorflow.python.ops.structured import structured_ops as _structured_ops +# Platform +from tensorflow.python.platform import app +from tensorflow.python.platform import flags +from tensorflow.python.platform import gfile +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.platform import resource_loader +from tensorflow.python.platform import sysconfig as sysconfig_lib +from tensorflow.python.platform import test + # Update the RaggedTensor package docs w/ a list of ops that support dispatch. ragged.__doc__ += ragged_ops.ragged_dispatch.ragged_op_list() diff --git a/tensorflow/python/ops/BUILD b/tensorflow/python/ops/BUILD index 4d4561dbf3473a..d27222a324747c 100644 --- a/tensorflow/python/ops/BUILD +++ b/tensorflow/python/ops/BUILD @@ -2864,8 +2864,17 @@ py_strict_library( "//tensorflow/python/ops/linalg:matmul_registrations", "//tensorflow/python/ops/linalg:solve_registrations", "//tensorflow/python/ops/parallel_for:control_flow_ops", + "//tensorflow/python/ops/ragged:ragged_batch_gather_ops", + "//tensorflow/python/ops/ragged:ragged_batch_gather_with_default_op", + "//tensorflow/python/ops/ragged:ragged_bincount_ops", + "//tensorflow/python/ops/ragged:ragged_check_ops", + "//tensorflow/python/ops/ragged:ragged_conversion_ops", "//tensorflow/python/ops/ragged:ragged_dispatch", + "//tensorflow/python/ops/ragged:ragged_embedding_ops", + "//tensorflow/python/ops/ragged:ragged_image_ops", "//tensorflow/python/ops/ragged:ragged_operators", + "//tensorflow/python/ops/ragged:ragged_squeeze_op", + "//tensorflow/python/ops/ragged:ragged_string_ops", ], ) diff --git a/tensorflow/python/ops/standard_ops.py b/tensorflow/python/ops/standard_ops.py index 05729d2f2b8288..9b301f06f767b9 100644 --- a/tensorflow/python/ops/standard_ops.py +++ b/tensorflow/python/ops/standard_ops.py @@ -78,8 +78,17 @@ from tensorflow.python.ops.parsing_ops import * from tensorflow.python.ops.partitioned_variables import * from tensorflow.python.ops.proto_ops import * +from tensorflow.python.ops.ragged import ragged_batch_gather_ops +from tensorflow.python.ops.ragged import ragged_batch_gather_with_default_op +from tensorflow.python.ops.ragged import ragged_bincount_ops +from tensorflow.python.ops.ragged import ragged_check_ops +from tensorflow.python.ops.ragged import ragged_conversion_ops from tensorflow.python.ops.ragged import ragged_dispatch as _ragged_dispatch +from tensorflow.python.ops.ragged import ragged_embedding_ops +from tensorflow.python.ops.ragged import ragged_image_ops from tensorflow.python.ops.ragged import ragged_operators as _ragged_operators +from tensorflow.python.ops.ragged import ragged_squeeze_op +from tensorflow.python.ops.ragged import ragged_string_ops from tensorflow.python.ops.random_ops import * from tensorflow.python.ops.script_ops import py_func from tensorflow.python.ops.session_ops import * diff --git a/tensorflow/python/tools/api/generator/doc_srcs.py b/tensorflow/python/tools/api/generator/doc_srcs.py index ab3d6e78ad4bf6..43c468e98d6fe1 100644 --- a/tensorflow/python/tools/api/generator/doc_srcs.py +++ b/tensorflow/python/tools/api/generator/doc_srcs.py @@ -37,7 +37,7 @@ def __init__(self, docstring=None, docstring_module_name=None): _TENSORFLOW_DOC_SOURCES = { 'app': - DocSource(docstring_module_name='platform.app'), + DocSource(docstring='Import router for absl.app.'), 'bitwise': DocSource(docstring_module_name='ops.bitwise_ops'), 'compat': @@ -52,7 +52,7 @@ def __init__(self, docstring=None, docstring_module_name=None): 'experimental.numpy': DocSource(docstring_module_name='ops.numpy_ops'), 'gfile': - DocSource(docstring_module_name='platform.gfile'), + DocSource(docstring='Import router for file_io.'), 'graph_util': DocSource(docstring_module_name='framework.graph_util'), 'image': @@ -80,7 +80,7 @@ def __init__(self, docstring=None, docstring_module_name=None): 'ragged': DocSource(docstring_module_name='ops.ragged'), 'resource_loader': - DocSource(docstring_module_name='platform.resource_loader'), + DocSource(docstring='Resource management library.'), 'sets': DocSource(docstring_module_name='ops.sets'), 'signal': @@ -92,9 +92,9 @@ def __init__(self, docstring=None, docstring_module_name=None): 'summary': DocSource(docstring_module_name='summary.summary'), 'sysconfig': - DocSource(docstring_module_name='platform.sysconfig'), + DocSource(docstring='System configuration library.'), 'test': - DocSource(docstring_module_name='platform.test'), + DocSource(docstring='Testing.'), 'train': DocSource( docstring=( 'Support for training models. See the' From b6946051acf429de68546763ab43e16368559442 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 27 Jul 2023 13:44:42 -0700 Subject: [PATCH 269/410] Add type annotations to the return values of convert_to_tensor_v1* and convert_to_tensor_v2*. PiperOrigin-RevId: 551629182 --- .../python/framework/tensor_conversion.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/framework/tensor_conversion.py b/tensorflow/python/framework/tensor_conversion.py index 25a89db9c80206..0098d6c3d06873 100644 --- a/tensorflow/python/framework/tensor_conversion.py +++ b/tensorflow/python/framework/tensor_conversion.py @@ -13,15 +13,22 @@ # limitations under the License. # ============================================================================== """Tensor conversion functions.""" +import typing + from tensorflow.python.framework import tensor_conversion_registry from tensorflow.python.util import deprecation from tensorflow.python.util import dispatch from tensorflow.python.util import tf_export +if typing.TYPE_CHECKING: + # pylint: disable=g-bad-import-order + from tensorflow.python.framework import tensor as tensor_lib + # pylint: enable=g-bad-import-order + def convert_to_tensor_v1( value, dtype=None, name=None, preferred_dtype=None, dtype_hint=None -): +) -> "tensor_lib.Tensor": """Converts the given `value` to a `Tensor` (with the TF1 API).""" preferred_dtype = deprecation.deprecated_argument_lookup( "dtype_hint", dtype_hint, "preferred_dtype", preferred_dtype @@ -33,7 +40,7 @@ def convert_to_tensor_v1( @dispatch.add_dispatch_support def convert_to_tensor_v1_with_dispatch( value, dtype=None, name=None, preferred_dtype=None, dtype_hint=None -): +) -> "tensor_lib.Tensor": """Converts the given `value` to a `Tensor`. This function converts Python objects of various types to `Tensor` @@ -96,7 +103,7 @@ def my_func(arg): @dispatch.add_dispatch_support def convert_to_tensor_v2_with_dispatch( value, dtype=None, dtype_hint=None, name=None -): +) -> "tensor_lib.Tensor": """Converts the given `value` to a `Tensor`. This function converts Python objects of various types to `Tensor` @@ -162,7 +169,9 @@ def convert_to_tensor_v2_with_dispatch( ) -def convert_to_tensor_v2(value, dtype=None, dtype_hint=None, name=None): +def convert_to_tensor_v2( + value, dtype=None, dtype_hint=None, name=None +) -> "tensor_lib.Tensor": """Converts the given `value` to a `Tensor`.""" # preferred_dtype = preferred_dtype or dtype_hint return tensor_conversion_registry.convert( From f4d81e86534ea8a3f5ebed7c2a2b2107564c6bcc Mon Sep 17 00:00:00 2001 From: Armando Ugalde Velasco Date: Thu, 27 Jul 2023 13:53:21 -0700 Subject: [PATCH 270/410] Register Iterations automatically when reporting a time In MultipleIterationsAutoScaler, an Iteration had to be registered before starting to report processing or target processing times. However, there can be cases where the dispatcher is restarted but the Iteration is not registered again, causing the times' reporting to fail. To solve this, we allow reporting times even if the specified Iteration has not been registered before. If it does not exist already, we create one. PiperOrigin-RevId: 551631559 --- tensorflow/core/data/service/auto_scaler.cc | 29 +++---- tensorflow/core/data/service/auto_scaler.h | 27 +++--- .../core/data/service/auto_scaler_test.cc | 85 ++++--------------- .../core/data/service/dispatcher_impl.cc | 5 -- 4 files changed, 40 insertions(+), 106 deletions(-) diff --git a/tensorflow/core/data/service/auto_scaler.cc b/tensorflow/core/data/service/auto_scaler.cc index 8b8ad35145c1f3..5a0e2d497cdddf 100644 --- a/tensorflow/core/data/service/auto_scaler.cc +++ b/tensorflow/core/data/service/auto_scaler.cc @@ -186,14 +186,11 @@ tsl::Status AutoScaler::RemoveConsumer(int64_t consumer_id) return tsl::OkStatus(); } -tsl::Status MultipleIterationsAutoScaler::RegisterIteration( - int64_t iteration_id) TF_LOCKS_EXCLUDED(mu_) { - tsl::mutex_lock l(mu_); - if (auto_scalers_.contains(iteration_id)) - return absl::AlreadyExistsError(absl::StrCat( - "AutoScaler for iteration_id ", iteration_id, " already exists")); - auto_scalers_[iteration_id] = std::make_unique(); - return tsl::OkStatus(); +void MultipleIterationsAutoScaler::EnsureIterationIsRegistered( + int64_t iteration_id) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (!auto_scalers_.contains(iteration_id)) { + auto_scalers_[iteration_id] = std::make_unique(); + } } tsl::Status MultipleIterationsAutoScaler::UnregisterIteration( @@ -248,10 +245,8 @@ std::optional MultipleIterationsAutoScaler::GetOptimalNumberOfWorkers() tsl::Status MultipleIterationsAutoScaler::ReportProcessingTime( int64_t iteration_id, const std::string& worker_address, absl::Duration processing_time) TF_LOCKS_EXCLUDED(mu_) { - tsl::tf_shared_lock l(mu_); - if (!auto_scalers_.contains(iteration_id)) - return absl::NotFoundError(absl::StrCat( - "Could not find AutoScaler for iteration_id ", iteration_id)); + tsl::mutex_lock l(mu_); + EnsureIterationIsRegistered(iteration_id); tsl::Status status = auto_scalers_[iteration_id]->ReportProcessingTime( worker_address, processing_time); @@ -261,10 +256,8 @@ tsl::Status MultipleIterationsAutoScaler::ReportProcessingTime( tsl::Status MultipleIterationsAutoScaler::ReportTargetProcessingTime( int64_t iteration_id, int64_t consumer_id, absl::Duration target_processing_time) TF_LOCKS_EXCLUDED(mu_) { - tsl::tf_shared_lock l(mu_); - if (!auto_scalers_.contains(iteration_id)) - return absl::NotFoundError(absl::StrCat( - "Could not find AutoScaler for iteration_id ", iteration_id)); + tsl::mutex_lock l(mu_); + EnsureIterationIsRegistered(iteration_id); tsl::Status status = auto_scalers_[iteration_id]->ReportTargetProcessingTime( consumer_id, target_processing_time); @@ -277,7 +270,7 @@ tsl::Status MultipleIterationsAutoScaler::RemoveWorker( tsl::tf_shared_lock l(mu_); if (!auto_scalers_.contains(iteration_id)) return absl::NotFoundError(absl::StrCat( - "Could not find AutoScaler for iteration_id ", iteration_id)); + "There are no reported times for iteration_id ", iteration_id)); tsl::Status status = auto_scalers_[iteration_id]->RemoveWorker(worker_address); @@ -290,7 +283,7 @@ tsl::Status MultipleIterationsAutoScaler::RemoveConsumer(int64_t iteration_id, tsl::tf_shared_lock l(mu_); if (!auto_scalers_.contains(iteration_id)) return absl::NotFoundError(absl::StrCat( - "Could not find AutoScaler for iteration_id ", iteration_id)); + "There are no reported times for iteration_id ", iteration_id)); tsl::Status status = auto_scalers_[iteration_id]->RemoveConsumer(consumer_id); return status; diff --git a/tensorflow/core/data/service/auto_scaler.h b/tensorflow/core/data/service/auto_scaler.h index ba7d2d019e8cdc..ee68d2226a6baf 100644 --- a/tensorflow/core/data/service/auto_scaler.h +++ b/tensorflow/core/data/service/auto_scaler.h @@ -115,10 +115,6 @@ class AutoScaler { class MultipleIterationsAutoScaler { public: MultipleIterationsAutoScaler() = default; - // Registers iteration with `iteration_id`, allowing its future reported times - // to be considered for the current workload estimation. Returns an error if - // the specified iteration already exists. - tsl::Status RegisterIteration(int64_t iteration_id) TF_LOCKS_EXCLUDED(mu_); // Unregisters iteration with `iteration_id`, removing its reported // times from consideration of the current workload estimation. // Returns an error if the specified iteration does not exist. @@ -134,37 +130,40 @@ class MultipleIterationsAutoScaler { TF_LOCKS_EXCLUDED(mu_); // Reports the latest observed processing time from the worker with // `worker_address` for iteration with `iteration_id`. Returns an error if - // the specified iteration was not previously registered, or `processing_time` - // is ZeroDuration or negative. + // `processing_time` is ZeroDuration or negative. tsl::Status ReportProcessingTime(int64_t iteration_id, const std::string& worker_address, absl::Duration processing_time) TF_LOCKS_EXCLUDED(mu_); // Reports the latest observed target processing time from the consumer // identified by `consumer_id` for iteration with `iteration_id`. Returns an - // error if the specified iteration was not previously registered, or - // `target_processing_time` is ZeroDuration or negative. + // error if `target_processing_time` is ZeroDuration or negative. tsl::Status ReportTargetProcessingTime(int64_t iteration_id, int64_t consumer_id, absl::Duration target_processing_time) TF_LOCKS_EXCLUDED(mu_); // Unregisters the worker with `worker_address` for iteration with // `iteration_id`, removing its reported processing time from consideration of - // the current workload estimation. Returns an error if iteration with - // `iteration_id` was not previously registered, or the specified worker does - // not exist. + // the current workload estimation. Returns an error if there are no + // previously reported processing times for iteration with `iteration_id` and + // the specified worker. tsl::Status RemoveWorker(int64_t iteration_id, const std::string& worker_address) TF_LOCKS_EXCLUDED(mu_); // Unregisters the consumer identified by `consumer_id` for iteration with // `iteration_id`, removing its reported target processing time from - // consideration of the current workload estimation. Returns an error if - // iteration with `iteration_id` was not previously registered, or the - // specified consumer does not exist. + // consideration of the current workload estimation. Returns an error if there + // are no previously reported processing times for iteration with + // `iteration_id` and the specified consumer. tsl::Status RemoveConsumer(int64_t iteration_id, int64_t consumer_id) TF_LOCKS_EXCLUDED(mu_); private: + // Registers iteration with `iteration_id` if it does not exist already, + // allowing its future reported times to be considered for the current + // workload estimation. + void EnsureIterationIsRegistered(int64_t iteration_id) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); mutable tsl::mutex mu_; // Map from iteration id to AutoScaler. absl::flat_hash_map> auto_scalers_ diff --git a/tensorflow/core/data/service/auto_scaler_test.cc b/tensorflow/core/data/service/auto_scaler_test.cc index 449a35b09ecdc5..eb3a88c2ad26d0 100644 --- a/tensorflow/core/data/service/auto_scaler_test.cc +++ b/tensorflow/core/data/service/auto_scaler_test.cc @@ -308,22 +308,10 @@ TEST(AutoScalerTest, RemoveConsumerAfterNewTPTReported) { TF_ASSERT_OK(auto_scaler.RemoveConsumer(0)); } -TEST(MultipleIterationsAutoScalerTest, RegisterNewIteration) { - MultipleIterationsAutoScaler auto_scaler; - TF_ASSERT_OK(auto_scaler.RegisterIteration(0)); - TF_ASSERT_OK(auto_scaler.RegisterIteration(1)); -} - -TEST(MultipleIterationsAutoScalerTest, RegisterExistingIteration) { - MultipleIterationsAutoScaler auto_scaler; - TF_ASSERT_OK(auto_scaler.RegisterIteration(0)); - EXPECT_THAT(auto_scaler.RegisterIteration(0), - StatusIs(absl::StatusCode::kAlreadyExists)); -} - TEST(MultipleIterationsAutoScalerTest, UnregisterExistingIteration) { MultipleIterationsAutoScaler auto_scaler; - TF_ASSERT_OK(auto_scaler.RegisterIteration(0)); + TF_ASSERT_OK( + auto_scaler.ReportTargetProcessingTime(0, 0, absl::Microseconds(5))); TF_ASSERT_OK(auto_scaler.UnregisterIteration(0)); } @@ -343,8 +331,6 @@ TEST(MultipleIterationsAutoScalerTest, TEST(MultipleIterationsAutoScalerTest, UpdateOptimalNumberOfWorkersMetricNoReportedPTs) { MultipleIterationsAutoScaler auto_scaler; - TF_ASSERT_OK(auto_scaler.RegisterIteration(0)); - TF_ASSERT_OK(auto_scaler.RegisterIteration(1)); TF_ASSERT_OK( auto_scaler.ReportTargetProcessingTime(0, 0, absl::Microseconds(5))); @@ -357,8 +343,6 @@ TEST(MultipleIterationsAutoScalerTest, TEST(MultipleIterationsAutoScalerTest, UpdateOptimalNumberOfWorkersMetricNoReportedTPTs) { MultipleIterationsAutoScaler auto_scaler; - TF_ASSERT_OK(auto_scaler.RegisterIteration(0)); - TF_ASSERT_OK(auto_scaler.RegisterIteration(1)); TF_ASSERT_OK(auto_scaler.ReportProcessingTime(0, "/worker/task/0:20000", absl::Microseconds(10))); @@ -371,8 +355,6 @@ TEST(MultipleIterationsAutoScalerTest, TEST(MultipleIterationsAutoScalerTest, UpdateOptimalNumberOfWorkersMetricWithReportedTimes) { MultipleIterationsAutoScaler auto_scaler; - TF_ASSERT_OK(auto_scaler.RegisterIteration(0)); - TF_ASSERT_OK(auto_scaler.RegisterIteration(1)); TF_ASSERT_OK( auto_scaler.ReportTargetProcessingTime(0, 0, absl::Microseconds(5))); @@ -391,16 +373,12 @@ TEST(MultipleIterationsAutoScalerTest, TEST(MultipleIterationsAutoScalerTest, GetOptimalNumberOfWorkersInitialState) { MultipleIterationsAutoScaler auto_scaler; - TF_ASSERT_OK(auto_scaler.RegisterIteration(0)); - TF_ASSERT_OK(auto_scaler.RegisterIteration(1)); EXPECT_EQ(auto_scaler.GetOptimalNumberOfWorkers(), std::nullopt); } TEST(MultipleIterationsAutoScalerTest, GetOptimalNumberOfWorkersNoRegisteredWorkers) { MultipleIterationsAutoScaler auto_scaler; - TF_ASSERT_OK(auto_scaler.RegisterIteration(0)); - TF_ASSERT_OK(auto_scaler.RegisterIteration(1)); TF_ASSERT_OK( auto_scaler.ReportTargetProcessingTime(0, 0, absl::Microseconds(5))); @@ -412,8 +390,6 @@ TEST(MultipleIterationsAutoScalerTest, TEST(MultipleIterationsAutoScalerTest, GetOptimalNumberOfWorkersNoRegisteredConsumers) { MultipleIterationsAutoScaler auto_scaler; - TF_ASSERT_OK(auto_scaler.RegisterIteration(0)); - TF_ASSERT_OK(auto_scaler.RegisterIteration(1)); TF_ASSERT_OK(auto_scaler.ReportProcessingTime(0, "/worker/task/0:20000", absl::Microseconds(10))); @@ -425,8 +401,6 @@ TEST(MultipleIterationsAutoScalerTest, TEST(MultipleIterationsAutoScalerTest, GetOptimalNumberOfWorkersExpectedEstimate1) { MultipleIterationsAutoScaler auto_scaler; - TF_ASSERT_OK(auto_scaler.RegisterIteration(0)); - TF_ASSERT_OK(auto_scaler.RegisterIteration(1)); // Estimated number of workers for iteration 0 = 8 TF_ASSERT_OK(auto_scaler.ReportProcessingTime(0, "/worker/task/0:20000", @@ -451,9 +425,6 @@ TEST(MultipleIterationsAutoScalerTest, TEST(MultipleIterationsAutoScalerTest, GetOptimalNumberOfWorkersExpectedEstimate2) { MultipleIterationsAutoScaler auto_scaler; - TF_ASSERT_OK(auto_scaler.RegisterIteration(0)); - TF_ASSERT_OK(auto_scaler.RegisterIteration(1)); - TF_ASSERT_OK(auto_scaler.RegisterIteration(2)); // Estimated number of workers for iteration 0 = 8 TF_ASSERT_OK(auto_scaler.ReportProcessingTime(0, "/worker/task/0:20000", @@ -485,29 +456,23 @@ TEST(MultipleIterationsAutoScalerTest, EXPECT_EQ(auto_scaler.GetOptimalNumberOfWorkers(), 20); } -TEST(MultipleIterationsAutoScalerTest, - ReportProcessingTimeUnregisteredIteration) { +TEST(MultipleIterationsAutoScalerTest, ReportProcessingTimeNewIteration) { MultipleIterationsAutoScaler auto_scaler; - EXPECT_THAT(auto_scaler.ReportProcessingTime(0, "/worker/task/0:20000", - absl::Microseconds(10)), - StatusIs(absl::StatusCode::kNotFound)); + TF_ASSERT_OK(auto_scaler.ReportProcessingTime(0, "/worker/task/0:20000", + absl::Microseconds(10))); } TEST(MultipleIterationsAutoScalerTest, ReportProcessingTimeNewWorker) { MultipleIterationsAutoScaler auto_scaler; - TF_ASSERT_OK(auto_scaler.RegisterIteration(0)); - TF_ASSERT_OK(auto_scaler.RegisterIteration(1)); TF_ASSERT_OK(auto_scaler.ReportProcessingTime(0, "/worker/task/0:20000", absl::Microseconds(10))); - TF_ASSERT_OK(auto_scaler.ReportProcessingTime(1, "/worker/task/0:20000", + TF_ASSERT_OK(auto_scaler.ReportProcessingTime(0, "/worker/task/1:20000", absl::Microseconds(10))); } TEST(MultipleIterationsAutoScalerTest, ReportProcessingTimeExistingWorker) { MultipleIterationsAutoScaler auto_scaler; - TF_ASSERT_OK(auto_scaler.RegisterIteration(0)); - TF_ASSERT_OK(auto_scaler.RegisterIteration(1)); TF_ASSERT_OK(auto_scaler.ReportProcessingTime(0, "/worker/task/0:20000", absl::Microseconds(10))); @@ -521,8 +486,6 @@ TEST(MultipleIterationsAutoScalerTest, ReportProcessingTimeExistingWorker) { TEST(MultipleIterationsAutoScalerTest, ReportProcessingTimeNewAndExisting) { MultipleIterationsAutoScaler auto_scaler; - TF_ASSERT_OK(auto_scaler.RegisterIteration(0)); - TF_ASSERT_OK(auto_scaler.RegisterIteration(1)); TF_ASSERT_OK(auto_scaler.ReportProcessingTime(0, "/worker/task/0:20000", absl::Microseconds(10))); @@ -545,7 +508,6 @@ TEST(MultipleIterationsAutoScalerTest, ReportProcessingTimeNewAndExisting) { TEST(MultipleIterationsAutoScalerTest, ReportProcessingTimeZeroDuration) { MultipleIterationsAutoScaler auto_scaler; - TF_ASSERT_OK(auto_scaler.RegisterIteration(0)); tsl::Status result = auto_scaler.ReportProcessingTime( 0, "/worker/task/0:20000", absl::ZeroDuration()); @@ -554,37 +516,30 @@ TEST(MultipleIterationsAutoScalerTest, ReportProcessingTimeZeroDuration) { TEST(MultipleIterationsAutoScalerTest, ReportProcessingTimeNegativeDuration) { MultipleIterationsAutoScaler auto_scaler; - TF_ASSERT_OK(auto_scaler.RegisterIteration(0)); tsl::Status result = auto_scaler.ReportProcessingTime( 0, "/worker/task/0:20000", absl::Microseconds(-10)); EXPECT_THAT(result, StatusIs(absl::StatusCode::kInvalidArgument)); } -TEST(MultipleIterationsAutoScalerTest, - ReportTargetProcessingTimeUnregisteredIteration) { +TEST(MultipleIterationsAutoScalerTest, ReportTargetProcessingTimeNewIteration) { MultipleIterationsAutoScaler auto_scaler; - EXPECT_THAT( - auto_scaler.ReportTargetProcessingTime(0, 0, absl::Microseconds(10)), - StatusIs(absl::StatusCode::kNotFound)); + TF_ASSERT_OK( + auto_scaler.ReportTargetProcessingTime(0, 0, absl::Microseconds(10))); } TEST(MultipleIterationsAutoScalerTest, ReportTargetProcessingTimeNewConsumer) { MultipleIterationsAutoScaler auto_scaler; - TF_ASSERT_OK(auto_scaler.RegisterIteration(0)); - TF_ASSERT_OK(auto_scaler.RegisterIteration(1)); TF_ASSERT_OK( auto_scaler.ReportTargetProcessingTime(0, 0, absl::Microseconds(10))); TF_ASSERT_OK( - auto_scaler.ReportTargetProcessingTime(1, 0, absl::Microseconds(10))); + auto_scaler.ReportTargetProcessingTime(0, 1, absl::Microseconds(10))); } TEST(MultipleIterationsAutoScalerTest, ReportTargetProcessingTimeExistingWorker) { MultipleIterationsAutoScaler auto_scaler; - TF_ASSERT_OK(auto_scaler.RegisterIteration(0)); - TF_ASSERT_OK(auto_scaler.RegisterIteration(1)); TF_ASSERT_OK( auto_scaler.ReportTargetProcessingTime(0, 0, absl::Microseconds(10))); @@ -599,8 +554,6 @@ TEST(MultipleIterationsAutoScalerTest, TEST(MultipleIterationsAutoScalerTest, ReportTargetProcessingTimeNewAndExisting) { MultipleIterationsAutoScaler auto_scaler; - TF_ASSERT_OK(auto_scaler.RegisterIteration(0)); - TF_ASSERT_OK(auto_scaler.RegisterIteration(1)); TF_ASSERT_OK( auto_scaler.ReportTargetProcessingTime(0, 0, absl::Microseconds(10))); @@ -623,7 +576,6 @@ TEST(MultipleIterationsAutoScalerTest, TEST(MultipleIterationsAutoScalerTest, ReportTargetProcessingTimeZeroDuration) { MultipleIterationsAutoScaler auto_scaler; - TF_ASSERT_OK(auto_scaler.RegisterIteration(0)); tsl::Status result = auto_scaler.ReportTargetProcessingTime(0, 0, absl::ZeroDuration()); @@ -633,7 +585,6 @@ TEST(MultipleIterationsAutoScalerTest, ReportTargetProcessingTimeZeroDuration) { TEST(MultipleIterationsAutoScalerTest, ReportTargetProcessingTimeNegativeDuration) { MultipleIterationsAutoScaler auto_scaler; - TF_ASSERT_OK(auto_scaler.RegisterIteration(0)); tsl::Status result = auto_scaler.ReportTargetProcessingTime(0, 0, absl::Microseconds(-10)); @@ -650,8 +601,6 @@ TEST(MultipleIterationsAutoScalerTest, RemoveWorkerUnregisteredIteration) { TEST(MultipleIterationsAutoScalerTest, RemoveWorkerSuccessful) { MultipleIterationsAutoScaler auto_scaler; - TF_ASSERT_OK(auto_scaler.RegisterIteration(0)); - TF_ASSERT_OK(auto_scaler.RegisterIteration(1)); TF_ASSERT_OK(auto_scaler.ReportProcessingTime(0, "/worker/task/0:20000", absl::Microseconds(10))); @@ -663,14 +612,14 @@ TEST(MultipleIterationsAutoScalerTest, RemoveWorkerSuccessful) { TEST(MultipleIterationsAutoScalerTest, RemoveNonexistentWorker) { MultipleIterationsAutoScaler auto_scaler; - TF_ASSERT_OK(auto_scaler.RegisterIteration(0)); - EXPECT_THAT(auto_scaler.RemoveWorker(0, "/worker/task/0:20000"), + TF_ASSERT_OK(auto_scaler.ReportProcessingTime(0, "/worker/task/0:20000", + absl::Microseconds(10))); + EXPECT_THAT(auto_scaler.RemoveWorker(0, "/worker/task/1:20000"), StatusIs(absl::StatusCode::kNotFound)); } TEST(MultipleIterationsAutoScalerTest, RemoveWorkerAfterNewPTReported) { MultipleIterationsAutoScaler auto_scaler; - TF_ASSERT_OK(auto_scaler.RegisterIteration(0)); TF_ASSERT_OK(auto_scaler.ReportProcessingTime(0, "/worker/task/0:20000", absl::Microseconds(10))); @@ -689,8 +638,6 @@ TEST(MultipleIterationsAutoScalerTest, RemoveConsumerUnregisteredIteration) { TEST(MultipleIterationsAutoScalerTest, RemoveConsumerSuccessful) { MultipleIterationsAutoScaler auto_scaler; - TF_ASSERT_OK(auto_scaler.RegisterIteration(0)); - TF_ASSERT_OK(auto_scaler.RegisterIteration(1)); TF_ASSERT_OK( auto_scaler.ReportTargetProcessingTime(0, 0, absl::Microseconds(10))); @@ -702,14 +649,14 @@ TEST(MultipleIterationsAutoScalerTest, RemoveConsumerSuccessful) { TEST(MultipleIterationsAutoScalerTest, RemoveNonexistentConsumer) { MultipleIterationsAutoScaler auto_scaler; - TF_ASSERT_OK(auto_scaler.RegisterIteration(0)); - EXPECT_THAT(auto_scaler.RemoveConsumer(0, 0), + TF_ASSERT_OK( + auto_scaler.ReportTargetProcessingTime(0, 0, absl::Microseconds(10))); + EXPECT_THAT(auto_scaler.RemoveConsumer(0, 1), StatusIs(absl::StatusCode::kNotFound)); } TEST(MultipleIterationsAutoScalerTest, RemoveConsumerAfterNewTPTReported) { MultipleIterationsAutoScaler auto_scaler; - TF_ASSERT_OK(auto_scaler.RegisterIteration(0)); TF_ASSERT_OK( auto_scaler.ReportTargetProcessingTime(0, 0, absl::Microseconds(10))); diff --git a/tensorflow/core/data/service/dispatcher_impl.cc b/tensorflow/core/data/service/dispatcher_impl.cc index 96abf4072ddced..326ed1ee8095d8 100644 --- a/tensorflow/core/data/service/dispatcher_impl.cc +++ b/tensorflow/core/data/service/dispatcher_impl.cc @@ -836,11 +836,6 @@ Status DataServiceDispatcherImpl::CreateIteration( TF_RETURN_IF_ERROR(Apply(update)); TF_RETURN_IF_ERROR(state_.IterationFromId(iteration_id, iteration)); - Status auto_scaler_status = auto_scaler_.RegisterIteration(iteration_id); - if (!auto_scaler_status.ok()) { - LOG(WARNING) << "Failed to register Iteration " << iteration_id - << " with tf.data service AutoScaler: " << auto_scaler_status; - } return OkStatus(); } From 766d99388c8a2f56d8a034442053c5de0b36b4c4 Mon Sep 17 00:00:00 2001 From: Cesar Magana De Leon Date: Thu, 27 Jul 2023 14:02:40 -0700 Subject: [PATCH 271/410] Implementing BEF Generation to saved_model_aot_compile and minimal changes to python API path. PiperOrigin-RevId: 551634236 --- tensorflow/core/tfrt/saved_model/BUILD | 11 + tensorflow/core/tfrt/saved_model/python/BUILD | 24 +- .../python/saved_model_aot_compile_test.py | 37 ---- .../core/tfrt/saved_model/saved_model.cc | 205 +++++++++--------- .../core/tfrt/saved_model/saved_model.h | 29 +++ .../saved_model/saved_model_aot_compile.cc | 120 +++++++++- .../saved_model/saved_model_aot_compile.h | 6 +- tensorflow/core/tfrt/saved_model/utils/BUILD | 2 + 8 files changed, 267 insertions(+), 167 deletions(-) delete mode 100644 tensorflow/core/tfrt/saved_model/python/saved_model_aot_compile_test.py diff --git a/tensorflow/core/tfrt/saved_model/BUILD b/tensorflow/core/tfrt/saved_model/BUILD index a157756f468eee..1ed76db746ff61 100644 --- a/tensorflow/core/tfrt/saved_model/BUILD +++ b/tensorflow/core/tfrt/saved_model/BUILD @@ -33,8 +33,10 @@ cc_library( ], hdrs = ["saved_model_aot_compile.h"], deps = [ + ":saved_model", "//tensorflow/cc/saved_model:constants", "//tensorflow/compiler/mlir/tensorflow:translate_lib", + "//tensorflow/compiler/mlir/tfrt:import_model", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/core:core_cpu_base", "//tensorflow/core:lib", @@ -43,12 +45,20 @@ cc_library( "//tensorflow/core/platform:path", "//tensorflow/core/platform:status", "//tensorflow/core/runtime_fallback/kernel:kernel_fallback_execute_compat", + "//tensorflow/core/tfrt/fallback:fallback_state", + "//tensorflow/core/tfrt/graph_executor", + "//tensorflow/core/tfrt/graph_executor:export_mlir", "//tensorflow/core/tfrt/graph_executor:graph_execution_options", "//tensorflow/core/tfrt/runtime", + "//tensorflow/core/tfrt/saved_model/utils:serialize_bef_utils", + "//tensorflow/core/tfrt/utils", "//tensorflow/tsl/platform:env", "//tensorflow/tsl/platform:status", "@com_google_absl//absl/status", + "@tf_runtime//:bef", + "@tf_runtime//:befexecutor", "@tf_runtime//:core_runtime_alwayslink", + "@tf_runtime//:hostcontext", ], ) @@ -94,6 +104,7 @@ cc_library( "//tensorflow/core/tfrt/mlrt/kernel:batch_kernel", "//tensorflow/core/tfrt/runtime", "//tensorflow/core/tfrt/runtime:work_queue_interface", + "//tensorflow/core/tfrt/saved_model/utils:serialize_bef_utils", "//tensorflow/core/tfrt/utils", "//tensorflow/core/tfrt/utils:error_util", "//tensorflow/core/tfrt/utils:fallback_tensor", diff --git a/tensorflow/core/tfrt/saved_model/python/BUILD b/tensorflow/core/tfrt/saved_model/python/BUILD index 233cadd38c35a6..a6fc120e28920e 100644 --- a/tensorflow/core/tfrt/saved_model/python/BUILD +++ b/tensorflow/core/tfrt/saved_model/python/BUILD @@ -1,5 +1,5 @@ load("//tensorflow:tensorflow.default.bzl", "tf_python_pybind_extension") -load("//tensorflow:pytype.default.bzl", "pytype_strict_binary", "pytype_strict_contrib_test") +load("//tensorflow:pytype.default.bzl", "pytype_strict_binary") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -13,6 +13,7 @@ package_group( # Authorized users go here. "//tensorflow/core/tfrt/saved_model/...", "//tensorflow/core/tfrt/graph_executor/...", + "//learning/brain/tfrt/cpp_tests/gpu_inference/...", ], ) @@ -79,24 +80,3 @@ tf_python_pybind_extension( "@pybind11_abseil//pybind11_abseil:status_casters", ], ) - -# copybara:uncomment_begin(AoT) -# pytype_strict_contrib_test( -# name = "saved_model_aot_compile_test", -# size = "small", -# srcs = [ -# "saved_model_aot_compile_test.py", -# ], -# data = [ -# "//learning/brain/tfrt/cpp_tests/gpu_inference:testdata", -# ], -# python_version = "PY3", -# deps = [ -# ":_pywrap_saved_model_aot_compile", -# "//base/python:pywrapbase", -# "//testing/pybase", -# "//third_party/py/lingvo:compat", -# "//tensorflow/python/platform:client_testlib", -# ], -# ) -# copybara:uncomment_end diff --git a/tensorflow/core/tfrt/saved_model/python/saved_model_aot_compile_test.py b/tensorflow/core/tfrt/saved_model/python/saved_model_aot_compile_test.py deleted file mode 100644 index 2e849fb743855c..00000000000000 --- a/tensorflow/core/tfrt/saved_model/python/saved_model_aot_compile_test.py +++ /dev/null @@ -1,37 +0,0 @@ -# Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -import os - -import lingvo.compat as tf - -from tensorflow.core.tfrt.saved_model.python import _pywrap_saved_model_aot_compile -from tensorflow.python.platform import test - - -class SavedModelAotCompileTest(test.TestCase): - - def testVerify_saved_model(self): - outputpath = os.getenv("TEST_UNDECLARED_OUTPUTS_DIR") - filepath = "learning/brain/tfrt/cpp_tests/gpu_inference/test_data/translate_converted_placed/" - _pywrap_saved_model_aot_compile.AotCompileSavedModel( - filepath, _pywrap_saved_model_aot_compile.AotOptions(), outputpath - ) - - # Verifies that .pbtxt is created correctly in the output directory - self.assertTrue(tf.io.gfile.exists(outputpath + "/aot_saved_model.pbtxt")) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/core/tfrt/saved_model/saved_model.cc b/tensorflow/core/tfrt/saved_model/saved_model.cc index 3507210f6e6d40..97d7d1c75e2803 100644 --- a/tensorflow/core/tfrt/saved_model/saved_model.cc +++ b/tensorflow/core/tfrt/saved_model/saved_model.cc @@ -65,6 +65,7 @@ limitations under the License. #include "tensorflow/core/tfrt/mlrt/kernel/kernel.h" #include "tensorflow/core/tfrt/runtime/work_queue_interface.h" #include "tensorflow/core/tfrt/saved_model/saved_model_import_input.h" +#include "tensorflow/core/tfrt/saved_model/utils/serialize_bef_utils.h" #include "tensorflow/core/tfrt/utils/error_util.h" #include "tensorflow/core/tfrt/utils/fallback_tensor.h" #include "tensorflow/core/tfrt/utils/utils.h" @@ -87,20 +88,6 @@ namespace { constexpr absl::string_view kSignatureJoiningDelimiter = "+"; -using SignatureMap = absl::flat_hash_map; -using ::tensorflow::StatusOr; - -struct Initializer { - std::string name; -}; - -struct InitializersAndSignatures { - // Initializers are kept in a certain order as they need to be executed in - // that order. - std::vector initializers; - SignatureMap signature_map; -}; - auto* saved_model_read_meta_graph_time_seconds = tensorflow::monitoring::Gauge::New( "/tensorflow/tfrt/saved_model/read_meta_graph_time", @@ -143,50 +130,6 @@ auto* saved_model_input_spec_validation_failure = "/tensorflow/tfrt/saved_model/input_spec_validation_failure", "Record the models that failed input spec validation.", "model_name"); -StatusOr GetInitializersAndSignatures( - mlir::ModuleOp module) { - InitializersAndSignatures result; - - // Create placeholders for initializers. - for (auto session_initializer_name : - mlir::tf_saved_model::GetSessionInitializerExportedName(module)) { - Initializer initializer; - initializer.name = session_initializer_name.str(); - result.initializers.push_back(std::move(initializer)); - } - - auto& signatures = result.signature_map; - TF_RETURN_IF_ERROR(tensorflow::MapFunctionSignaturesFromTFSavedModelMLIR( - module, - [&signatures](const tensorflow::TFRTSavedModelSignatureInfo& sig_info) { - auto signature_name = std::string(sig_info.func_name); - auto& signature = signatures[signature_name]; - - auto copy = [](llvm::ArrayRef src, - std::vector* dst) { - transform(src, std::back_inserter(*dst), - [](llvm::StringRef x) { return x.str(); }); - }; - copy(sig_info.input_names, &signature.input_names); - copy(sig_info.output_names, &signature.output_names); - copy(sig_info.input_devices, &signature.input_devices); - - DCHECK(signature.input_specs.empty()); - signature.input_specs.reserve(sig_info.input_specs.size()); - for (auto& spec : sig_info.input_specs) { - signature.input_specs.push_back(TensorSpec(spec.first, spec.second)); - } - - DCHECK(signature.output_specs.empty()); - signature.output_specs.reserve(sig_info.output_specs.size()); - for (auto& spec : sig_info.output_specs) { - signature.output_specs.push_back(TensorSpec(spec.first, spec.second)); - } - })); - - return result; -} - tensorflow::Status RunBytecodeInitializers( const GraphExecutionOptions& options, const InitializersAndSignatures& initializers_and_signatures, @@ -316,52 +259,6 @@ std::vector FindNamesForValidSignatures( return valid_signature_names; } -StatusOr> ImportSavedModel( - mlir::MLIRContext* context, const tensorflow::MetaGraphDef& meta_graph_def, - const FallbackState& fallback_state, std::string saved_model_dir, - bool import_user_signatures, bool run_placer_grappler_on_functions) { - std::vector signature_names; - if (import_user_signatures) { - signature_names = FindNamesForValidSignatures(meta_graph_def); - if (signature_names.empty()) - LOG(WARNING) << "No valid signature found for model: " << saved_model_dir; - } - - // TfrtSavedModelMLIRImportInput basically implements the graph processing - // logic (eg. Placer and Grappler) used in DirectSession, which apply graph - // transformations on each subgraphs (ie. signatures). It is reusing the - // code path in DirectSession to avoid problems caused by different behavior - // in a different code path. And it is injected to the MLIR importer so that - // the importer can import the transformed graph instead of the original - // graph. - TF_ASSIGN_OR_RETURN(auto import_input, - TfrtSavedModelMLIRImportInput::Create( - fallback_state, &meta_graph_def, /*debug_info=*/{}, - run_placer_grappler_on_functions)); - - TF_ASSIGN_OR_RETURN( - auto module, - tensorflow::ConvertSavedModelV1ToMlirLite( - import_input, - /*exported_names=*/absl::MakeSpan(signature_names), context)); - - LOG(INFO) << "TFRT ImportSavedModel: Functionalization took " - << absl::ToInt64Milliseconds( - import_input.GetFunctionalizationDuration()) - << " ms."; - LOG(INFO) << "TFRT ImportSavedModel: Grappler took " - << absl::ToInt64Milliseconds(import_input.GetGrapplerDuration()) - << " ms."; - - saved_model_functionalization_time_seconds->GetCell(saved_model_dir) - ->Set(absl::ToInt64Seconds(import_input.GetFunctionalizationDuration())); - - saved_model_grappler_time_seconds->GetCell(saved_model_dir) - ->Set(absl::ToInt64Seconds(import_input.GetGrapplerDuration())); - - return module; -} - tensorflow::Status IsInputSpecsCorrect( absl::string_view name, const internal::Signature& signature, absl::Span inputs) { @@ -541,6 +438,98 @@ void UpdateCompileOptions(SavedModel::Options& options) { } } +} // namespace + +StatusOr> ImportSavedModel( + mlir::MLIRContext* context, const tensorflow::MetaGraphDef& meta_graph_def, + const FallbackState& fallback_state, std::string saved_model_dir, + bool import_user_signatures, bool run_placer_grappler_on_functions) { + std::vector signature_names; + if (import_user_signatures) { + signature_names = FindNamesForValidSignatures(meta_graph_def); + if (signature_names.empty()) + LOG(WARNING) << "No valid signature found for model: " << saved_model_dir; + } + + // TfrtSavedModelMLIRImportInput basically implements the graph processing + // logic (eg. Placer and Grappler) used in DirectSession, which apply graph + // transformations on each subgraphs (ie. signatures). It is reusing the + // code path in DirectSession to avoid problems caused by different behavior + // in a different code path. And it is injected to the MLIR importer so that + // the importer can import the transformed graph instead of the original + // graph. + TF_ASSIGN_OR_RETURN(auto import_input, + TfrtSavedModelMLIRImportInput::Create( + fallback_state, &meta_graph_def, /*debug_info=*/{}, + run_placer_grappler_on_functions)); + + TF_ASSIGN_OR_RETURN( + auto module, + tensorflow::ConvertSavedModelV1ToMlirLite( + import_input, + /*exported_names=*/absl::MakeSpan(signature_names), context)); + + LOG(INFO) << "TFRT ImportSavedModel: Functionalization took " + << absl::ToInt64Milliseconds( + import_input.GetFunctionalizationDuration()) + << " ms."; + LOG(INFO) << "TFRT ImportSavedModel: Grappler took " + << absl::ToInt64Milliseconds(import_input.GetGrapplerDuration()) + << " ms."; + + saved_model_functionalization_time_seconds->GetCell(saved_model_dir) + ->Set(absl::ToInt64Seconds(import_input.GetFunctionalizationDuration())); + + saved_model_grappler_time_seconds->GetCell(saved_model_dir) + ->Set(absl::ToInt64Seconds(import_input.GetGrapplerDuration())); + + return module; +} + +StatusOr GetInitializersAndSignatures( + mlir::ModuleOp module) { + InitializersAndSignatures result; + + // Create placeholders for initializers. + for (auto session_initializer_name : + mlir::tf_saved_model::GetSessionInitializerExportedName(module)) { + Initializer initializer; + initializer.name = session_initializer_name.str(); + result.initializers.push_back(std::move(initializer)); + } + + auto& signatures = result.signature_map; + TF_RETURN_IF_ERROR(tensorflow::MapFunctionSignaturesFromTFSavedModelMLIR( + module, + [&signatures](const tensorflow::TFRTSavedModelSignatureInfo& sig_info) { + auto signature_name = std::string(sig_info.func_name); + auto& signature = signatures[signature_name]; + + auto copy = [](llvm::ArrayRef src, + std::vector* dst) { + transform(src, std::back_inserter(*dst), + [](llvm::StringRef x) { return x.str(); }); + }; + copy(sig_info.input_names, &signature.input_names); + copy(sig_info.output_names, &signature.output_names); + copy(sig_info.input_devices, &signature.input_devices); + + DCHECK(signature.input_specs.empty()); + signature.input_specs.reserve(sig_info.input_specs.size()); + for (auto& spec : sig_info.input_specs) { + signature.input_specs.push_back(TensorSpec(spec.first, spec.second)); + } + + DCHECK(signature.output_specs.empty()); + signature.output_specs.reserve(sig_info.output_specs.size()); + for (auto& spec : sig_info.output_specs) { + signature.output_specs.push_back(TensorSpec(spec.first, spec.second)); + } + })); + + return result; +} + StatusOr ReadSavedModel( absl::string_view saved_model_dir, const std::unordered_set& tags) { @@ -560,7 +549,6 @@ StatusOr ReadSavedModel( return std::move(meta_graph_def); } -} // namespace tensorflow::StatusOr> SavedModelImpl::LoadSavedModel(Options options, @@ -699,6 +687,13 @@ SavedModelImpl::LoadSavedModel(Options options, graph_executor->kernel_registry()); } else { DCHECK(!bef.empty()); + // TODO(cesarmagana) + // Call code if bef exists, make into its own util + // Deserialization is only called if BEF is found + + // and if bef file exists this will be called + // Create another function where we first detect if bef_file exists in + // saved_model dir then we run code below if not we call original code. ASSIGN_OR_RETURN_IN_INIT( bef_file, tfrt::CreateBefFileFromBefBuffer( *options.graph_execution_options.runtime, bef)); diff --git a/tensorflow/core/tfrt/saved_model/saved_model.h b/tensorflow/core/tfrt/saved_model/saved_model.h index 256706d0129224..21fcbd1dce432d 100644 --- a/tensorflow/core/tfrt/saved_model/saved_model.h +++ b/tensorflow/core/tfrt/saved_model/saved_model.h @@ -202,6 +202,35 @@ class SavedModel { const Options options_; }; +// TODO(cesarmagana) Create new library saved_model_utils and move (refactor) +// functions to the anonymous space of the util file. Making only one API public +// for use in both LoadSavedModel and AotCompileSavedModel. +StatusOr> ImportSavedModel( + mlir::MLIRContext* context, const tensorflow::MetaGraphDef& meta_graph_def, + const FallbackState& fallback_state, std::string saved_model_dir, + bool import_user_signatures, bool run_placer_grappler_on_functions); + +StatusOr ReadSavedModel( + absl::string_view saved_model_dir, + const std::unordered_set& tags); + +using SignatureMap = absl::flat_hash_map; +using ::tensorflow::StatusOr; + +struct Initializer { + std::string name; +}; + +struct InitializersAndSignatures { + // Initializers are kept in a certain order as they need to be executed in + // that order. + std::vector initializers; + SignatureMap signature_map; +}; + +StatusOr GetInitializersAndSignatures( + mlir::ModuleOp module); + class SavedModelImpl final : public SavedModel { public: struct JoinedSignature; diff --git a/tensorflow/core/tfrt/saved_model/saved_model_aot_compile.cc b/tensorflow/core/tfrt/saved_model/saved_model_aot_compile.cc index aa5fcd51e98195..e9bc5a2a889b1d 100644 --- a/tensorflow/core/tfrt/saved_model/saved_model_aot_compile.cc +++ b/tensorflow/core/tfrt/saved_model/saved_model_aot_compile.cc @@ -16,29 +16,145 @@ limitations under the License. #include "tensorflow/core/tfrt/saved_model/saved_model_aot_compile.h" #include +#include #include +#include #include "absl/status/status.h" #include "tensorflow/cc/saved_model/constants.h" #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h" +#include "tensorflow/compiler/mlir/tfrt/translate/import_model.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/core/platform/file_system_helper.h" #include "tensorflow/core/platform/path.h" #include "tensorflow/core/platform/status.h" +#include "tensorflow/core/tfrt/fallback/fallback_state.h" +#include "tensorflow/core/tfrt/graph_executor/export_mlir.h" #include "tensorflow/core/tfrt/graph_executor/graph_execution_options.h" +#include "tensorflow/core/tfrt/graph_executor/graph_executor.h" #include "tensorflow/core/tfrt/runtime/runtime.h" +#include "tensorflow/core/tfrt/saved_model/saved_model.h" +#include "tensorflow/core/tfrt/saved_model/utils/serialize_bef_utils.h" +#include "tensorflow/core/tfrt/utils/utils.h" #include "tensorflow/tsl/platform/env.h" #include "tensorflow/tsl/platform/errors.h" #include "tensorflow/tsl/platform/file_system_helper.h" #include "tensorflow/tsl/platform/status.h" +#include "tfrt/bef/bef_buffer.h" // from @tf_runtime +#include "tfrt/bef_executor/bef_file.h" // from @tf_runtime +#include "tfrt/host_context/resource_context.h" // from @tf_runtime namespace tensorflow::tfrt_stub { +void UpdateCompileOptions(AotOptions& options) { + // Disable DecomposeResourceOpsPass for now, as DecomposeResourceGather does + // not work well with GPU (b/232819415). + if (options.graph_execution_options->enable_tfrt_gpu) { + options.graph_execution_options->compile_options.decompose_resource_ops = + false; + } + + options.graph_execution_options->compile_options + .fuse_get_resource_ops_in_hoisting = + !options.graph_execution_options->enable_mlrt; +} + AotOptions::AotOptions() : graph_execution_options(nullptr) {} Status AotCompileSavedModel(absl::string_view input_model_dir, - const AotOptions& aot_options, + AotOptions aot_options, absl::string_view output_model_dir) { + if (aot_options.graph_execution_options == nullptr) { + // Since we are not going to actually run the model during AoT + // compilation and optimization, we choose a value of 4 inter_op_threads + // which is commonly used for testing. + SetGlobalRuntime(tfrt_stub::Runtime::Create(/*num_inter_op_threads=*/4)); + + GraphExecutionOptions graph_execution_options(GetGlobalRuntime()); + + graph_execution_options.enable_tfrt_gpu = true; + graph_execution_options.enable_grappler_function_optimizer = true; + graph_execution_options.compile_options.enable_grappler = true; + graph_execution_options.compile_options.device_target = + TfrtDeviceInfraTarget::kGpu; + graph_execution_options.compile_options.hoist_invariant_ops = true; + + aot_options.graph_execution_options = + std::make_shared(graph_execution_options); + } + + if (aot_options.tags.empty()) { + aot_options.tags = {"serve", "gpu"}; + } + // TODO(cesarmagana) Refactor duplicated code from saved_model and + // saved_model_aot_compile into a shared util library for LoadSavedModel() + TF_ASSIGN_OR_RETURN(tensorflow::MetaGraphDef meta_graph_def, + ReadSavedModel(input_model_dir, aot_options.tags)); + + UpdateTpuTargetByBridgeCompatibility(*aot_options.graph_execution_options, + meta_graph_def.graph_def()); + UpdateCompileOptions(aot_options); + aot_options.graph_execution_options->compile_options.saved_model_dir = + input_model_dir; + mlir::DialectRegistry registry; + RegisterMlirDialect(registry); + mlir::MLIRContext context(registry); + + tensorflow::SessionOptions session_options = + CreateDefaultSessionOptions(*aot_options.graph_execution_options); + session_options.config.mutable_experimental()->set_optimize_for_static_graph( + true); + LOG_FIRST_N(INFO, 10) << "SessionOptions: " + << session_options.config.DebugString(); + LOG_FIRST_N(INFO, 10) << "GraphExecutionOptions: " + << *aot_options.graph_execution_options; + + const ::tensorflow::FunctionDefLibrary& fdef_lib = + meta_graph_def.graph_def().library(); + ASSIGN_OR_RETURN_IN_IMPORT( + std::unique_ptr fallback_state, + FallbackState::Create(session_options, fdef_lib)); + ASSIGN_OR_RETURN_IN_IMPORT( + mlir::OwningOpRef mlir_module, + ImportSavedModel(&context, meta_graph_def, *fallback_state, + std::string(input_model_dir), + /*import_user_signatures=*/true, + aot_options.graph_execution_options + ->run_placer_grappler_on_functions)); + + auto kernel_registry = std::make_unique(); + + auto resource_context = std::make_unique(); + ModelRuntimeContext model_context(&*aot_options.graph_execution_options, + std::string(input_model_dir), + resource_context.get()); + + { + model_context.set_meta_graph_def(&meta_graph_def); + TF_RETURN_IF_ERROR( + aot_options.graph_execution_options->runtime->CreateRuntimeResources( + model_context)); + + model_context.set_meta_graph_def(nullptr); + } + + ASSIGN_OR_RETURN_WITH_STAGE_INFO( + "graph_executor creation", + std::unique_ptr graph_executor, + GraphExecutor::Create(*aot_options.graph_execution_options, + *fallback_state, std::move(resource_context), + std::move(*meta_graph_def.mutable_graph_def()), + std::move(kernel_registry))); + + tfrt::BefBuffer bef; + RETURN_IF_ERROR_IN_COMPILE(tensorflow::ConvertTfMlirToBef( + aot_options.graph_execution_options->compile_options, mlir_module.get(), + &bef, model_context, fallback_state.get())); + if (bef.empty()) { + LOG(ERROR) << "BefBuffer is empty."; + return absl::InternalError("BefBuffer is empty."); + } + Env* env = Env::Default(); const std::string warmup_requests_path = io::JoinPath( input_model_dir, "assets.extra", "tf_serving_warmup_requests"); @@ -71,11 +187,13 @@ Status AotCompileSavedModel(absl::string_view input_model_dir, const std::string output_file_directory = io::JoinPath(std::string(output_model_dir), absl::StrCat("aot_", kSavedModelFilenamePb)); + // serialize bef to a file under output_model_dir return env->CopyFile(saved_model_pb_path, output_file_directory); } else { const std::string output_file_directory = io::JoinPath(std::string(output_model_dir), absl::StrCat("aot_", kSavedModelFilenamePbTxt)); + // serialize bef to a file under output_model_dir return env->CopyFile(saved_model_pbtxt_path, output_file_directory); } } diff --git a/tensorflow/core/tfrt/saved_model/saved_model_aot_compile.h b/tensorflow/core/tfrt/saved_model/saved_model_aot_compile.h index 5547d3506e0ace..017bc7923bc42f 100644 --- a/tensorflow/core/tfrt/saved_model/saved_model_aot_compile.h +++ b/tensorflow/core/tfrt/saved_model/saved_model_aot_compile.h @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/core/tfrt/graph_executor/graph_execution_options.h" @@ -25,8 +26,9 @@ limitations under the License. namespace tensorflow::tfrt_stub { struct AotOptions { AotOptions(); + std::unordered_set tags = {}; - std::unique_ptr graph_execution_options; + std::shared_ptr graph_execution_options; }; // AOT Compiles saved_model in input_model_dir, writing output @@ -34,7 +36,7 @@ struct AotOptions { // "{input_model_dir}/aot_packages" if output dir provided. Warmup requests // should be present in input_model_dir Status AotCompileSavedModel(absl::string_view input_model_dir, - const AotOptions& aot_options = {}, + AotOptions aot_options = {}, absl::string_view output_model_dir = ""); } // namespace tensorflow::tfrt_stub diff --git a/tensorflow/core/tfrt/saved_model/utils/BUILD b/tensorflow/core/tfrt/saved_model/utils/BUILD index b67ed5caf68563..4e0ade80568b74 100644 --- a/tensorflow/core/tfrt/saved_model/utils/BUILD +++ b/tensorflow/core/tfrt/saved_model/utils/BUILD @@ -10,6 +10,8 @@ package_group( name = "friends", packages = [ # Authorized users go here. + "//tensorflow/core/tfrt/saved_model/...", + "//learning/brain/tfrt/cpp_tests/gpu_inference/...", ], ) From 7983c59f617d9e4698fd571d6ab62f8f0e95d6cf Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 27 Jul 2023 14:20:01 -0700 Subject: [PATCH 272/410] Don't add input shardings for conditional statements when generating replicated strategies. This is consistent with the generation of other kinds of shardings for conditional statements. Change needed to get a unit test (AutoShardingTest.Conditional2TupleArgTest) to pass after a regression in cl/549483733 PiperOrigin-RevId: 551639052 --- .../xla/hlo/experimental/auto_sharding/auto_sharding.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding.cc index dc4324495c325d..dcaf8d56073109 100644 --- a/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -524,7 +524,6 @@ void AddReplicatedStrategy(const HloInstruction* ins, const Shape& shape, if (ins->opcode() == HloOpcode::kConditional) { resharding_costs.push_back(std::vector( strategy_map.at(operand)->leaf_vector.size(), 0)); - input_shardings.push_back(output_spec); } else { resharding_costs.push_back(ReshardingCostVector( strategy_map.at(operand).get(), ins->operand(k)->shape(), @@ -2582,6 +2581,7 @@ void SetHloShardingPostProcessing( } else { const ShardingStrategy& stra = GetShardingStrategy(inst, strategy_map, cost_graph, s_val); + if (stra.input_shardings.empty()) { continue; } From 7e723269d9e42cd6efb34e6b2bf043f42940e982 Mon Sep 17 00:00:00 2001 From: Matt Callanan Date: Thu, 27 Jul 2023 14:30:47 -0700 Subject: [PATCH 273/410] #tf-data Turn down `"file_locality"` (v1) experiment. PiperOrigin-RevId: 551641998 --- tensorflow/core/data/dataset_utils.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/data/dataset_utils.cc b/tensorflow/core/data/dataset_utils.cc index c31e7ec1fc5965..535e24eb718ac5 100644 --- a/tensorflow/core/data/dataset_utils.cc +++ b/tensorflow/core/data/dataset_utils.cc @@ -976,7 +976,7 @@ REGISTER_DATASET_EXPERIMENT("stage_based_autotune_v2", RandomJobSamplePercentage<0>, IndependentHostTasks); REGISTER_DATASET_EXPERIMENT("data_transfer", RandomJobSamplePercentage<0>, AllTasks); -REGISTER_DATASET_EXPERIMENT("file_locality", RandomJobSamplePercentage<10>, +REGISTER_DATASET_EXPERIMENT("file_locality", RandomJobSamplePercentage<0>, IndependentHostTasks); REGISTER_DATASET_EXPERIMENT("file_locality_v2", RandomJobSamplePercentage<10>, IndependentHostTasks); From bd2ad32de1b77c29078b0944ce7269b71091cc33 Mon Sep 17 00:00:00 2001 From: Ken Franko Date: Thu, 27 Jul 2023 14:44:23 -0700 Subject: [PATCH 274/410] Add TF::CollectiveReduceV2Op to AllowedTf2XlaFallback so it is not outside compiled. PiperOrigin-RevId: 551645759 --- .../compiler/mlir/tf2xla/transforms/legalization_op_config.cc | 1 + .../mlir/tf2xla/transforms/legalization_op_config_test.cc | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config.cc index fa138cc9bca4bb..79c3517b650780 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config.cc @@ -176,6 +176,7 @@ bool IsOpTypeAllowedTf2XlaFallback(const TypeID& type_id) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc index a194a5b330a20e..9f50e358e771f9 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc @@ -130,8 +130,8 @@ TEST_F(LegalizationOpConfigTest, CountLoweringsSet) { // from MLIR to TF2XLA), these numbers should change. Or if TF Dialect adds // a new op, we should expect these to change too. EXPECT_EQ(mlir_lowering_count, 71); - EXPECT_EQ(tf2xla_fallback_count, 296); - EXPECT_EQ(non_categorized_count, 419); + EXPECT_EQ(tf2xla_fallback_count, 297); + EXPECT_EQ(non_categorized_count, 418); } // Just a counter test to see which ops have duplicate lowerings. This isn't a From a08c4bc51c7164a7bfd454d675e9478007bc19f0 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 27 Jul 2023 15:06:23 -0700 Subject: [PATCH 275/410] When sharding propagation cannot infer shardings for the operand of a custom call HLO op, set the default input sharding as replicated. Change needed to make a unit test (AutoShardingTest.CustomCallWithUserSharding) pass. PiperOrigin-RevId: 551651797 --- .../xla/hlo/experimental/auto_sharding/auto_sharding.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding.cc index dcaf8d56073109..5d175c0e9cd2f3 100644 --- a/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -153,7 +153,8 @@ GenerateReshardingCostsAndMissingShardingsForAllOperands( } if (!cur_input_sharding.has_value() && ((ins->opcode() == HloOpcode::kGather && k == 0) || - (ins->opcode() == HloOpcode::kScatter && k != 0))) { + (ins->opcode() == HloOpcode::kScatter && k != 0) || + ins->opcode() == HloOpcode::kCustomCall)) { cur_input_sharding = HloSharding::Replicate(); } CHECK(cur_input_sharding.has_value()); From 4d7ea08f306960c3129e523b296baab0b5508628 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 27 Jul 2023 15:33:29 -0700 Subject: [PATCH 276/410] Add type annotations to the return values of convert_to_tensor_v1* and convert_to_tensor_v2*. PiperOrigin-RevId: 551659082 --- .../python/framework/tensor_conversion.py | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/tensorflow/python/framework/tensor_conversion.py b/tensorflow/python/framework/tensor_conversion.py index 0098d6c3d06873..25a89db9c80206 100644 --- a/tensorflow/python/framework/tensor_conversion.py +++ b/tensorflow/python/framework/tensor_conversion.py @@ -13,22 +13,15 @@ # limitations under the License. # ============================================================================== """Tensor conversion functions.""" -import typing - from tensorflow.python.framework import tensor_conversion_registry from tensorflow.python.util import deprecation from tensorflow.python.util import dispatch from tensorflow.python.util import tf_export -if typing.TYPE_CHECKING: - # pylint: disable=g-bad-import-order - from tensorflow.python.framework import tensor as tensor_lib - # pylint: enable=g-bad-import-order - def convert_to_tensor_v1( value, dtype=None, name=None, preferred_dtype=None, dtype_hint=None -) -> "tensor_lib.Tensor": +): """Converts the given `value` to a `Tensor` (with the TF1 API).""" preferred_dtype = deprecation.deprecated_argument_lookup( "dtype_hint", dtype_hint, "preferred_dtype", preferred_dtype @@ -40,7 +33,7 @@ def convert_to_tensor_v1( @dispatch.add_dispatch_support def convert_to_tensor_v1_with_dispatch( value, dtype=None, name=None, preferred_dtype=None, dtype_hint=None -) -> "tensor_lib.Tensor": +): """Converts the given `value` to a `Tensor`. This function converts Python objects of various types to `Tensor` @@ -103,7 +96,7 @@ def my_func(arg): @dispatch.add_dispatch_support def convert_to_tensor_v2_with_dispatch( value, dtype=None, dtype_hint=None, name=None -) -> "tensor_lib.Tensor": +): """Converts the given `value` to a `Tensor`. This function converts Python objects of various types to `Tensor` @@ -169,9 +162,7 @@ def convert_to_tensor_v2_with_dispatch( ) -def convert_to_tensor_v2( - value, dtype=None, dtype_hint=None, name=None -) -> "tensor_lib.Tensor": +def convert_to_tensor_v2(value, dtype=None, dtype_hint=None, name=None): """Converts the given `value` to a `Tensor`.""" # preferred_dtype = preferred_dtype or dtype_hint return tensor_conversion_registry.convert( From 50505993b5325a679b4c38b86c40565e578bf9b4 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 27 Jul 2023 15:50:52 -0700 Subject: [PATCH 277/410] Add support for mesh shapes with >3 dimensions, but will <3 dimensions larger than 1. We handle this by removing all dimensions of size 1. Currently, we enable this only for meshes with >3 dimensions. In a following CL, we will, after more careful consideration, do this for all mesh shapes (regardless of their shape) which will simplify auto-sharding code in a lot of places. PiperOrigin-RevId: 551663524 --- .../auto_sharding/auto_sharding.h | 34 ++++++++++++++++--- .../auto_sharding/auto_sharding_util.h | 2 ++ 2 files changed, 31 insertions(+), 5 deletions(-) diff --git a/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding.h b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding.h index 2af47b648945df..e258f3d359142f 100644 --- a/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding.h +++ b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding.h @@ -259,11 +259,15 @@ struct AutoShardingOption { return absl::OutOfRangeError( "device_mesh_shape is empty and it needs to be specified."); } - if (device_mesh_shape.size() > 3) { - return absl::OutOfRangeError( - absl::StrCat("Not supported: the length of device_mesh_shape is " - "greater than 3, actual length: ", - device_mesh_shape.size())); + std::vector mesh_dims_greater_than_one_indices = + spmd::VectorGreaterThanOneElementIndices(device_mesh_shape); + + if (mesh_dims_greater_than_one_indices.size() > 3) { + return absl::OutOfRangeError(absl::StrCat( + "Not supported: only device_mesh_shapes with 3 or less " + "dimensions larger than 1 are supported. Instead we have ", + mesh_dims_greater_than_one_indices.size(), + " dimensions greater than 1.")); } // All values in device_mesh_shape must be greater than 0. if (absl::c_any_of(device_mesh_shape, @@ -313,6 +317,26 @@ struct AutoShardingOption { "device_mesh_beta, " "please leave them empty and default values will be used.")); } + + if (device_mesh_shape.size() > 3) { + std::vector compressed_device_mesh_shape; + std::vector compressed_device_mesh_alpha; + std::vector compressed_device_mesh_beta; + int non_zero_counter = 0; + for (size_t i = 0; i < device_mesh_shape.size(); ++i) { + if (non_zero_counter < mesh_dims_greater_than_one_indices.size() && + i == mesh_dims_greater_than_one_indices[non_zero_counter]) { + non_zero_counter++; + compressed_device_mesh_shape.push_back(device_mesh_shape[i]); + compressed_device_mesh_alpha.push_back(device_mesh_alpha[i]); + compressed_device_mesh_beta.push_back(device_mesh_beta[i]); + } + } + this->device_mesh_shape = compressed_device_mesh_shape; + this->device_mesh_alpha = compressed_device_mesh_alpha; + this->device_mesh_beta = compressed_device_mesh_beta; + } + int64_t total_devices = 1; for (auto i : device_mesh_shape) { total_devices *= i; diff --git a/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_util.h b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_util.h index 1f53678fc5d381..37801ebfbbb2e9 100644 --- a/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_util.h +++ b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_util.h @@ -559,6 +559,8 @@ Array Transpose(const Array array, std::vector axes) { size_t VectorGreaterThanOneElementCount(absl::Span span, bool omit_last_dim = false); +// This functions returns the indices of all vector elements larger than 1, in +// order. std::vector VectorGreaterThanOneElementIndices( absl::Span span, bool omit_last_dim = false); From f433fe3ce95a2e7dbb1b579479060fa1f89e8ecf Mon Sep 17 00:00:00 2001 From: Matt Callanan Date: Thu, 27 Jul 2023 15:54:37 -0700 Subject: [PATCH 278/410] #tf-data Ramp up `"file_locality_v2"` experiment to 50%. PiperOrigin-RevId: 551664452 --- tensorflow/core/data/dataset_utils.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/data/dataset_utils.cc b/tensorflow/core/data/dataset_utils.cc index 535e24eb718ac5..0a86d042ec010a 100644 --- a/tensorflow/core/data/dataset_utils.cc +++ b/tensorflow/core/data/dataset_utils.cc @@ -978,7 +978,7 @@ REGISTER_DATASET_EXPERIMENT("data_transfer", RandomJobSamplePercentage<0>, AllTasks); REGISTER_DATASET_EXPERIMENT("file_locality", RandomJobSamplePercentage<0>, IndependentHostTasks); -REGISTER_DATASET_EXPERIMENT("file_locality_v2", RandomJobSamplePercentage<10>, +REGISTER_DATASET_EXPERIMENT("file_locality_v2", RandomJobSamplePercentage<50>, IndependentHostTasks); } // namespace } // namespace data From 7e0eb786857bda3e91f386af9ec36969adc83365 Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Thu, 27 Jul 2023 16:05:36 -0700 Subject: [PATCH 279/410] [PJRT:C] Make PjRtCApiBuffer::has_dynamic_dimensions return false if unimplemented This way we can more gracefully handle plugins that don't have dynamic shapes implemented. PiperOrigin-RevId: 551667221 --- tensorflow/compiler/xla/pjrt/c/pjrt_c_api_helpers.cc | 12 ++++++++---- tensorflow/compiler/xla/pjrt/c/pjrt_c_api_helpers.h | 2 ++ tensorflow/compiler/xla/pjrt/pjrt_c_api_client.cc | 12 ++++++++++-- 3 files changed, 20 insertions(+), 6 deletions(-) diff --git a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_helpers.cc b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_helpers.cc index 501531e0747acc..d50835a216f0f1 100644 --- a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_helpers.cc +++ b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_helpers.cc @@ -119,14 +119,18 @@ PJRT_TopologyDescriptionDeleter MakeTopologyDescriptionDeleter( }; } -absl::StatusCode PjrtErrorToStatusCode(const PJRT_Error* error, - const PJRT_Api* api) { +PJRT_Error_Code GetErrorCode(const PJRT_Error* error, const PJRT_Api* api) { PJRT_Error_GetCode_Args args; args.struct_size = PJRT_Error_GetCode_Args_STRUCT_SIZE; args.priv = nullptr; args.error = error; - api->PJRT_Error_GetCode(&args); - PJRT_Error_Code code = args.code; + pjrt::LogFatalIfPjrtError(api->PJRT_Error_GetCode(&args), api); + return args.code; +} + +absl::StatusCode PjrtErrorToStatusCode(const PJRT_Error* error, + const PJRT_Api* api) { + PJRT_Error_Code code = GetErrorCode(error, api); switch (code) { case PJRT_Error_Code_CANCELLED: case PJRT_Error_Code_UNKNOWN: diff --git a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_helpers.h b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_helpers.h index ac9320d5edc8ee..9901824f6056fa 100644 --- a/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_helpers.h +++ b/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_helpers.h @@ -110,6 +110,8 @@ void LogFatalIfPjrtError(PJRT_Error* error, const PJRT_Api* api); absl::string_view GetPjrtErrorMessage(const PJRT_Error* error, const PJRT_Api* api); +PJRT_Error_Code GetErrorCode(const PJRT_Error* error, const PJRT_Api* api); + xla::Status PjrtErrorToStatus(const PJRT_Error* error, const PJRT_Api* api); absl::StatusCode PjrtErrorToStatusCode(const PJRT_Error* error, diff --git a/tensorflow/compiler/xla/pjrt/pjrt_c_api_client.cc b/tensorflow/compiler/xla/pjrt/pjrt_c_api_client.cc index ebad86d42f3811..6bbeb217b19515 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_c_api_client.cc +++ b/tensorflow/compiler/xla/pjrt/pjrt_c_api_client.cc @@ -1349,8 +1349,16 @@ bool PjRtCApiBuffer::has_dynamic_dimensions() const { args.struct_size = PJRT_Buffer_DynamicDimensionIndices_Args_STRUCT_SIZE; args.priv = nullptr; args.buffer = buffer_.get(); - pjrt::LogFatalIfPjrtError( - pjrt_c_api()->PJRT_Buffer_DynamicDimensionIndices(&args), pjrt_c_api()); + + const PJRT_Api* api = pjrt_c_api(); + std::unique_ptr error( + api->PJRT_Buffer_DynamicDimensionIndices(&args), + pjrt::MakeErrorDeleter(api)); + + if (error && + pjrt::GetErrorCode(error.get(), api) == PJRT_Error_Code_UNIMPLEMENTED) { + return false; + } return args.num_dynamic_dims > 0; } From 9daa492bf932267084dd80dc23f8557e161b1e54 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 27 Jul 2023 16:33:54 -0700 Subject: [PATCH 280/410] Remove all dimensions from the mesh shape with a length of 1. This simplifies some code in auto-sharding. PiperOrigin-RevId: 551674599 --- .../auto_sharding/auto_sharding.cc | 38 ++--- .../auto_sharding/auto_sharding.h | 44 +++-- .../auto_sharding_dot_handler.cc | 160 ++++++++---------- 3 files changed, 98 insertions(+), 144 deletions(-) diff --git a/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding.cc index 5d175c0e9cd2f3..7024ae30066361 100644 --- a/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -624,8 +624,6 @@ void EnumerateAll2DPartition(const HloInstruction* ins, const Shape& shape, const InstructionBatchDimMap& batch_dim_map, bool only_allow_divisible, const CallGraph& call_graph) { - std::vector shardable_mesh_dims = - VectorGreaterThanOneElementIndices(device_mesh.dimensions()); auto iter = batch_dim_map.find(GetBatchDimMapKey(ins)); int64_t batch_dim = -1; if (iter != batch_dim_map.end()) { @@ -637,23 +635,19 @@ void EnumerateAll2DPartition(const HloInstruction* ins, const Shape& shape, if ((batch_dim != -1 && !(batch_dim == i || batch_dim == j)) || i == j) { continue; } - if (shape.dimensions(i) < device_mesh.dim(shardable_mesh_dims[0]) || - shape.dimensions(j) < device_mesh.dim(shardable_mesh_dims[1])) { + if (shape.dimensions(i) < device_mesh.dim(0) || + shape.dimensions(j) < device_mesh.dim(1)) { continue; } if (only_allow_divisible && - (!IsDivisible(shape.dimensions(i), - device_mesh.dim(shardable_mesh_dims[0])) || - !IsDivisible(shape.dimensions(j), - device_mesh.dim(shardable_mesh_dims[1])))) { + (!IsDivisible(shape.dimensions(i), device_mesh.dim(0)) || + !IsDivisible(shape.dimensions(j), device_mesh.dim(1)))) { continue; } std::string name = absl::StrFormat("S{%d,%d} @ {0,1}", i, j); - HloSharding output_spec = - Tile(shape, {i, j}, {shardable_mesh_dims[0], shardable_mesh_dims[1]}, - device_mesh); + HloSharding output_spec = Tile(shape, {i, j}, {0, 1}, device_mesh); double compute_cost = 0, communication_cost = 0; double memory_cost = GetBytes(shape) / output_spec.NumTiles(); std::vector input_shardings; @@ -757,8 +751,6 @@ void Enumerate2DPartitionReshape(const HloInstruction* ins, const InstructionBatchDimMap& batch_dim_map, std::unique_ptr& strategies, bool only_allow_divisible) { - std::vector shardable_mesh_dims = - VectorGreaterThanOneElementIndices(device_mesh.dimensions()); auto iter = batch_dim_map.find(GetBatchDimMapKey(ins)); int64_t batch_dim = -1; if (iter != batch_dim_map.end()) { @@ -773,23 +765,17 @@ void Enumerate2DPartitionReshape(const HloInstruction* ins, if ((batch_dim != -1 && !(batch_dim == i || batch_dim == j)) || i == j) { continue; } - if (ins->shape().dimensions(i) < - device_mesh.dim(shardable_mesh_dims[0]) || - ins->shape().dimensions(j) < - device_mesh.dim(shardable_mesh_dims[1])) { + if (ins->shape().dimensions(i) < device_mesh.dim(0) || + ins->shape().dimensions(j) < device_mesh.dim(1)) { continue; } if (only_allow_divisible && - (!IsDivisible(ins->shape().dimensions(i), - device_mesh.dim(shardable_mesh_dims[0])) || - !IsDivisible(ins->shape().dimensions(j), - device_mesh.dim(shardable_mesh_dims[1])))) { + (!IsDivisible(ins->shape().dimensions(i), device_mesh.dim(0)) || + !IsDivisible(ins->shape().dimensions(j), device_mesh.dim(1)))) { continue; } - HloSharding output_spec = - Tile(ins->shape(), {i, j}, - {shardable_mesh_dims[0], shardable_mesh_dims[1]}, device_mesh); + HloSharding output_spec = Tile(ins->shape(), {i, j}, {0, 1}, device_mesh); std::optional input_spec = hlo_sharding_util::ReshapeSharding(ins->shape(), operand->shape(), output_spec); @@ -797,9 +783,7 @@ void Enumerate2DPartitionReshape(const HloInstruction* ins, continue; } - std::string name = - absl::StrFormat("S%d%d @ {%d,%d}", i, j, shardable_mesh_dims[0], - shardable_mesh_dims[1]); + std::string name = absl::StrFormat("S%d%d @ {%d,%d}", i, j, 0, 1); double compute_cost = 0, communication_cost = 0; double memory_cost = GetBytes(ins->shape()) / output_spec.NumTiles(); diff --git a/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding.h b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding.h index e258f3d359142f..e9cbd7f27c1489 100644 --- a/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding.h +++ b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding.h @@ -299,13 +299,6 @@ struct AutoShardingOption { << absl::StrJoin(device_mesh_beta, ","); } - // If device_mesh_shape has only one value, append 1 to it - if (device_mesh_shape.size() == 1) { - device_mesh_shape.push_back(1); - device_mesh_alpha.push_back(1.0); - device_mesh_beta.push_back(1.0); - } - if (device_mesh_shape.size() != device_mesh_alpha.size() || device_mesh_shape.size() != device_mesh_beta.size()) { return absl::OutOfRangeError(absl::StrCat( @@ -318,23 +311,28 @@ struct AutoShardingOption { "please leave them empty and default values will be used.")); } - if (device_mesh_shape.size() > 3) { - std::vector compressed_device_mesh_shape; - std::vector compressed_device_mesh_alpha; - std::vector compressed_device_mesh_beta; - int non_zero_counter = 0; - for (size_t i = 0; i < device_mesh_shape.size(); ++i) { - if (non_zero_counter < mesh_dims_greater_than_one_indices.size() && - i == mesh_dims_greater_than_one_indices[non_zero_counter]) { - non_zero_counter++; - compressed_device_mesh_shape.push_back(device_mesh_shape[i]); - compressed_device_mesh_alpha.push_back(device_mesh_alpha[i]); - compressed_device_mesh_beta.push_back(device_mesh_beta[i]); - } + std::vector compressed_device_mesh_shape; + std::vector compressed_device_mesh_alpha; + std::vector compressed_device_mesh_beta; + int non_zero_counter = 0; + for (size_t i = 0; i < device_mesh_shape.size(); ++i) { + if (non_zero_counter < mesh_dims_greater_than_one_indices.size() && + i == mesh_dims_greater_than_one_indices[non_zero_counter]) { + non_zero_counter++; + compressed_device_mesh_shape.push_back(device_mesh_shape[i]); + compressed_device_mesh_alpha.push_back(device_mesh_alpha[i]); + compressed_device_mesh_beta.push_back(device_mesh_beta[i]); } - this->device_mesh_shape = compressed_device_mesh_shape; - this->device_mesh_alpha = compressed_device_mesh_alpha; - this->device_mesh_beta = compressed_device_mesh_beta; + } + this->device_mesh_shape = compressed_device_mesh_shape; + this->device_mesh_alpha = compressed_device_mesh_alpha; + this->device_mesh_beta = compressed_device_mesh_beta; + + // If device_mesh_shape has only one value, append 1 to it + if (device_mesh_shape.size() == 1) { + device_mesh_shape.push_back(1); + device_mesh_alpha.push_back(1.0); + device_mesh_beta.push_back(1.0); } int64_t total_devices = 1; diff --git a/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc index e50134cd4bf6ad..c087a3ee822323 100644 --- a/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc +++ b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc @@ -618,74 +618,62 @@ class DotHandler { } Status RegisterStrategies() { - std::vector shardable_mesh_dims = - VectorGreaterThanOneElementIndices(device_mesh_.dimensions()); - // For 1D sharding - if (shardable_mesh_dims.size() == 1) { - shardable_mesh_dims.push_back((shardable_mesh_dims.at(0) + 1) % - device_mesh_.num_dimensions()); - } + auto mesh_shape = device_mesh_.dimensions(); // SS = SR x RS // Split lhs space dim and rhs space dim. - for (int64_t i = 0; i < shardable_mesh_dims.size(); ++i) { - for (int64_t j = (i + 1); j < shardable_mesh_dims.size(); ++j) { - SplitLhsSpaceRhsSpace(shardable_mesh_dims[i], shardable_mesh_dims[j]); - SplitLhsSpaceRhsSpace(shardable_mesh_dims[j], shardable_mesh_dims[i]); + for (int64_t i = 0; i < mesh_shape.size(); ++i) { + for (int64_t j = (i + 1); j < mesh_shape.size(); ++j) { + SplitLhsSpaceRhsSpace(i, j); + SplitLhsSpaceRhsSpace(j, i); } } // SSR = SSR x RR // Split lhs space dims only if it has more than 1 space dims. if (lhs_space_dims_.size() > 1) { - for (int64_t i = 0; i < shardable_mesh_dims.size(); ++i) { - for (int64_t j = (i + 1); j < shardable_mesh_dims.size(); ++j) { - SplitLhsSpaceOnly(shardable_mesh_dims[i], shardable_mesh_dims[j]); - SplitLhsSpaceOnly(shardable_mesh_dims[j], shardable_mesh_dims[i]); + for (int64_t i = 0; i < mesh_shape.size(); ++i) { + for (int64_t j = (i + 1); j < mesh_shape.size(); ++j) { + SplitLhsSpaceOnly(i, j); + SplitLhsSpaceOnly(j, i); } } } // RSS = RR x RSS // Split rhs space dims only if it has more than 1 space dims. if (rhs_space_dims_.size() > 1) { - for (int64_t i = 0; i < shardable_mesh_dims.size(); ++i) { - for (int64_t j = (i + 1); j < shardable_mesh_dims.size(); ++j) { - SplitRhsSpaceOnly(shardable_mesh_dims[i], shardable_mesh_dims[j]); - SplitRhsSpaceOnly(shardable_mesh_dims[j], shardable_mesh_dims[i]); + for (int64_t i = 0; i < mesh_shape.size(); ++i) { + for (int64_t j = (i + 1); j < mesh_shape.size(); ++j) { + SplitRhsSpaceOnly(i, j); + SplitRhsSpaceOnly(j, i); } } } // SR = SS x SR // Split lhs space dim and both contracting dims. - for (int64_t i = 0; i < shardable_mesh_dims.size(); ++i) { - for (int64_t j = (i + 1); j < shardable_mesh_dims.size(); ++j) { - SplitLhsSpaceBothContract(shardable_mesh_dims[i], - shardable_mesh_dims[j]); - SplitLhsSpaceBothContract(shardable_mesh_dims[j], - shardable_mesh_dims[i]); + for (int64_t i = 0; i < mesh_shape.size(); ++i) { + for (int64_t j = (i + 1); j < mesh_shape.size(); ++j) { + SplitLhsSpaceBothContract(i, j); + SplitLhsSpaceBothContract(j, i); } } // RS = RS x SS // Split rhs space dim and both contracting dims. - for (int64_t i = 0; i < shardable_mesh_dims.size(); ++i) { - for (int64_t j = (i + 1); j < shardable_mesh_dims.size(); ++j) { - SplitRhsSpaceBothContract(shardable_mesh_dims[i], - shardable_mesh_dims[j]); - SplitRhsSpaceBothContract(shardable_mesh_dims[j], - shardable_mesh_dims[i]); + for (int64_t i = 0; i < mesh_shape.size(); ++i) { + for (int64_t j = (i + 1); j < mesh_shape.size(); ++j) { + SplitRhsSpaceBothContract(i, j); + SplitRhsSpaceBothContract(j, i); } } // RR = SS x SS // Split two contracting dims on lhs and rhs. - for (int64_t i = 0; i < shardable_mesh_dims.size(); ++i) { - for (int64_t j = (i + 1); j < shardable_mesh_dims.size(); ++j) { - SplitBothContractTwoDims(shardable_mesh_dims[i], - shardable_mesh_dims[j]); - SplitBothContractTwoDims(shardable_mesh_dims[j], - shardable_mesh_dims[i]); + for (int64_t i = 0; i < mesh_shape.size(); ++i) { + for (int64_t j = (i + 1); j < mesh_shape.size(); ++j) { + SplitBothContractTwoDims(i, j); + SplitBothContractTwoDims(j, i); } } @@ -693,12 +681,10 @@ class DotHandler { // This is a special case where we allow spliting only one dim in the // multi-dimensional mesh case. This allows some recomputation // (e.g., the dense layer in the LM_head of BERT). - for (int64_t i = 0; i < shardable_mesh_dims.size(); ++i) { - for (int64_t j = (i + 1); j < shardable_mesh_dims.size(); ++j) { - RecomputeSplitBothContract(shardable_mesh_dims[i], - shardable_mesh_dims[j]); - RecomputeSplitBothContract(shardable_mesh_dims[j], - shardable_mesh_dims[i]); + for (int64_t i = 0; i < mesh_shape.size(); ++i) { + for (int64_t j = (i + 1); j < mesh_shape.size(); ++j) { + RecomputeSplitBothContract(i, j); + RecomputeSplitBothContract(j, i); } } @@ -721,30 +707,28 @@ class DotHandler { // SbSi = SbSi x SbR // Split batch dim and lhs space dim - for (int64_t i = 0; i < shardable_mesh_dims.size(); ++i) { - for (int64_t j = (i + 1); j < shardable_mesh_dims.size(); ++j) { - SplitBatchDimLhsSpace(shardable_mesh_dims[i], shardable_mesh_dims[j]); - SplitBatchDimLhsSpace(shardable_mesh_dims[j], shardable_mesh_dims[i]); + for (int64_t i = 0; i < mesh_shape.size(); ++i) { + for (int64_t j = (i + 1); j < mesh_shape.size(); ++j) { + SplitBatchDimLhsSpace(i, j); + SplitBatchDimLhsSpace(j, i); } } // SbSj = SbR x SbSj // Split batch dim and rhs space dim - for (int64_t i = 0; i < shardable_mesh_dims.size(); ++i) { - for (int64_t j = (i + 1); j < shardable_mesh_dims.size(); ++j) { - SplitBatchDimRhsSpace(shardable_mesh_dims[i], shardable_mesh_dims[j]); - SplitBatchDimRhsSpace(shardable_mesh_dims[j], shardable_mesh_dims[i]); + for (int64_t i = 0; i < mesh_shape.size(); ++i) { + for (int64_t j = (i + 1); j < mesh_shape.size(); ++j) { + SplitBatchDimRhsSpace(i, j); + SplitBatchDimRhsSpace(j, i); } } // SbSj = SbR x SbSj // Split batch dim and contracting dim - for (int64_t i = 0; i < shardable_mesh_dims.size(); ++i) { - for (int64_t j = (i + 1); j < shardable_mesh_dims.size(); ++j) { - SplitBatchDimBothContract(shardable_mesh_dims[i], - shardable_mesh_dims[j]); - SplitBatchDimBothContract(shardable_mesh_dims[j], - shardable_mesh_dims[i]); + for (int64_t i = 0; i < mesh_shape.size(); ++i) { + for (int64_t j = (i + 1); j < mesh_shape.size(); ++j) { + SplitBatchDimBothContract(i, j); + SplitBatchDimBothContract(j, i); } } @@ -759,10 +743,10 @@ class DotHandler { // Sb = Sb x Sb // Split batch dims. - for (int64_t i = 0; i < shardable_mesh_dims.size(); ++i) { - for (int64_t j = (i + 1); j < shardable_mesh_dims.size(); ++j) { - SplitTwoBatchDims(shardable_mesh_dims[i], shardable_mesh_dims[j]); - SplitTwoBatchDims(shardable_mesh_dims[j], shardable_mesh_dims[i]); + for (int64_t i = 0; i < mesh_shape.size(); ++i) { + for (int64_t j = (i + 1); j < mesh_shape.size(); ++j) { + SplitTwoBatchDims(i, j); + SplitTwoBatchDims(j, i); } } @@ -968,14 +952,8 @@ class ConvHandler { } Status RegisterStrategies() { - std::vector shardable_mesh_dims = - VectorGreaterThanOneElementIndices(device_mesh_.dimensions()); + auto mesh_shape = device_mesh_.dimensions(); // For 1D sharding - if (shardable_mesh_dims.size() == 1) { - shardable_mesh_dims.push_back((shardable_mesh_dims.at(0) + 1) % - device_mesh_.num_dimensions()); - } - if ((ins_->feature_group_count() == lhs_->shape().dimensions(lhs_in_channel_dim_) && ins_->feature_group_count() == @@ -983,10 +961,10 @@ class ConvHandler { // for depthwise conv // SS = SS x S // Split batch dim and channel dim - for (int64_t i = 0; i < shardable_mesh_dims.size(); ++i) { - for (int64_t j = (i + 1); j < shardable_mesh_dims.size(); ++j) { - SplitDepthwise(shardable_mesh_dims[i], shardable_mesh_dims[j], true); - SplitDepthwise(shardable_mesh_dims[j], shardable_mesh_dims[i], true); + for (int64_t i = 0; i < mesh_shape.size(); ++i) { + for (int64_t j = (i + 1); j < mesh_shape.size(); ++j) { + SplitDepthwise(i, j, true); + SplitDepthwise(j, i, true); } } } else if ((ins_->batch_group_count() == @@ -996,44 +974,38 @@ class ConvHandler { // for depthwise conv filter_backward // SS = SS x S // Split batch dim and channel dim - for (int64_t i = 0; i < shardable_mesh_dims.size(); ++i) { - for (int64_t j = (i + 1); j < shardable_mesh_dims.size(); ++j) { - SplitDepthwise(shardable_mesh_dims[i], shardable_mesh_dims[j], false); - SplitDepthwise(shardable_mesh_dims[j], shardable_mesh_dims[i], false); + for (int64_t i = 0; i < mesh_shape.size(); ++i) { + for (int64_t j = (i + 1); j < mesh_shape.size(); ++j) { + SplitDepthwise(i, j, false); + SplitDepthwise(j, i, false); } } } // SS = SR x RS // Split lhs batch dim and rhs out_channel dim. - for (int64_t i = 0; i < shardable_mesh_dims.size(); ++i) { - for (int64_t j = (i + 1); j < shardable_mesh_dims.size(); ++j) { - SplitLhsBatchRhsOutchannel(shardable_mesh_dims[i], - shardable_mesh_dims[j]); - SplitLhsBatchRhsOutchannel(shardable_mesh_dims[j], - shardable_mesh_dims[i]); + for (int64_t i = 0; i < mesh_shape.size(); ++i) { + for (int64_t j = (i + 1); j < mesh_shape.size(); ++j) { + SplitLhsBatchRhsOutchannel(i, j); + SplitLhsBatchRhsOutchannel(j, i); } } // SR = SS x SR // Split lhs batch dim and both in_channel dims. - for (int64_t i = 0; i < shardable_mesh_dims.size(); ++i) { - for (int64_t j = (i + 1); j < shardable_mesh_dims.size(); ++j) { - SplitLhsBatchBothInchannel(shardable_mesh_dims[i], - shardable_mesh_dims[j]); - SplitLhsBatchBothInchannel(shardable_mesh_dims[j], - shardable_mesh_dims[i]); + for (int64_t i = 0; i < mesh_shape.size(); ++i) { + for (int64_t j = (i + 1); j < mesh_shape.size(); ++j) { + SplitLhsBatchBothInchannel(i, j); + SplitLhsBatchBothInchannel(j, i); } } // RS = RS x SS // Split rhs out_channel dim and both in_channel dims. - for (int64_t i = 0; i < shardable_mesh_dims.size(); ++i) { - for (int64_t j = (i + 1); j < shardable_mesh_dims.size(); ++j) { - SplitRhsOutchannelBothInchannel(shardable_mesh_dims[i], - shardable_mesh_dims[j]); - SplitRhsOutchannelBothInchannel(shardable_mesh_dims[j], - shardable_mesh_dims[i]); + for (int64_t i = 0; i < mesh_shape.size(); ++i) { + for (int64_t j = (i + 1); j < mesh_shape.size(); ++j) { + SplitRhsOutchannelBothInchannel(i, j); + SplitRhsOutchannelBothInchannel(j, i); } } From 222a63c5684971c9038ee0140750fde7c5da7d1e Mon Sep 17 00:00:00 2001 From: David Silverstone Date: Thu, 27 Jul 2023 16:37:03 -0700 Subject: [PATCH 281/410] [Refactoring] Replace `absl::optional`, `absl::variant`, and related with `std::` PiperOrigin-RevId: 551675329 --- .../compiler/xla/stream_executor/tpu/BUILD | 10 +---- .../stream_executor/tpu/noncopyable_buffer.h | 2 +- .../tpu/tpu_executable_interface.h | 1 - .../tpu/tpu_initializer_helper.cc | 1 - .../stream_executor/tpu/tpu_node_context.cc | 1 + .../stream_executor/tpu/tpu_node_context.h | 6 +-- .../tpu/tpu_on_demand_compiler.cc | 6 --- tensorflow/core/tpu/BUILD | 5 +-- tensorflow/core/tpu/graph_rewrite/BUILD | 2 - .../distributed_tpu_rewrite_pass.cc | 29 +++++------- .../encapsulate_tpu_computations_pass.cc | 22 ++++----- .../host_training_loop_optimization_util.h | 8 ++-- .../update_tpu_embedding_ops_passes.cc | 6 +-- tensorflow/core/tpu/kernels/BUILD | 45 +++++-------------- .../kernels/tpu_compilation_cache_external.cc | 12 ++--- .../kernels/tpu_compilation_cache_external.h | 20 +-------- .../tpu/kernels/tpu_compilation_cache_key.h | 1 - .../tpu_compilation_cache_rpc_support.cc | 3 +- .../core/tpu/kernels/tpu_compile_op_common.cc | 1 - .../core/tpu/kernels/tpu_compile_op_common.h | 4 +- .../core/tpu/kernels/tpu_compile_op_impl.cc | 5 ++- .../core/tpu/kernels/tpu_compile_op_impl.h | 6 +-- .../tpu/kernels/tpu_compile_op_support.cc | 4 +- .../core/tpu/kernels/tpu_compile_op_support.h | 6 +-- tensorflow/core/tpu/kernels/tpu_execute_op.cc | 10 +---- .../core/tpu/kernels/tpu_fingerprint_lookup.h | 3 +- .../core/tpu/kernels/tpu_program_group.h | 1 - tensorflow/core/tpu/tpu_compile.cc | 24 ++++++---- tensorflow/core/tpu/tpu_compile.h | 5 +++ tensorflow/core/tpu/tpu_global_init.cc | 8 +--- 30 files changed, 88 insertions(+), 169 deletions(-) diff --git a/tensorflow/compiler/xla/stream_executor/tpu/BUILD b/tensorflow/compiler/xla/stream_executor/tpu/BUILD index e47afbc5fe747a..e6b8c4c0083bee 100644 --- a/tensorflow/compiler/xla/stream_executor/tpu/BUILD +++ b/tensorflow/compiler/xla/stream_executor/tpu/BUILD @@ -110,7 +110,6 @@ cc_library( "//tensorflow/tsl/platform:platform_port", "@com_google_absl//absl/base", "@com_google_absl//absl/functional:function_ref", - "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], ) @@ -298,16 +297,15 @@ cc_library( deps = [ ":status_helper", ":tpu_api", + ":tpu_ops_c_api_hdrs", ":tpu_platform_interface", "//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service:backend", "//tensorflow/compiler/xla/service:stream_pool", - "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/compiler/xla/stream_executor:device_memory_allocator", "//tensorflow/tsl/platform:macros", "//tensorflow/tsl/platform:status", "//tensorflow/tsl/platform:statusor", - "@com_google_absl//absl/memory", ], ) @@ -465,18 +463,13 @@ cc_library( ":tpu_executor_c_api_hdrs", ":tpu_executor_hdrs", ":tpu_platform_id", - "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/hlo/ir:hlo", "//tensorflow/compiler/xla/hlo/ir:hlo_module_group", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:hlo_cost_analysis", - "//tensorflow/compiler/xla/service:shaped_buffer", - "//tensorflow/compiler/xla/stream_executor:device_memory_allocator", "@com_google_absl//absl/cleanup", - "@com_google_absl//absl/types:span", ], alwayslink = True, ) @@ -512,7 +505,6 @@ cc_library( "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/compiler/xla/stream_executor:stream_executor_headers", "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], ) diff --git a/tensorflow/compiler/xla/stream_executor/tpu/noncopyable_buffer.h b/tensorflow/compiler/xla/stream_executor/tpu/noncopyable_buffer.h index cde608dc2ef49d..dedc1c0b3d8292 100644 --- a/tensorflow/compiler/xla/stream_executor/tpu/noncopyable_buffer.h +++ b/tensorflow/compiler/xla/stream_executor/tpu/noncopyable_buffer.h @@ -19,11 +19,11 @@ limitations under the License. #include #include #include +#include #include #include "absl/base/casts.h" #include "absl/functional/function_ref.h" -#include "absl/types/optional.h" #include "absl/types/span.h" #include "tensorflow/tsl/platform/logging.h" #include "tensorflow/tsl/platform/mem.h" diff --git a/tensorflow/compiler/xla/stream_executor/tpu/tpu_executable_interface.h b/tensorflow/compiler/xla/stream_executor/tpu/tpu_executable_interface.h index 2a7d678b6060f3..425789ec41beae 100644 --- a/tensorflow/compiler/xla/stream_executor/tpu/tpu_executable_interface.h +++ b/tensorflow/compiler/xla/stream_executor/tpu/tpu_executable_interface.h @@ -19,7 +19,6 @@ limitations under the License. #include #include -#include "absl/types/optional.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_input_output_alias_config.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_module.h" diff --git a/tensorflow/compiler/xla/stream_executor/tpu/tpu_initializer_helper.cc b/tensorflow/compiler/xla/stream_executor/tpu/tpu_initializer_helper.cc index 645dfc84dcc0c8..e33248ae9ca902 100644 --- a/tensorflow/compiler/xla/stream_executor/tpu/tpu_initializer_helper.cc +++ b/tensorflow/compiler/xla/stream_executor/tpu/tpu_initializer_helper.cc @@ -25,7 +25,6 @@ limitations under the License. #include #include -#include #include #include #include diff --git a/tensorflow/compiler/xla/stream_executor/tpu/tpu_node_context.cc b/tensorflow/compiler/xla/stream_executor/tpu/tpu_node_context.cc index 4b08a171e8967e..7124ecb8a1728a 100644 --- a/tensorflow/compiler/xla/stream_executor/tpu/tpu_node_context.cc +++ b/tensorflow/compiler/xla/stream_executor/tpu/tpu_node_context.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/stream_executor/tpu/status_helper.h" #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_api.h" namespace tensorflow { diff --git a/tensorflow/compiler/xla/stream_executor/tpu/tpu_node_context.h b/tensorflow/compiler/xla/stream_executor/tpu/tpu_node_context.h index 20c66fb2c33612..759814dcf2d801 100644 --- a/tensorflow/compiler/xla/stream_executor/tpu/tpu_node_context.h +++ b/tensorflow/compiler/xla/stream_executor/tpu/tpu_node_context.h @@ -16,14 +16,12 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_TPU_TPU_NODE_CONTEXT_H_ #define TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_TPU_TPU_NODE_CONTEXT_H_ -#include +#include -#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/stream_pool.h" -#include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/stream_executor/device_memory_allocator.h" -#include "tensorflow/compiler/xla/stream_executor/tpu/status_helper.h" +#include "tensorflow/compiler/xla/stream_executor/tpu/tpu_ops_c_api.h" #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_platform_interface.h" #include "tensorflow/tsl/platform/macros.h" #include "tensorflow/tsl/platform/status.h" diff --git a/tensorflow/compiler/xla/stream_executor/tpu/tpu_on_demand_compiler.cc b/tensorflow/compiler/xla/stream_executor/tpu/tpu_on_demand_compiler.cc index 02b56476643c7e..0e51d2b953b924 100644 --- a/tensorflow/compiler/xla/stream_executor/tpu/tpu_on_demand_compiler.cc +++ b/tensorflow/compiler/xla/stream_executor/tpu/tpu_on_demand_compiler.cc @@ -15,15 +15,11 @@ limitations under the License. #include #include "absl/cleanup/cleanup.h" -#include "absl/types/span.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_module.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_module_group.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" -#include "tensorflow/compiler/xla/service/shaped_buffer.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/stream_executor/device_memory_allocator.h" #include "tensorflow/compiler/xla/stream_executor/tpu/c_api_conversions.h" #include "tensorflow/compiler/xla/stream_executor/tpu/c_api_decl.h" #include "tensorflow/compiler/xla/stream_executor/tpu/proto_helper.h" @@ -31,10 +27,8 @@ limitations under the License. #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_executable.h" #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_executor.h" #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_executor_c_api.h" -#include "tensorflow/compiler/xla/stream_executor/tpu/tpu_platform.h" #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_platform_id.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { diff --git a/tensorflow/core/tpu/BUILD b/tensorflow/core/tpu/BUILD index 6344b65ce6c16b..b4bb40dc372755 100644 --- a/tensorflow/core/tpu/BUILD +++ b/tensorflow/core/tpu/BUILD @@ -223,11 +223,10 @@ cc_library( "//tensorflow/compiler/xla/client:compile_only_client", "//tensorflow/core:core_cpu_base", "//tensorflow/core/framework:attr_value_proto_cc", - "//tensorflow/core/framework:versions_proto_cc", - "//tensorflow/core/platform:statusor", + "//tensorflow/core/platform:status", "//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc", "//tensorflow/core/tpu/kernels:tpu_compile_op_support", - "//tensorflow/core/tpu/kernels:tpu_util_hdrs", + "@com_google_absl//absl/types:span", ], ) diff --git a/tensorflow/core/tpu/graph_rewrite/BUILD b/tensorflow/core/tpu/graph_rewrite/BUILD index 5873bf83aca7bd..fac15fc3bf0197 100644 --- a/tensorflow/core/tpu/graph_rewrite/BUILD +++ b/tensorflow/core/tpu/graph_rewrite/BUILD @@ -241,13 +241,11 @@ cc_library( "//tensorflow/compiler/tf2xla:tf2xla_util", "//tensorflow/core:core_cpu", "//tensorflow/core:framework_internal", - "//tensorflow/core:graph", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:node_hash_set", - "@com_google_absl//absl/types:optional", ], ) diff --git a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc index 9424a9c3b9ee48..70e88e8b59e8fd 100644 --- a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc +++ b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc @@ -36,18 +36,15 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/side_effect_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/sharding_builder.h" #include "tensorflow/compiler/xla/service/computation_placer.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_api.h" #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_ops_c_api.h" -#include "tensorflow/compiler/xla/stream_executor/tpu/tpu_platform_interface.h" #include "tensorflow/compiler/xla/xla.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/common_runtime/device_propagation.h" -#include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/common_runtime/lower_function_call_op.h" #include "tensorflow/core/common_runtime/lower_functional_ops.h" @@ -58,23 +55,18 @@ limitations under the License. #include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/node_def_util.h" -#include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/framework/versions.pb.h" -#include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/math/math_util.h" #include "tensorflow/core/lib/strings/proto_serialization.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/error_payloads.h" -#include "tensorflow/core/platform/fingerprint.h" #include "tensorflow/core/protobuf/core_platform_payloads.pb.h" #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" -#include "tensorflow/core/protobuf/tpu/dynamic_padding.pb.h" #include "tensorflow/core/protobuf/tpu/topology.pb.h" #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/tpu/graph_rewrite/cond_builder.h" @@ -82,7 +74,6 @@ limitations under the License. #include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass_internal.h" #include "tensorflow/core/tpu/graph_rewrite/host_training_loop_optimization_util.h" #include "tensorflow/core/tpu/graph_rewrite/incomplete_nodedef_builder.h" -#include "tensorflow/core/tpu/tpu_compile_interface.h" #include "tensorflow/core/tpu/tpu_defs.h" #include "tensorflow/core/tpu/tpu_fingerprint_utils.h" #include "tensorflow/core/util/device_name_utils.h" @@ -1425,7 +1416,7 @@ struct NodeAndSharding { Status ParseAndValidateSharding(const NodeAndSharding& node_and_sharding, const int num_cores_per_replica, int64_t* inferred_core_id, - absl::optional* result) { + std::optional* result) { if (node_and_sharding.sharding.type() == xla::OpSharding::MAXIMAL) { int64_t core_annotation = node_and_sharding.sharding.tile_assignment_devices(0); @@ -1764,7 +1755,7 @@ static Status BuildGeneralDeviceAssignment( std::unique_ptr* xla_device_assignment) { // Assign TensorFlow devices to each computation's replicas according to // device_assignment and 'topology'. - *xla_device_assignment = absl::make_unique( + *xla_device_assignment = std::make_unique( num_replicas, num_cores_per_replica); for (int replica = 0; replica < num_replicas; ++replica) { for (int computation = 0; computation < num_cores_per_replica; @@ -2128,7 +2119,7 @@ bool UseSpmdForXlaPartitioning(const Node* replicate_node) { } std::string FormatNodeAndShardingMsg( - const absl::optional& node_and_sharding) { + const std::optional& node_and_sharding) { DCHECK(node_and_sharding.has_value()); xla::OpSharding sharding_no_metadata = node_and_sharding->sharding; @@ -2236,7 +2227,7 @@ Status DistributedTPURewritePass::AssignArgsAndRetvalsToCores( i + (is_per_replica_arg ? 0 : index_offset), &input_node)); if (_IsTPUPartitionedInput(input_node)) { TF_ASSIGN_OR_RETURN( - absl::optional parsed_sharding, + std::optional parsed_sharding, GetShardingFromNodeDef(input_node->def(), /*add_metadata=*/true)); if (!parsed_sharding.has_value()) return errors::InvalidArgument("Missing _XlaSharding attr from: ", @@ -2254,7 +2245,7 @@ Status DistributedTPURewritePass::AssignArgsAndRetvalsToCores( replicate_node->input_node(i + index_offset, &input_node)); if (input_node->type_string() == kVarHandleOp) { TF_ASSIGN_OR_RETURN( - absl::optional parsed_sharding, + std::optional parsed_sharding, GetShardingFromNodeDef(input_node->def(), /*add_metadata=*/true)); if (parsed_sharding.has_value()) { node_and_sharding = NodeAndSharding(input_node, *parsed_sharding); @@ -2366,11 +2357,11 @@ Status DistributedTPURewritePass::AssignArgsAndRetvalsToCores( TF_RETURN_IF_ERROR(retvals[i]->input_edge(0, &edge)); TF_ASSIGN_OR_RETURN( - absl::optional edge_sharding, + std::optional edge_sharding, ParseShardingFromEdgeSource(*edge, num_cores_per_replica, /*add_metadata=*/true)); - absl::optional node_and_sharding; + std::optional node_and_sharding; if (edge_sharding.has_value()) { node_and_sharding.emplace(NodeAndSharding(edge->src(), *edge_sharding)); } @@ -2378,7 +2369,7 @@ Status DistributedTPURewritePass::AssignArgsAndRetvalsToCores( if (partitioned_output_nodes.contains(i)) { Node* output_node = partitioned_output_nodes[i]; TF_ASSIGN_OR_RETURN( - absl::optional parsed_sharding, + std::optional parsed_sharding, GetShardingFromNodeDef(output_node->def(), /*add_metadata=*/true)); if (parsed_sharding.has_value()) { node_and_sharding = NodeAndSharding(output_node, *parsed_sharding); @@ -2387,7 +2378,7 @@ Status DistributedTPURewritePass::AssignArgsAndRetvalsToCores( << parsed_sharding->DebugString(); } } - absl::optional assigned_core; + std::optional assigned_core; if (node_and_sharding.has_value()) { if (enable_automatic_model_parallelism_) { return tensorflow::errors::InvalidArgument( @@ -4497,7 +4488,7 @@ DistributedTPURewritePass::LowerOutsideCompilationFunctionalNodes( TF_RETURN_IF_ERROR( GetNodeAttr(replicate_node.attrs(), "computation", function)); - *computation = absl::make_unique(graph->op_registry()); + *computation = std::make_unique(graph->op_registry()); TF_RETURN_IF_ERROR(GetComputationForTPUReplicateOp( **function, flr, computation->get(), arg_types, retval_types)); diff --git a/tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass.cc b/tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass.cc index f3915c680cc908..2802b254b52a1f 100644 --- a/tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass.cc +++ b/tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass.cc @@ -15,12 +15,15 @@ limitations under the License. #include "tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass.h" +#include #include +#include +#include +#include #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/container/node_hash_map.h" -#include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" #include "tensorflow/compiler/jit/encapsulate_util.h" @@ -1604,7 +1607,7 @@ Status RenameClustersWithDuplicatedNames(Graph* g) { // Start with outputs of TPUReplicateMetadata and follow output edges. std::queue queue; queue.push(iter.second.at(i)); - std::unordered_set visited; + absl::flat_hash_set visited; while (!queue.empty()) { Node* n = queue.front(); queue.pop(); @@ -1840,7 +1843,7 @@ Status LiftOutsideCompilationOnlyArgs(Graph* g, FunctionLibraryRuntime* flr, Status LiftOutsideCompilationOnlyArgsAndReplaceFunctionDef( const FunctionBody& fbody, FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld, int* lifted_arg_count, - absl::optional new_func_name, bool* rewritten) { + std::optional new_func_name, bool* rewritten) { *rewritten = false; TF_RETURN_IF_ERROR(LiftOutsideCompilationOnlyArgs( fbody.graph, flr, fld, lifted_arg_count, rewritten)); @@ -2310,7 +2313,7 @@ Status LiftOutsideCompilationOnlyArgs(Graph* g, FunctionLibraryRuntime* flr, bool func_rewritten = false; TF_RETURN_IF_ERROR(LiftOutsideCompilationOnlyArgsAndReplaceFunctionDef( *body_fbody, flr, fld, lifted_arg_count, - /*new_func_name=*/absl::nullopt, &func_rewritten)); + /*new_func_name=*/std::nullopt, &func_rewritten)); *rewritten = *rewritten || func_rewritten; while_nodes.push_back(n); @@ -2321,7 +2324,7 @@ Status LiftOutsideCompilationOnlyArgs(Graph* g, FunctionLibraryRuntime* flr, bool func_rewritten = false; TF_RETURN_IF_ERROR(LiftOutsideCompilationOnlyArgsAndReplaceFunctionDef( *then_branch_fbody, flr, fld, lifted_arg_count, - /*new_func_name=*/absl::nullopt, &func_rewritten)); + /*new_func_name=*/std::nullopt, &func_rewritten)); *rewritten |= func_rewritten; TF_ASSIGN_OR_RETURN( @@ -2330,7 +2333,7 @@ Status LiftOutsideCompilationOnlyArgs(Graph* g, FunctionLibraryRuntime* flr, func_rewritten = false; TF_RETURN_IF_ERROR(LiftOutsideCompilationOnlyArgsAndReplaceFunctionDef( *else_branch_fbody, flr, fld, lifted_arg_count, - /*new_func_name=*/absl::nullopt, &func_rewritten)); + /*new_func_name=*/std::nullopt, &func_rewritten)); *rewritten |= func_rewritten; if_nodes.push_back(n); @@ -2470,7 +2473,7 @@ Status LiftOutsideCompilationOnlyArgs(Graph* g, FunctionLibraryRuntime* flr, TF_RETURN_IF_ERROR( PerformStaticShapeInferenceBeforeEncapsulation(graph->get())); - auto output = absl::make_unique((*graph)->op_registry()); + auto output = std::make_unique((*graph)->op_registry()); TF_RETURN_WITH_CONTEXT_IF_ERROR( EncapsulateSubgraphsInFunctions( kTPUReplicateAttr, **graph, RewriteSubgraph, @@ -2574,9 +2577,8 @@ Status LiftOutsideCompilationOnlyArgs(Graph* g, FunctionLibraryRuntime* flr, GetNodeAttr(in_edges[pos]->src()->attrs(), "N", &input_num_replicas)); bool is_mirrored_variable; - CHECK(GetNodeAttr(in_edges[pos]->src()->attrs(), "is_mirrored_variable", - &is_mirrored_variable) - .ok()); + CHECK_OK(GetNodeAttr(in_edges[pos]->src()->attrs(), + "is_mirrored_variable", &is_mirrored_variable)); if (is_mirrored_variable) { mirrored_variable_indices.push_back(pos); } diff --git a/tensorflow/core/tpu/graph_rewrite/host_training_loop_optimization_util.h b/tensorflow/core/tpu/graph_rewrite/host_training_loop_optimization_util.h index 822dc9edd510ba..3e2bc9212e3120 100644 --- a/tensorflow/core/tpu/graph_rewrite/host_training_loop_optimization_util.h +++ b/tensorflow/core/tpu/graph_rewrite/host_training_loop_optimization_util.h @@ -16,11 +16,11 @@ limitations under the License. #ifndef TENSORFLOW_CORE_TPU_GRAPH_REWRITE_HOST_TRAINING_LOOP_OPTIMIZATION_UTIL_H_ #define TENSORFLOW_CORE_TPU_GRAPH_REWRITE_HOST_TRAINING_LOOP_OPTIMIZATION_UTIL_H_ +#include #include #include #include -#include "absl/types/optional.h" #include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/graph/graph.h" @@ -31,7 +31,7 @@ namespace tpu { struct LoopArgInfo { std::string enter_node_name; // Exit nodes are optional for loop invariant while loop args. - absl::optional exit_node_name; + std::optional exit_node_name; }; struct HostTrainingLoopInfo { @@ -39,8 +39,8 @@ struct HostTrainingLoopInfo { // host training loop is included. If host training loop is not // inside a function call, then `function_name` and `function_attrs` // are nullopt. - absl::optional encapsulating_function_name; - absl::optional encapsulating_function_attrs; + std::optional encapsulating_function_name; + std::optional encapsulating_function_attrs; // TPU Compile node as within a host training loop. std::string compile_node_name; diff --git a/tensorflow/core/tpu/graph_rewrite/update_tpu_embedding_ops_passes.cc b/tensorflow/core/tpu/graph_rewrite/update_tpu_embedding_ops_passes.cc index e2991b819f6827..7e4d336d1ca079 100644 --- a/tensorflow/core/tpu/graph_rewrite/update_tpu_embedding_ops_passes.cc +++ b/tensorflow/core/tpu/graph_rewrite/update_tpu_embedding_ops_passes.cc @@ -160,7 +160,7 @@ Status ComputeEnqueueTrainingStatus( bool send_exists = (found_grad_send_op.find(node.first) != found_grad_send_op.end()); VLOG(1) << "Found call " << node.first - << (send_exists ? " with " : " without ") << " send op(s)."; + << (send_exists ? " with " : " without ") << " send op(s)."; // If we have found a send gradient op for that is in the same cluster as // the enqueue op, then this is a training call so set the output to true // for this @@ -195,8 +195,8 @@ Status UpdateTPUEmbeddingModePass::GetEnqueueOpsFromGraph( // Update the graph for a specific enqueue op. Status UpdateTPUEmbeddingModePass::UpdateGraphEnqueueOp(bool training, - Graph* graph, - Node* enqueue) { + Graph* graph, + Node* enqueue) { // When using the layer, the mode override input is a SelectV2 op (unless this // pass has already run), which takes a training and eval op as input. We will // simply short circut the SelectV2 and take input from the correct op. diff --git a/tensorflow/core/tpu/kernels/BUILD b/tensorflow/core/tpu/kernels/BUILD index fd0b342028aae7..101f15da8d59da 100644 --- a/tensorflow/core/tpu/kernels/BUILD +++ b/tensorflow/core/tpu/kernels/BUILD @@ -99,9 +99,7 @@ cc_library( "@com_google_absl//absl/cleanup", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", - "@com_google_absl//absl/types:variant", ], alwayslink = 1, ) @@ -275,10 +273,7 @@ cc_library( hdrs = [ "tpu_compilation_cache_key.h", ], - deps = [ - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:optional", - ], + deps = ["@com_google_absl//absl/strings"], ) cc_library( @@ -308,9 +303,7 @@ cc_library( "//tensorflow/core/platform:errors", "//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc", "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", - "@com_google_absl//absl/types:variant", ], ) @@ -431,7 +424,6 @@ cc_library( "//tensorflow/compiler/xrt:xrt_proto_cc", "//tensorflow/core:lib", "//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc", - "@com_google_absl//absl/types:optional", ], ) @@ -476,32 +468,18 @@ cc_library( ], deps = [ ":compiled_subgraph", - ":tpu_compilation_cache_common_proto_cc", - ":tpu_compilation_cache_entry", ":tpu_compilation_cache_interface", ":tpu_compilation_cache_key", ":tpu_compilation_metrics", # buildcleaner: keep - ":tpu_compilation_metrics_hdrs", - ":tpu_compile_op_support", - ":tpu_mesh_state_interface", - ":tpu_op_consts", ":tpu_program_group", ":tpu_util", - ":trace_util_hdrs", "//tensorflow/compiler/xla/service", - "//tensorflow/compiler/xla/service:hlo_proto_cc", "//tensorflow/compiler/xla/stream_executor/tpu:tpu_ops_c_api_hdrs", - "//tensorflow/core:framework", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core/profiler/lib:traceme", - "//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc", - "@com_google_absl//absl/container:node_hash_map", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/types:span", + "@com_google_absl//absl/base:core_headers", ], ) @@ -748,15 +726,10 @@ cc_library( ":tpu_program_group", ":tpu_program_group_interface", ":tpu_util", - "//tensorflow/compiler/jit:shape_inference", "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/xla:status", - "//tensorflow/compiler/xla/stream_executor/tpu:tpu_executor_c_api_hdrs", - "//tensorflow/compiler/xla/stream_executor/tpu:tpu_executor_hdrs", "//tensorflow/compiler/xla/stream_executor/tpu:tpu_ops_c_api_hdrs", "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:framework", - "@com_google_absl//absl/types:variant", + "//tensorflow/core/platform:status", ], alwayslink = 1, ) @@ -794,12 +767,11 @@ cc_library( hdrs = ["tpu_execute_op.h"], deps = [ ":tpu_compilation_cache_entry", - ":tpu_compilation_cache_external", ":tpu_compilation_cache_interface", - ":tpu_compilation_cache_local_lookup", ":tpu_compilation_cache_lookup", ":tpu_executable_info_proto_cc", ":tpu_op_consts", + ":tpu_program_group", "//tensorflow/compiler/jit:variable_info", "//tensorflow/compiler/jit:variable_info_util", "//tensorflow/compiler/jit:xla_device_no_jit_rewrite_registration", @@ -819,13 +791,11 @@ cc_library( "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - "//tensorflow/core:protos_all_cc", "//tensorflow/core/profiler/lib:traceme", "//tensorflow/core/tpu:tpu_configuration", "//tensorflow/core/tpu:tpu_defs", "//tensorflow/core/tpu:tpu_execute", "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/memory", "@com_google_absl//absl/types:span", ], alwayslink = True, @@ -1146,7 +1116,6 @@ cc_library( deps = [ "//tensorflow/core:framework", "//tensorflow/core/platform:stringpiece", - "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/strings", @@ -1234,6 +1203,12 @@ cc_library( tags = ["avoid_dep"], textual_hdrs = ["tpu_compile_op_impl.h"], visibility = ["//visibility:public"], + deps = [ + ":tpu_compilation_cache_key", + ":tpu_compile_op_common", + ":tpu_program_group_interface", + "//tensorflow/compiler/xla/stream_executor/tpu:tpu_ops_c_api_hdrs", + ], ) tf_cc_test( diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.cc b/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.cc index b0931ab7b09952..c3e59671dd0745 100644 --- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.cc +++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.cc @@ -14,24 +14,20 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h" +#include #include #include #include #include +#include -#include "absl/memory/memory.h" -#include "absl/strings/str_cat.h" -#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_ops_c_api.h" -#include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/platform/random.h" #include "tensorflow/core/profiler/lib/traceme.h" #include "tensorflow/core/tpu/kernels/compiled_subgraph.h" -#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h" -#include "tensorflow/core/tpu/kernels/tpu_compilation_metrics.h" -#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h" +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h" +#include "tensorflow/core/tpu/kernels/tpu_program_group.h" #include "tensorflow/core/tpu/kernels/tpu_util.h" -#include "tensorflow/core/tpu/kernels/trace_util.h" namespace tensorflow { namespace tpu { diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h b/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h index aac50489c9e095..17c4289b0a83ca 100644 --- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h +++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h @@ -16,29 +16,11 @@ limitations under the License. #define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_EXTERNAL_H_ #include -#include -#include -#include -#include "absl/container/node_hash_map.h" -#include "absl/strings/string_view.h" -#include "absl/synchronization/mutex.h" -#include "absl/types/span.h" -#include "tensorflow/compiler/xla/service/hlo.pb.h" -#include "tensorflow/compiler/xla/stream_executor/tpu/tpu_ops_c_api.h" -#include "tensorflow/core/framework/resource_mgr.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/platform/refcount.h" -#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" +#include "absl/base/thread_annotations.h" #include "tensorflow/core/tpu/kernels/compiled_subgraph.h" -#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_common.pb.h" -#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h" #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h" #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h" -#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h" -#include "tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h" -#include "tensorflow/core/tpu/kernels/tpu_op_consts.h" -#include "tensorflow/core/tpu/kernels/tpu_program_group.h" namespace tensorflow { namespace tpu { diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h b/tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h index 81997e69a32789..c52ef1135c3fc8 100644 --- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h +++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h @@ -19,7 +19,6 @@ limitations under the License. #include #include "absl/strings/str_cat.h" -#include "absl/types/optional.h" namespace tensorflow { namespace tpu { diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.cc b/tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.cc index 05c5e9488937fd..40877afd95b0f8 100644 --- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.cc +++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.cc @@ -54,7 +54,7 @@ Status DeserializeRpcResponseToCacheEntry( }); // When we lookup from remote cache, we fetch a TPU program for a specific // core, hence we allocate TPU program group for a single program. - auto tpu_program_group = absl::make_unique(); + auto tpu_program_group = std::make_unique(); // TODO(b/166575150): can be optimized by sending the buffer over the gRPC // without an extra deserializing. @@ -114,7 +114,6 @@ xla::StatusOr> SerializeCacheEntryToBufferSlices( } header.set_is_empty(false); - bool may_modify_variables = tpu_program_group->may_modify_variables(cache_entry.core_index()); header.set_may_modify_variables(may_modify_variables); diff --git a/tensorflow/core/tpu/kernels/tpu_compile_op_common.cc b/tensorflow/core/tpu/kernels/tpu_compile_op_common.cc index ac5d9b37643ae9..733f78fcb82c6c 100644 --- a/tensorflow/core/tpu/kernels/tpu_compile_op_common.cc +++ b/tensorflow/core/tpu/kernels/tpu_compile_op_common.cc @@ -25,7 +25,6 @@ limitations under the License. #include "absl/cleanup/cleanup.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" -#include "absl/types/optional.h" #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/statusor.h" diff --git a/tensorflow/core/tpu/kernels/tpu_compile_op_common.h b/tensorflow/core/tpu/kernels/tpu_compile_op_common.h index f9b7136b294047..99f1bf29ad34bd 100644 --- a/tensorflow/core/tpu/kernels/tpu_compile_op_common.h +++ b/tensorflow/core/tpu/kernels/tpu_compile_op_common.h @@ -19,10 +19,10 @@ limitations under the License. #include #include #include +#include #include #include "absl/types/span.h" -#include "absl/types/variant.h" #include "tensorflow/compiler/jit/shape_inference.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/xla/statusor.h" @@ -98,7 +98,7 @@ class TpuCompileOpKernelCommon { // Lowers Mlir or TF Function computation into HLO IR and using XLA compiler // compiles into TPU programs ready for execution. virtual Status Compile( - const absl::variant& computation, + const std::variant& computation, const XLA_TpuMeshState* mesh_state, const std::vector& arg_shapes, const TpuCompilationCacheKey* key, diff --git a/tensorflow/core/tpu/kernels/tpu_compile_op_impl.cc b/tensorflow/core/tpu/kernels/tpu_compile_op_impl.cc index d0358b70dcea9d..f85dbf4d9a002d 100644 --- a/tensorflow/core/tpu/kernels/tpu_compile_op_impl.cc +++ b/tensorflow/core/tpu/kernels/tpu_compile_op_impl.cc @@ -16,10 +16,11 @@ limitations under the License. #include #include +#include #include -#include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_ops_c_api.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/core/tpu/kernels/tpu_compile.pb.h" #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h" #include "tensorflow/core/tpu/kernels/tpu_program_group.h" @@ -30,7 +31,7 @@ namespace tpu { using tsl::StatusOr; Status TpuCompileOpKernelImpl::Compile( - const absl::variant& computation, + const std::variant& computation, const XLA_TpuMeshState* mesh_state, const std::vector& arg_shapes, const TpuCompilationCacheKey* key, diff --git a/tensorflow/core/tpu/kernels/tpu_compile_op_impl.h b/tensorflow/core/tpu/kernels/tpu_compile_op_impl.h index a70c0dc33d23c6..067d13ae83bfe6 100644 --- a/tensorflow/core/tpu/kernels/tpu_compile_op_impl.h +++ b/tensorflow/core/tpu/kernels/tpu_compile_op_impl.h @@ -16,12 +16,10 @@ limitations under the License. #define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_IMPL_H_ #include +#include #include -#include "absl/types/variant.h" -#include "tensorflow/compiler/jit/shape_inference.h" #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_ops_c_api.h" -#include "tensorflow/core/framework/function.h" #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h" #include "tensorflow/core/tpu/kernels/tpu_compile_op_common.h" #include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h" @@ -52,7 +50,7 @@ class TpuCompileOpKernelImpl : public TpuCompileOpKernelCommon { unload_cache_on_session_close, /*persistent_cache=*/nullptr) {} Status Compile( - const absl::variant& computation, + const std::variant& computation, const XLA_TpuMeshState* mesh_state, const std::vector& arg_shapes, const TpuCompilationCacheKey* key, diff --git a/tensorflow/core/tpu/kernels/tpu_compile_op_support.cc b/tensorflow/core/tpu/kernels/tpu_compile_op_support.cc index a88d546bdfb78f..d0832dfaf974c1 100644 --- a/tensorflow/core/tpu/kernels/tpu_compile_op_support.cc +++ b/tensorflow/core/tpu/kernels/tpu_compile_op_support.cc @@ -317,7 +317,7 @@ Status ComputeOutputShapesForEachCore( Status CreateHloModules( const TPUCompileMetadataProto& metadata, const tensorflow::XlaCompiler::CompilationResult& compilation_result, - const absl::optional& device_assignment, + const std::optional& device_assignment, std::vector>* hlo_modules) { TF_RET_CHECK( compilation_result.computation->proto().has_host_program_shape()); @@ -345,7 +345,7 @@ Status CreateHloModules( } StatusOr CreateTpuCompilationRequest( - const absl::variant& computation, + const std::variant& computation, const TPUCompileMetadataProto& metadata, const std::vector& arg_shapes) { VLOG(1) << "CreateTpuCompilationRequest."; diff --git a/tensorflow/core/tpu/kernels/tpu_compile_op_support.h b/tensorflow/core/tpu/kernels/tpu_compile_op_support.h index ce7c39ae800c0a..d3e31f266bc91f 100644 --- a/tensorflow/core/tpu/kernels/tpu_compile_op_support.h +++ b/tensorflow/core/tpu/kernels/tpu_compile_op_support.h @@ -23,9 +23,7 @@ limitations under the License. #include #include "absl/strings/string_view.h" -#include "absl/types/optional.h" #include "absl/types/span.h" -#include "absl/types/variant.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_module_group.h" @@ -138,11 +136,11 @@ tsl::Status ComputeOutputShapesForEachCore( tsl::Status CreateHloModules( const TPUCompileMetadataProto& metadata, const XlaCompiler::CompilationResult& compilation_result, - const absl::optional& device_assignment, + const std::optional& device_assignment, std::vector>* hlo_modules); tsl::StatusOr CreateTpuCompilationRequest( - const absl::variant& computation, + const std::variant& computation, const TPUCompileMetadataProto& metadata, const std::vector& arg_shapes); diff --git a/tensorflow/core/tpu/kernels/tpu_execute_op.cc b/tensorflow/core/tpu/kernels/tpu_execute_op.cc index 50a2593fa2547f..44374bd277719b 100644 --- a/tensorflow/core/tpu/kernels/tpu_execute_op.cc +++ b/tensorflow/core/tpu/kernels/tpu_execute_op.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" -#include "absl/memory/memory.h" #include "absl/types/span.h" #include "tensorflow/compiler/jit/variable_info.h" #include "tensorflow/compiler/jit/variable_info_util.h" @@ -36,29 +35,22 @@ limitations under the License. #include "tensorflow/compiler/xla/stream_executor/device_memory_allocator.h" #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_node_context.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/node_def_util.h" -#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/resource_var.h" #include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/platform/casts.h" -#include "tensorflow/core/platform/tracing.h" #include "tensorflow/core/profiler/lib/traceme.h" #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h" -#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h" #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h" -#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_local_lookup.h" #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h" #include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h" #include "tensorflow/core/tpu/kernels/tpu_op_consts.h" +#include "tensorflow/core/tpu/kernels/tpu_program_group.h" #include "tensorflow/core/tpu/tpu_configuration.h" #include "tensorflow/core/tpu/tpu_defs.h" #include "tensorflow/core/tpu/tpu_execute.h" -#include "tensorflow/core/util/stream_executor_util.h" namespace tensorflow { namespace { diff --git a/tensorflow/core/tpu/kernels/tpu_fingerprint_lookup.h b/tensorflow/core/tpu/kernels/tpu_fingerprint_lookup.h index 8519fa760c0364..fe98817c65dd81 100644 --- a/tensorflow/core/tpu/kernels/tpu_fingerprint_lookup.h +++ b/tensorflow/core/tpu/kernels/tpu_fingerprint_lookup.h @@ -64,8 +64,7 @@ class TpuFingerprintLookup : public ResourceBase { bool RegisterIntermediateAndValuePair(uint64 intermediate, std::string value); // Look up fingerprint with key. - // Return absl::optional<::tensorflow::StringPiece>{} if - // not found. + // Return std::nullopt if not found. std::optional<::tensorflow::StringPiece> Lookup(uint64 key); size_t num_valid() { diff --git a/tensorflow/core/tpu/kernels/tpu_program_group.h b/tensorflow/core/tpu/kernels/tpu_program_group.h index a3a8d7fb6931a3..7d7771451f5331 100644 --- a/tensorflow/core/tpu/kernels/tpu_program_group.h +++ b/tensorflow/core/tpu/kernels/tpu_program_group.h @@ -20,7 +20,6 @@ limitations under the License. #include #include -#include "absl/types/optional.h" #include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/xla/client/compile_only_client.h" diff --git a/tensorflow/core/tpu/tpu_compile.cc b/tensorflow/core/tpu/tpu_compile.cc index 1ec2e073e81ae3..46f2ca38d2302b 100644 --- a/tensorflow/core/tpu/tpu_compile.cc +++ b/tensorflow/core/tpu/tpu_compile.cc @@ -12,8 +12,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ + #include "tensorflow/core/tpu/tpu_compile.h" +#include +#include +#include +#include +#include +#include +#include + +#include "absl/types/span.h" #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/shape_inference.h" #include "tensorflow/compiler/tf2xla/layout_util.h" @@ -23,11 +33,9 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/framework/attr_value.pb.h" -#include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/graph/graph.h" -#include "tensorflow/core/platform/statusor.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h" -#include "tensorflow/core/tpu/kernels/tpu_util.h" #include "tensorflow/core/tpu/tpu_defs.h" namespace tensorflow { @@ -284,8 +292,8 @@ Status BuildComputationArgumentDescriptions( arg.kind = XlaCompiler::Argument::kConstant; guaranteed_constants_size = guaranteed_constants.index() == 0 - ? absl::get<0>(guaranteed_constants).size() - : absl::get<1>(guaranteed_constants)->size(); + ? std::get<0>(guaranteed_constants).size() + : std::get<1>(guaranteed_constants)->size(); TF_RET_CHECK(constant_count < guaranteed_constants_size) << "More constant args in TPUCompileMetadataProto than constant " "tensors."; @@ -294,13 +302,13 @@ Status BuildComputationArgumentDescriptions( // const>`. Tensor tensor; CHECK(tensor.FromProto( - *absl::get<0>(guaranteed_constants)[constant_count++])) + *std::get<0>(guaranteed_constants)[constant_count++])) << "Failed to deserialize invalid `TensorProto` into `Tensor`."; arg.constant_value = tensor; } else { // `guaranteed_constants` is of type `const OpInputList* const`. arg.constant_value = - (*absl::get<1>(guaranteed_constants))[constant_count++]; + (*std::get<1>(guaranteed_constants))[constant_count++]; } break; case tpu::TPUCompileMetadataProto::Arg::INVALID: @@ -464,7 +472,7 @@ Status GetShardingInfo( TF_ASSIGN_OR_RETURN(auto arg_sharding, xla::HloSharding::FromProto(proto_arg.sharding())); auto layout_preference = shape_determination_fns.layout_preference_fn( - arg_shapes[i], proto_arg.dtype(), absl::nullopt); + arg_shapes[i], proto_arg.dtype(), std::nullopt); TF_ASSIGN_OR_RETURN(auto xla_arg_shape, shape_determination_fns.shape_representation_fn( arg_shapes[i], proto_arg.dtype(), diff --git a/tensorflow/core/tpu/tpu_compile.h b/tensorflow/core/tpu/tpu_compile.h index 55ee52a26cddd8..96eda37e0bb675 100644 --- a/tensorflow/core/tpu/tpu_compile.h +++ b/tensorflow/core/tpu/tpu_compile.h @@ -12,12 +12,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ + #ifndef TENSORFLOW_CORE_TPU_TPU_COMPILE_H_ #define TENSORFLOW_CORE_TPU_TPU_COMPILE_H_ +#include + +#include "absl/types/span.h" #include "tensorflow/compiler/jit/shape_inference.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/xla/client/compile_only_client.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h" diff --git a/tensorflow/core/tpu/tpu_global_init.cc b/tensorflow/core/tpu/tpu_global_init.cc index e58133d2bd667d..62ef8a23a0ea98 100644 --- a/tensorflow/core/tpu/tpu_global_init.cc +++ b/tensorflow/core/tpu/tpu_global_init.cc @@ -14,33 +14,29 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/tpu/tpu_global_init.h" -#include #include #include #include #include -#include "absl/memory/memory.h" #include "absl/strings/string_view.h" #include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/ops/tpu_configuration_ops.h" -#include "tensorflow/core/common_runtime/device.h" -#include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/device_set.h" #include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/common_runtime/graph_runner.h" #include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/common_runtime/session_factory.h" +#include "tensorflow/core/framework/device.h" +#include "tensorflow/core/framework/device_factory.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/platform/types.h" #include "tensorflow/core/public/session.h" #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_configuration_rewrite_pass.h" #include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_helpers.h" -#include "tensorflow/core/tpu/tpu_defs.h" #include "tensorflow/core/util/device_name_utils.h" namespace tensorflow { From 1256c0384894a948728fa63a35bad32790bc4b21 Mon Sep 17 00:00:00 2001 From: Rahul Joshi Date: Thu, 27 Jul 2023 16:57:43 -0700 Subject: [PATCH 282/410] [XLA] Allow mixed precision types for all-gather in HLO verifier - all-gather with multiple operands can have different types, so fix HLO verifier to allow that. PiperOrigin-RevId: 551680533 --- tensorflow/compiler/xla/service/hlo_verifier.cc | 3 +++ .../compiler/xla/service/hlo_verifier_test.cc | 17 +++++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index df878750bfdab3..94ea54333c6e5d 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -1751,6 +1751,9 @@ Status CheckMixedPrecisionOperands(const HloInstruction* instruction) { case HloOpcode::kAllReduce: case HloOpcode::kAllReduceStart: case HloOpcode::kAllReduceDone: + case HloOpcode::kAllGather: + case HloOpcode::kAllGatherStart: + case HloOpcode::kAllGatherDone: case HloOpcode::kAsyncDone: case HloOpcode::kAsyncUpdate: case HloOpcode::kAsyncStart: diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc index 1fa4f1a4a3a430..24bab5acedf745 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc @@ -2909,5 +2909,22 @@ ENTRY entry { TF_ASSERT_OK(status); } +TEST_F(HloVerifierTest, MixedTypeForAllGatherAllowed) { + constexpr absl::string_view kHlo = R"( +HloModule module + +ENTRY entry { + p0 = f32[10] parameter(0) + p1 = bf16[10] parameter(1) + ROOT ag = (f32[20], bf16[20]) all-gather(p0, p1), dimensions={0} +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnUnverifiedModule(kHlo)); + Status status = verifier().Run(module.get()).status(); + + TF_ASSERT_OK(status); +} + } // namespace } // namespace xla From 6ecdaccb5db2e4303a93fde2933bdfc1ae21499c Mon Sep 17 00:00:00 2001 From: Swachhand Lokhande Date: Thu, 27 Jul 2023 17:19:23 -0700 Subject: [PATCH 283/410] Allow DEVICE_GPU to use PJRT for XlaCompile+Run ops. This is guarded by a flag. PiperOrigin-RevId: 551685521 --- tensorflow/core/common_runtime/gpu/gpu_device.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc index 59b236dc87acb0..939061e61e1d5c 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_device.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc @@ -1792,6 +1792,7 @@ Status BaseGPUDeviceFactory::CreateDevices( auto& pjrt_rollout_config = GetXlaOpsCommonFlags()->tf_xla_use_device_api; pjrt_rollout_config.AllowForDeviceInXlaLaunch(DEVICE_GPU); + pjrt_rollout_config.AllowForDeviceInXlaCompileAndRun(DEVICE_GPU); // Creates PJRT GPU client and places it into a TF global resource manager. auto gpu_run_options = From 79ecccad8b8e972afc41082e44f20f3637c610cf Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 27 Jul 2023 17:28:46 -0700 Subject: [PATCH 284/410] Integrate LLVM at llvm/llvm-project@2854852f4f0f Updates LLVM usage to match [2854852f4f0f](https://github.com/llvm/llvm-project/commit/2854852f4f0f) PiperOrigin-RevId: 551687307 --- .../mlir/tensorflow/utils/cluster_util.cc | 2 +- .../gpu/transforms/outline_cuda_graphs.cc | 5 +- .../xla/mlir/runtime/transforms/BUILD | 1 - .../transforms/compilation_pipeline_cpu.cc | 1 - tensorflow/compiler/xla/mlir_hlo/BUILD | 1 - .../xla/mlir_hlo/transforms/CMakeLists.txt | 1 - .../transforms/generic_host_to_llvm.cc | 2 - .../compiler/xla/service/gpu/openxla/BUILD | 8 +- .../transforms/region_to_functional/impl.cc | 5 +- third_party/llvm/generated.patch | 12 - third_party/llvm/workspace.bzl | 4 +- third_party/stablehlo/temporary.patch: | 4513 +++++++++++++++++ 12 files changed, 4528 insertions(+), 27 deletions(-) create mode 100644 third_party/stablehlo/temporary.patch: diff --git a/tensorflow/compiler/mlir/tensorflow/utils/cluster_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/cluster_util.cc index 270f3e0a12000c..0557110de35d60 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/cluster_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/cluster_util.cc @@ -203,7 +203,7 @@ void ReorderOpResultUses(mlir::Operation* cluster) { } } - std::vector sorted = ops_to_reorder.takeVector(); + llvm::SmallVector sorted = ops_to_reorder.takeVector(); llvm::sort(sorted, [](mlir::Operation* lhs, mlir::Operation* rhs) { return lhs->isBeforeInBlock(rhs); }); diff --git a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/outline_cuda_graphs.cc b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/outline_cuda_graphs.cc index bbca585ef50bc3..b154c7c11690db 100644 --- a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/outline_cuda_graphs.cc +++ b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/outline_cuda_graphs.cc @@ -320,8 +320,9 @@ static std::vector GetGraphCaptureFuncArgs(const CaptureSequence& seq) { [&](Value arg) { return !defined_by_seq.contains(arg); }); args.insert(external_args.begin(), external_args.end()); } - - return args.takeVector(); + llvm::SmallVector args_sv = args.takeVector(); + std::vector args_tv(args_sv.begin(), args_sv.end()); + return args_tv; } // Given a sequence of operations, outline them into a graph capture function diff --git a/tensorflow/compiler/xla/mlir/runtime/transforms/BUILD b/tensorflow/compiler/xla/mlir/runtime/transforms/BUILD index dc9bd544f64fc2..22c2c46edfe4e4 100644 --- a/tensorflow/compiler/xla/mlir/runtime/transforms/BUILD +++ b/tensorflow/compiler/xla/mlir/runtime/transforms/BUILD @@ -126,7 +126,6 @@ cc_library( "@llvm-project//mlir:GPUToGPURuntimeTransforms", "@llvm-project//mlir:GPUTransforms", "@llvm-project//mlir:LLVMToLLVMIRTranslation", - "@llvm-project//mlir:LinalgToLLVM", "@llvm-project//mlir:LinalgTransforms", "@llvm-project//mlir:MathDialect", "@llvm-project//mlir:MathToLLVM", diff --git a/tensorflow/compiler/xla/mlir/runtime/transforms/compilation_pipeline_cpu.cc b/tensorflow/compiler/xla/mlir/runtime/transforms/compilation_pipeline_cpu.cc index 0ff7e1b0d366da..e923fb6b041d21 100644 --- a/tensorflow/compiler/xla/mlir/runtime/transforms/compilation_pipeline_cpu.cc +++ b/tensorflow/compiler/xla/mlir/runtime/transforms/compilation_pipeline_cpu.cc @@ -22,7 +22,6 @@ limitations under the License. #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" // from @llvm-project #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" // from @llvm-project #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" // from @llvm-project -#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" // from @llvm-project #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" // from @llvm-project #include "mlir/Conversion/MathToLibm/MathToLibm.h" // from @llvm-project #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" // from @llvm-project diff --git a/tensorflow/compiler/xla/mlir_hlo/BUILD b/tensorflow/compiler/xla/mlir_hlo/BUILD index d905784f0f3119..10da024978bbcb 100644 --- a/tensorflow/compiler/xla/mlir_hlo/BUILD +++ b/tensorflow/compiler/xla/mlir_hlo/BUILD @@ -1298,7 +1298,6 @@ cc_library( "@llvm-project//mlir:LLVMCommonConversion", "@llvm-project//mlir:LLVMDialect", "@llvm-project//mlir:LinalgDialect", - "@llvm-project//mlir:LinalgToLLVM", "@llvm-project//mlir:LinalgTransforms", "@llvm-project//mlir:LoopLikeInterface", "@llvm-project//mlir:MathDialect", diff --git a/tensorflow/compiler/xla/mlir_hlo/transforms/CMakeLists.txt b/tensorflow/compiler/xla/mlir_hlo/transforms/CMakeLists.txt index a8e8ef3aed7b44..4d61e07d8e6cb1 100644 --- a/tensorflow/compiler/xla/mlir_hlo/transforms/CMakeLists.txt +++ b/tensorflow/compiler/xla/mlir_hlo/transforms/CMakeLists.txt @@ -51,7 +51,6 @@ add_mlir_library(MLIRBufferTransforms MLIRGPUDialect MLIRHLOAnalysis MLIRIR - MLIRLinalgToLLVM MLIRMathTransforms MLIRPass MLIRReconcileUnrealizedCasts diff --git a/tensorflow/compiler/xla/mlir_hlo/transforms/generic_host_to_llvm.cc b/tensorflow/compiler/xla/mlir_hlo/transforms/generic_host_to_llvm.cc index b5d99e3c41a8c3..b338623b1135bc 100644 --- a/tensorflow/compiler/xla/mlir_hlo/transforms/generic_host_to_llvm.cc +++ b/tensorflow/compiler/xla/mlir_hlo/transforms/generic_host_to_llvm.cc @@ -24,7 +24,6 @@ limitations under the License. #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" -#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" #include "mlir/Conversion/MathToLibm/MathToLibm.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" @@ -107,7 +106,6 @@ class GenericHostToLLVMPass cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns); populateSCFToControlFlowConversionPatterns(patterns); populateComplexToLLVMConversionPatterns(typeConverter, patterns); - populateLinalgToLLVMConversionPatterns(typeConverter, patterns); populateMathToLibmConversionPatterns(patterns); deallocation::populateDeallocationToLLVMConversionPatterns(typeConverter, patterns); diff --git a/tensorflow/compiler/xla/service/gpu/openxla/BUILD b/tensorflow/compiler/xla/service/gpu/openxla/BUILD index 38aa94529445e2..630a13b02cb2dc 100644 --- a/tensorflow/compiler/xla/service/gpu/openxla/BUILD +++ b/tensorflow/compiler/xla/service/gpu/openxla/BUILD @@ -37,7 +37,10 @@ package_group( # # compatible with `non_prod` constraint. # compatible_with = [], # data = select({ -# ":with_openxla_runtime": ["//third_party/iree/lib:libIREECompiler.so"], +# ":with_openxla_runtime": [ +# # IREE build is currently broken +# # "//third_party/iree/lib:libIREECompiler.so", +# ], # "//conditions:default": [], # }), # deps = [ @@ -55,7 +58,8 @@ package_group( # ":with_openxla_runtime": [ # "//third_party/iree/compiler/bindings/c:headers", # "//third_party/iree/compiler/bindings/c:loader", -# "//third_party/iree/llvm-external-projects/iree-dialects:IREEInputDialect", +# # IREE build is currently broken +# # "//third_party/iree/llvm-external-projects/iree-dialects:IREEInputDialect", # ], # "//conditions:default": [], # }), diff --git a/tensorflow/core/transforms/region_to_functional/impl.cc b/tensorflow/core/transforms/region_to_functional/impl.cc index 7fac57bb7fda91..04b2c96c861988 100644 --- a/tensorflow/core/transforms/region_to_functional/impl.cc +++ b/tensorflow/core/transforms/region_to_functional/impl.cc @@ -24,6 +24,7 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringSet.h" #include "llvm/ADT/iterator.h" #include "llvm/Support/ScopedPrinter.h" @@ -465,8 +466,8 @@ FailureOr> BasePattern::CollectValuesDefinedAboveAll( SetVector data_set, ctl_only; for (Region ®ion : llvm::make_pointee_range(regions)) CollectValuesDefinedAbove(region, data_set, ctl_only); - std::vector datas = data_set.takeVector(); - + llvm::SmallVector data_sv = data_set.takeVector(); + std::vector datas(data_sv.begin(), data_sv.end()); // If in any of the regions we found a use of a control token defined above // the regions with no associated data value, then it cannot be converted to // explicit capture unless we insert chain constants. If this option was not diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch index 2bea0dc6ac565b..509398da979e83 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch @@ -1,13 +1 @@ Auto generated patch. Do not edit or delete it, even if empty. -diff -ruN --strip-trailing-cr a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp ---- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp -+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp -@@ -1714,7 +1714,7 @@ - ");\n", - op.getCppClassName()); - } else { -- body << " result.addAttribute(\"odsResultSegmentSizes\", " -+ body << " result.addAttribute(\"result_segment_sizes\", " - << "parser.getBuilder().getDenseI32ArrayAttr({"; - llvm::interleaveComma(op.getResults(), body, interleaveFn); - body << "}));\n"; diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index 2644759bfc40b1..2f1b7eba6020a8 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "4706251a3186c34da0ee8fd894f7e6b095da8fdc" - LLVM_SHA256 = "01cdfda6790f2b0d897423c6ba8147af6359f261c5aed62e157b9ecdf2ad591e" + LLVM_COMMIT = "2854852f4f0f1fbb8fa7adb031f921898c8201d6" + LLVM_SHA256 = "5e6575622bb1486d09933f3abd72cf84b4b8d3f0171c850cc5c5290caa4f55b8" tf_http_archive( name = name, diff --git a/third_party/stablehlo/temporary.patch: b/third_party/stablehlo/temporary.patch: new file mode 100644 index 00000000000000..ddd7f6c1f59a86 --- /dev/null +++ b/third_party/stablehlo/temporary.patch: @@ -0,0 +1,4513 @@ +diff --ruN a/stablehlo/BUILD.bazel b/stablehlo/BUILD.bazel +--- stablehlo/BUILD.bazel ++++ stablehlo/BUILD.bazel +@@ -279,6 +279,24 @@ + ) + + cc_library( ++ name = "experimental_ops", ++ srcs = [ ++ "stablehlo/dialect/ExperimentalOps.cpp", ++ ], ++ hdrs = [ ++ "stablehlo/dialect/ExperimentalOps.h", ++ ], ++ strip_include_prefix = ".", ++ deps = [ ++ ":stablehlo_ops", ++ "@llvm-project//llvm:Support", ++ "@llvm-project//mlir:FuncDialect", ++ "@llvm-project//mlir:IR", ++ "@llvm-project//mlir:Support", ++ ], ++) ++ ++cc_library( + name = "reference_axes", + srcs = [ + "stablehlo/reference/Axes.cpp", +@@ -677,6 +695,7 @@ + deps = [ + ":base", + ":chlo_ops", ++ ":experimental_ops", + ":stablehlo_ops", + ":stablehlo_ops_inc_gen", + ":stablehlo_pass_inc_gen", +diff --ruN a/stablehlo/CMakeLists.txt b/stablehlo/CMakeLists.txt +--- stablehlo/CMakeLists.txt ++++ stablehlo/CMakeLists.txt +@@ -13,135 +13,20 @@ + # See the License for the specific language governing permissions and + # limitations under the License. + # +-cmake_minimum_required(VERSION 3.15.0) + +-if(POLICY CMP0068) +- cmake_policy(SET CMP0068 NEW) +- set(CMAKE_BUILD_WITH_INSTALL_NAME_DIR ON) +-endif() +- +-if(POLICY CMP0075) +- cmake_policy(SET CMP0075 NEW) +-endif() +- +-if(POLICY CMP0077) +- cmake_policy(SET CMP0077 NEW) +-endif() +- +-# CMP0116: Ninja generators transform `DEPFILE`s from `add_custom_command()` +-# New in CMake 3.20. https://cmake.org/cmake/help/latest/policy/CMP0116.html +-if(POLICY CMP0116) +- cmake_policy(SET CMP0116 OLD) +-endif() ++# This build of StableHLO is meant to be embedded in MLIR-HLO. ++# As a result, its root CMakeLists.txt is different from the original ++# CMakeLists.txt from https://github.com/openxla/stablehlo. ++# All other files of this build of StableHLO except for this one are the same ++# as the original files. ++# To get access to a standalone build of StableHLO, check out the ++# openxla/stablehlo repository. + + #------------------------------------------------------------------------------- + # Options and settings + #------------------------------------------------------------------------------- +-option(STABLEHLO_BUILD_EMBEDDED "Build StableHLO as part of another project" OFF) +-option(STABLEHLO_ENABLE_BINDINGS_PYTHON "Enables StableHLO Python bindings" OFF) +-option(STABLEHLO_ENABLE_STRICT_BUILD "Build StableHLO with strict warnings and warnings as errors" OFF) + +-#------------------------------------------------------------------------------- +-# Project setup and globals +-#------------------------------------------------------------------------------- +-set(STABLEHLO_EXTERNAL_PROJECT_BUILD OFF) +- +-if(NOT (CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR) AND NOT MLIR_BINARY_DIR) +- # Building as part of LLVM via the external project mechanism. +- set(STABLEHLO_EXTERNAL_PROJECT_BUILD ON) +-else() +- # Building standalone. +- project(stablehlo LANGUAGES CXX C) +- set(CMAKE_C_STANDARD 11) +- set(CMAKE_CXX_STANDARD 17) +-endif() +- +-# Build with ccache if the package is present +-set(LLVM_CCACHE_BUILD OFF CACHE BOOL "Set to ON for a ccache enabled build") +-if(LLVM_CCACHE_BUILD) +- find_program(CCACHE_PROGRAM ccache) +- if(CCACHE_PROGRAM) +- set(LLVM_CCACHE_MAXSIZE "" CACHE STRING "Size of ccache") +- set(LLVM_CCACHE_DIR "" CACHE STRING "Directory to keep ccached data") +- set(LLVM_CCACHE_PARAMS "CCACHE_CPP2=yes CCACHE_HASHDIR=yes" +- CACHE STRING "Parameters to pass through to ccache") +- +- set(CCACHE_PROGRAM "${LLVM_CCACHE_PARAMS} ${CCACHE_PROGRAM}") +- if (LLVM_CCACHE_MAXSIZE) +- set(CCACHE_PROGRAM "CCACHE_MAXSIZE=${LLVM_CCACHE_MAXSIZE} ${CCACHE_PROGRAM}") +- endif() +- if (LLVM_CCACHE_DIR) +- set(CCACHE_PROGRAM "CCACHE_DIR=${LLVM_CCACHE_DIR} ${CCACHE_PROGRAM}") +- endif() +- set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE ${CCACHE_PROGRAM}) +- else() +- message(FATAL_ERROR "Unable to find the program ccache. Set LLVM_CCACHE_BUILD to OFF") +- endif() +-endif() +- +-#------------------------------------------------------------------------------- +-# MLIR/LLVM Configuration +-#------------------------------------------------------------------------------- +-if (STABLEHLO_ENABLE_STRICT_BUILD) +- set(LLVM_ENABLE_WARNINGS ON) +- set(LLVM_ENABLE_WERROR ON) +- set(LLVM_ENABLE_PEDANTIC ON) +-endif() +- +-# Find MLIR to install if we are building standalone. If building as part of +-# another project, let it handle the MLIR dependency. The dependent project +-# might use a bundled version of MLIR instead of installing, for instance. +-if(STABLEHLO_EXTERNAL_PROJECT_BUILD) +- message(STATUS "Building StableHLO as an external LLVM project") +- set(MLIR_MAIN_SRC_DIR ${LLVM_MAIN_SRC_DIR}/../mlir ) # --src-root +- set(MLIR_INCLUDE_DIR ${MLIR_MAIN_SRC_DIR}/include ) # --includedir +- set(MLIR_GENERATED_INCLUDE_DIR ${LLVM_BINARY_DIR}/tools/mlir/include) +- include_directories(SYSTEM ${MLIR_INCLUDE_DIR}) +- include_directories(SYSTEM ${MLIR_GENERATED_INCLUDE_DIR}) +- include_directories(SYSTEM ${MLIR_TABLEGEN_OUTPUT_DIR}) +- +- set(BACKEND_PACKAGE_STRING "${PACKAGE_STRING}") +- list(APPEND CMAKE_MODULE_PATH "${MLIR_MAIN_SRC_DIR}/cmake/modules") +-elseif(NOT STABLEHLO_BUILD_EMBEDDED) +- message(STATUS "Building StableHLO with an installed MLIR") +- find_package(MLIR REQUIRED CONFIG) +- message(STATUS "Using MLIRConfig.cmake in: ${MLIR_DIR}") +- message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}") +- set(LLVM_RUNTIME_OUTPUT_INTDIR ${CMAKE_BINARY_DIR}/bin) +- set(LLVM_LIBRARY_OUTPUT_INTDIR ${CMAKE_BINARY_DIR}/lib) +- list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}") +- list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}") +-else() +- message(STATUS "Building StableHLO embedded in another project") +-endif() +- +-if(LLVM_ENABLE_ZLIB) +- find_package(ZLIB) +-endif() +- +-include(TableGen) +-include(AddLLVM) +-include(AddMLIR) +-include(HandleLLVMOptions) +-include_directories(${LLVM_INCLUDE_DIRS}) +-include_directories(${MLIR_INCLUDE_DIRS}) +-include_directories(${CMAKE_CURRENT_SOURCE_DIR}) +-include_directories(${CMAKE_CURRENT_BINARY_DIR}) +-link_directories(${LLVM_BUILD_LIBRARY_DIR}) +-add_definitions(${LLVM_DEFINITIONS}) +- +-#------------------------------------------------------------------------------- +-# Python configuration +-#------------------------------------------------------------------------------- +- +-if(STABLEHLO_ENABLE_BINDINGS_PYTHON) +- if(NOT STABLEHLO_EXTERNAL_PROJECT_BUILD) +- message(WARNING "StableHLO Python bindings are not supported in standalone mode") +- endif() +- +- include(MLIRDetectPythonEnv) +- mlir_configure_python_dev_packages() +-endif() ++set(STABLEHLO_ENABLE_BINDINGS_PYTHON ${MHLO_ENABLE_BINDINGS_PYTHON}) + + #------------------------------------------------------------------------------- + # Directory setup +diff --ruN a/stablehlo/stablehlo/dialect/Base.cpp b/stablehlo/stablehlo/dialect/Base.cpp +--- stablehlo/stablehlo/dialect/Base.cpp ++++ stablehlo/stablehlo/dialect/Base.cpp +@@ -156,6 +156,7 @@ + DenseIntElementsAttr attr; + if (!matchPattern(value, m_Constant(&attr))) return failure(); + ++ // Signless types are treated as signed, per StableHLO convention. + // Unless the type is i1 (which models boolean type from the StableHLO spec), + // in which case it's considered to be unsigned. + auto elementType = attr.getType().getElementType(); +@@ -599,5 +600,18 @@ + return UnrankedTensorType::get(components.getElementType()); + } + ++DenseIntElementsAttr getPaddingAttr(MLIRContext* context, ++ ArrayRef values) { ++ return DenseIntElementsAttr::get( ++ RankedTensorType::get({static_cast(values.size()) / 2, 2}, ++ IntegerType::get(context, 64)), ++ values); ++} ++ ++DenseIntElementsAttr getPaddingAttr(Builder* builder, ++ ArrayRef values) { ++ return getPaddingAttr(builder->getContext(), values); ++} ++ + } // namespace hlo + } // namespace mlir +diff --ruN a/stablehlo/stablehlo/dialect/Base.h b/stablehlo/stablehlo/dialect/Base.h +--- stablehlo/stablehlo/dialect/Base.h ++++ stablehlo/stablehlo/dialect/Base.h +@@ -194,6 +194,10 @@ + + ShapedType createShapedType(ShapedTypeComponents components); + ++DenseIntElementsAttr getPaddingAttr(MLIRContext *context, ++ ArrayRef value); ++DenseIntElementsAttr getPaddingAttr(Builder *builder, ArrayRef value); ++ + // This interface is implemented by both StableHLO and MHLO dialects + // and is used as the foundation for sharing verification, type inference and + // prettyprinting logic between them. +@@ -249,6 +253,10 @@ + template + class BroadcastingElementwise + : public mlir::OpTrait::TraitBase {}; ++ ++template ++class IsCommutative ++ : public mlir::OpTrait::TraitBase {}; + + template + class PairwiseSameOperandAndResultType +diff --ruN a/stablehlo/stablehlo/dialect/Base.td b/stablehlo/stablehlo/dialect/Base.td +--- stablehlo/stablehlo/dialect/Base.td ++++ stablehlo/stablehlo/dialect/Base.td +@@ -188,6 +188,11 @@ + // An operation that is essentially element-wise but may implement broadcasting + // semantics. + def HLO_BroadcastingElementwise : HLO_NativeOpTrait<"BroadcastingElementwise">; ++ ++// This class adds property that the operation is commutative. ++// Upstream IsCommutative has default folders, and StableHLO aims to have no ++// default folders or canonicalizations. ++def HLO_Commutative : HLO_NativeOpTrait<"IsCommutative">; + + // Op has pairwise operand and result type matching: the number of operands + // must be equal to the number of results and the type of ith operand must +diff --ruN a/stablehlo/stablehlo/dialect/CMakeLists.txt b/stablehlo/stablehlo/dialect/CMakeLists.txt +--- stablehlo/stablehlo/dialect/CMakeLists.txt ++++ stablehlo/stablehlo/dialect/CMakeLists.txt +@@ -77,6 +77,20 @@ + target_include_directories(ChloOps INTERFACE + $ + $ ++) ++ ++add_mlir_dialect_library(ExperimentalOps ++ PARTIAL_SOURCES_INTENDED ++ ExperimentalOps.cpp ++ ++ DEPENDS ++ StablehloOpsIncGen ++ ++ LINK_LIBS PUBLIC ++ MLIRFuncDialect ++ MLIRIR ++ MLIRSupport ++ StablehloOps + ) + + add_mlir_dialect_library(StablehloRegister +diff --ruN a/stablehlo/stablehlo/dialect/ExperimentalOps.cpp b/stablehlo/stablehlo/dialect/ExperimentalOps.cpp +--- stablehlo/stablehlo/dialect/ExperimentalOps.cpp ++++ stablehlo/stablehlo/dialect/ExperimentalOps.cpp +@@ -0,0 +1,392 @@ ++/* Copyright 2023 The StableHLO Authors. ++ ++Licensed under the Apache License, Version 2.0 (the "License"); ++you may not use this file except in compliance with the License. ++You may obtain a copy of the License at ++ ++ http://www.apache.org/licenses/LICENSE-2.0 ++ ++Unless required by applicable law or agreed to in writing, software ++distributed under the License is distributed on an "AS IS" BASIS, ++WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++See the License for the specific language governing permissions and ++limitations under the License. ++==============================================================================*/ ++ ++#include "stablehlo/dialect/ExperimentalOps.h" ++ ++#include ++ ++#include "llvm/ADT/ArrayRef.h" ++#include "llvm/ADT/STLExtras.h" ++#include "mlir/Dialect/Func/IR/FuncOps.h" ++#include "mlir/IR/BuiltinAttributes.h" ++#include "mlir/IR/BuiltinOps.h" ++#include "mlir/IR/BuiltinTypeInterfaces.h" ++#include "mlir/IR/Types.h" ++ ++namespace mlir { ++namespace stablehlo { ++ ++LogicalResult DynamicReduceWindowOpAdaptor::verify() { ++ // Before checking the constraints inherited from ReduceWindowOp, ++ // make sure that the operands and the attributes of the underlying custom ++ // call make sense. ++ if (op_->getNumOperands() != 2 * op_->getNumResults() + 5) ++ return op_.emitError("expects size(operands) = 2 * size(results) + 5"); ++ if (op_->getNumResults() == 0) ++ return op_.emitError("expects size(results) > 0"); ++ for (const auto& attr : op_->getAttrs()) { ++ // api_version and backend_config have default values. ++ // call_target_name should be "stablehlo.dynamic_reduce_window". ++ // called_computations carries the body. ++ if (attr.getName() != "api_version" && ++ attr.getName() != "backend_config" && ++ attr.getName() != "call_target_name" && ++ attr.getName() != "called_computations") ++ return op_.emitError() ++ << attr.getName() << " is not a supported attribute"; ++ } ++ if (!op_.getBackendConfig().empty()) ++ return op_.emitError() << "expects an empty backend_config"; ++ if (op_.getCallTargetName() != "stablehlo.dynamic_reduce_window") ++ return op_.emitError() << "expects @stablehlo.dynamic_reduce_window"; ++ ++ // Unpack operands and attributes of the underlying custom call into ++ // operation-specific inputs. ++ auto numInputs = getInputs().size(); ++ auto inputs = op_.getInputs().slice(0, numInputs); ++ auto initValues = op_.getInputs().slice(numInputs, numInputs); ++ auto windowDimensions = op_.getInputs()[op_.getInputs().size() - 5]; ++ auto windowStrides = op_.getInputs()[op_.getInputs().size() - 4]; ++ auto baseDilations = op_.getInputs()[op_.getInputs().size() - 3]; ++ auto windowDilations = op_.getInputs()[op_.getInputs().size() - 2]; ++ auto padding = op_.getInputs()[op_.getInputs().size() - 1]; ++ auto results = op_.getResults(); ++ ++ // reduce_window_c1 ++ // This constraint hold automatically thanks to the checks that we have ++ // performed above. ++ ++ // reduce_window_i1 ++ SmallVector inputTypes; ++ for (auto [index, input] : llvm::enumerate(inputs)) { ++ auto inputType = input.getType().dyn_cast(); ++ inputTypes.push_back(inputType); ++ if (!inputType) ++ return op_.emitError() ++ << "expects inputs (e.g. operand #" << index << ") to be tensors"; ++ } ++ ++ // reduce_window_i2 ++ SmallVector initValueTypes; ++ for (auto [index, initValue] : llvm::enumerate(initValues)) { ++ auto initValueType = initValue.getType().dyn_cast(); ++ initValueTypes.push_back(initValueType); ++ if (!initValueType || !initValueType.hasRank() || ++ initValueType.getRank() != 0) ++ return op_.emitError() << "expects init_values (e.g. operand #" ++ << numInputs + index << ") " ++ << "to be 0-dimensional tensors"; ++ } ++ ++ // reduce_window_i3...reduce_window_i7 ++ auto checkRank = [&](StringRef name, int64_t index, Value dynamicAttr, ++ int64_t expectedRank) -> LogicalResult { ++ auto type = dynamicAttr.getType().dyn_cast(); ++ if (!type || !type.hasRank() || type.getRank() != expectedRank || ++ !type.getElementType().isIntOrIndex()) { ++ if (index < 0) index += op_->getNumOperands(); ++ return op_.emitError() ++ << "expects " << name << " (operand #" << index << ") " ++ << "to be a " << expectedRank << "-dimensional tensor " ++ << "of integer or index type"; ++ } ++ return success(); ++ }; ++ if (failed(checkRank("window_dimensions", -5, windowDimensions, 1)) || ++ failed(checkRank("window_strides", -4, windowStrides, 1)) || ++ failed(checkRank("base_dilations", -3, baseDilations, 1)) || ++ failed(checkRank("window_dilations", -2, windowDilations, 1)) || ++ failed(checkRank("padding", -1, padding, 2))) ++ return failure(); ++ ++ // reduce_window_i7 ++ auto paddingType = getPadding().getType().dyn_cast(); ++ if (!paddingType || !paddingType.hasRank() || paddingType.getRank() != 2 || ++ paddingType.getDimSize(1) != 2 || ++ !paddingType.getElementType().isIntOrIndex()) ++ return op_.emitError() ++ << "expects padding_type (operand #" << op_.getNumOperands() - 1 ++ << ") to be a 2-dimensional tensor of integer or index type"; ++ ++ // reduce_window_c2 ++ std::optional> inputShape; ++ for (auto inputType : inputTypes) { ++ if (!inputType.hasRank()) continue; ++ if (!inputShape) inputShape = inputType.getShape(); ++ if (failed(verifyCompatibleShape(inputType.getShape(), *inputShape))) ++ return op_.emitError() << "expects all inputs (operands 0.." << numInputs ++ << ") to have compatible shapes"; ++ } ++ ++ // reduce_window_c3 ++ for (auto [inputType, initValueType] : ++ llvm::zip(inputTypes, initValueTypes)) { ++ if (inputType.getElementType() != initValueType.getElementType()) ++ return op_.emitError() << "expects inputs (operands 0.." << numInputs ++ << ") and init_values (operands " << numInputs ++ << ".." << numInputs * 2 << ") to have pairwise " ++ << "the same element types"; ++ } ++ ++ // reduce_window_c4...reduce_window_c12 ++ // In this range, we only verify the constraints with even numbers. ++ // Verifying the constraints with odd numbers would require knowing the ++ // actual values of window_dimensions, window_strides, etc. ++ // While we certainly can try to check whether they are constants and ++ // verify them in that case, that seems like too much at this point. ++ auto checkShape = [&](StringRef name, int64_t index, Value dynamicAttr, ++ ArrayRef expectedShape) -> LogicalResult { ++ auto type = dynamicAttr.getType().cast(); ++ if (type.getShape() != expectedShape) { ++ if (index < 0) index += op_->getNumOperands(); ++ return op_.emitError() ++ << "expects " << name << " (operand #" << index << ") " ++ << "to have shape [" << expectedShape << "]"; ++ } ++ return success(); ++ }; ++ if (inputShape) { ++ auto inputRank = static_cast(inputShape->size()); ++ if (failed(checkShape("window_dimensions", -5, windowDimensions, ++ {inputRank})) || ++ failed(checkShape("window_strides", -4, windowStrides, {inputRank})) || ++ failed(checkShape("base_dilations", -3, baseDilations, {inputRank})) || ++ failed( ++ checkShape("window_dilations", -2, windowDilations, {inputRank})) || ++ failed(checkShape("padding", -1, padding, {inputRank, 2}))) ++ return failure(); ++ } ++ ++ // reduce_window_c13 ++ if (op_.getCalledComputations().size() != 1) ++ return op_.emitError() << "expects called_computations to have 1 element"; ++ auto bodyAttr = op_.getCalledComputations()[0].cast(); ++ auto bodyFunc = ++ op_->getParentOfType().lookupSymbol(bodyAttr); ++ if (!bodyFunc) ++ return op_.emitError() << "expects called_computations to refer to " ++ << "a function that exists within a parent module"; ++ ++ // reduce_window_c13 ++ SmallVector expectedBodyInputs; ++ llvm::append_range(expectedBodyInputs, initValueTypes); ++ llvm::append_range(expectedBodyInputs, initValueTypes); ++ SmallVector expectedBodyOutputs; ++ llvm::append_range(expectedBodyOutputs, initValueTypes); ++ auto expectedBodyType = FunctionType::get( ++ op_.getContext(), expectedBodyInputs, expectedBodyOutputs); ++ if (bodyFunc.getFunctionType() != expectedBodyType) ++ return op_.emitError() << "expects body to have type " << expectedBodyType; ++ ++ // reduce_window_c14 ++ SmallVector resultTypes; ++ std::optional> resultShape; ++ for (auto result : results) { ++ auto resultType = result.getType().dyn_cast(); ++ resultTypes.push_back(resultType); ++ if (!resultType) return op_.emitError() << "expects results to be tensors"; ++ ++ if (!resultType.hasRank()) continue; ++ if (!resultShape) resultShape = resultType.getShape(); ++ if (failed(verifyCompatibleShape(resultType.getShape(), *resultShape))) ++ return op_.emitError() << "expects all results to have compatible shapes"; ++ } ++ ++ // reduce_window_c15 ++ // Verifying this constraint would require knowing the actual values of ++ // window_dimensions, window_strides, etc. ++ // While we certainly can try to check whether they are constants and ++ // verify them in that case, that seems like too much at this point. ++ ++ // reduce_window_c16 ++ for (auto [resultType, initValueType] : ++ llvm::zip(resultTypes, initValueTypes)) { ++ if (resultType.getElementType() != initValueType.getElementType()) ++ return op_.emitError() << "expects results and init_values (operands " ++ << numInputs << ".." << numInputs * 2 << ") " ++ << "to have pairwise the same element types"; ++ } ++ ++ return success(); ++} ++ ++ValueRange DynamicReduceWindowOpAdaptor::getInputs() { ++ auto numInputs = (op_.getInputs().size() - 5) / 2; ++ return op_.getInputs().slice(0, numInputs); ++} ++ ++ValueRange DynamicReduceWindowOpAdaptor::getInitValues() { ++ auto numInputs = (op_.getInputs().size() - 5) / 2; ++ return op_.getInputs().slice(numInputs, numInputs); ++} ++ ++TypedValue DynamicReduceWindowOpAdaptor::getWindowDimensions() { ++ return op_.getInputs()[op_.getInputs().size() - 5] ++ .cast>(); ++} ++ ++TypedValue DynamicReduceWindowOpAdaptor::getWindowStrides() { ++ return op_.getInputs()[op_.getInputs().size() - 4] ++ .cast>(); ++} ++ ++TypedValue DynamicReduceWindowOpAdaptor::getBaseDilations() { ++ return op_.getInputs()[op_.getInputs().size() - 3] ++ .cast>(); ++} ++ ++TypedValue DynamicReduceWindowOpAdaptor::getWindowDilations() { ++ return op_.getInputs()[op_.getInputs().size() - 2] ++ .cast>(); ++} ++ ++TypedValue DynamicReduceWindowOpAdaptor::getPadding() { ++ return op_.getInputs()[op_.getInputs().size() - 1] ++ .cast>(); ++} ++ ++Region& DynamicReduceWindowOpAdaptor::getBody() { ++ auto bodyAttr = op_.getCalledComputations()[0].cast(); ++ auto bodyFunc = ++ op_->getParentOfType().lookupSymbol(bodyAttr); ++ return bodyFunc.getBody(); ++} ++ ++ValueRange DynamicReduceWindowOpAdaptor::getResults() { ++ return op_.getResults(); ++} ++ ++std::optional getDynamicReduceWindowOp( ++ CustomCallOp op) { ++ if (op.getCallTargetName() != "stablehlo.dynamic_reduce_window") return {}; ++ return DynamicReduceWindowOpAdaptor(op); ++} ++ ++LogicalResult DynamicRngBitGeneratorOpAdaptor::verify() { ++ // Before checking the constraints inherited from RngBitGeneratorOp, ++ // make sure that the operands and the attributes of the underlying custom ++ // call make sense. ++ if (op_->getNumOperands() != 2) ++ return op_.emitError("expects size(operands) = 2"); ++ if (op_->getNumResults() != 2) ++ return op_.emitError("expects size(results) = 2"); ++ for (const auto& attr : op_->getAttrs()) { ++ // api_version and backend_config have default values. ++ // call_target_name should be "stablehlo.dynamic_rng_bit_generator". ++ // rng_algorithm comes from the operation. ++ if (attr.getName() != "api_version" && attr.getName() != "backend_config" && ++ attr.getName() != "call_target_name" && ++ attr.getName() != "rng_algorithm") ++ return op_.emitError() ++ << attr.getName() << " is not a supported attribute"; ++ } ++ if (!op_.getBackendConfig().empty()) ++ return op_.emitError() << "expects an empty backend_config"; ++ if (op_.getCallTargetName() != "stablehlo.dynamic_rng_bit_generator") ++ return op_.emitError() << "expects @stablehlo.dynamic_rng_bit_generator"; ++ if (!op_->hasAttr("rng_algorithm")) ++ return op_.emitError() << "expects an rng_algorithm"; ++ ++ // Unpack operands and attributes of the underlying custom call into ++ // operation-specific inputs. ++ auto rngAlgorithmAttr = op_->getAttr("rng_algorithm"); ++ auto initialState = op_.getInputs()[0]; ++ auto outputShape = op_.getInputs()[1]; ++ auto outputState = op_.getResults()[0]; ++ auto output = op_.getResults()[1]; ++ ++ // dynamic_rng_bit_generator_i1 ++ if (!rngAlgorithmAttr.isa()) ++ return op_.emitError() ++ << "expects a #stablehlo rng_algorithm"; ++ ++ // dynamic_rng_bit_generator_i2 ++ // TODO(#643): Clarify supported types for RngBitGeneratorOp. ++ auto initialStateType = initialState.getType().dyn_cast(); ++ if (!initialStateType || !initialStateType.getElementType().isIntOrFloat()) ++ return op_.emitError() ++ << "expects initial_state (operand #0) " ++ << "to be a tensor of integer or floating-point type"; ++ ++ // dynamic_rng_bit_generator_i3 ++ auto outputShapeType = outputShape.getType().dyn_cast(); ++ if (!outputShapeType || !outputShapeType.hasRank() || ++ outputShapeType.getRank() != 1 || ++ !outputShapeType.getElementType().isIntOrIndex()) ++ return op_.emitError() ++ << "expects output_shape (operand #1) " ++ << "to be a 1-dimensional tensor of integer or index type"; ++ ++ // dynamic_rng_bit_generator_o1 ++ // TODO(#643): Clarify supported types for RngBitGeneratorOp. ++ auto outputStateType = outputState.getType().dyn_cast(); ++ if (!outputStateType || !outputStateType.getElementType().isIntOrFloat()) ++ return op_.emitError() ++ << "expects output_state (result #0) " ++ << "to be a tensor of integer or floating-point type"; ++ ++ // dynamic_rng_bit_generator_o2 ++ auto outputType = output.getType().dyn_cast(); ++ if (!outputType || !outputType.getElementType().isIntOrFloat()) ++ return op_.emitError() ++ << "expects output (result #1) " ++ << "to be a tensor of integer or floating-point type"; ++ ++ // dynamic_rng_bit_generator_c1 ++ if (!hlo::isCompatibleForHloTypeInference(initialStateType, outputStateType)) ++ return op_.emitError() ++ << "expects initial_state (operand #0) and output_state (result #0) " ++ << "to have compatible shapes"; ++ ++ // dynamic_rng_bit_generator_c2 ++ // TODO(#486): Verify rng_algorithm in RngBitGeneratorOp. ++ ++ // dynamic_rng_bit_generator_c3 ++ if (!hlo::isCompatibleForHloTypeInference(outputShape, outputType)) ++ return op_.emitError() << "expects output (result #1) to have shape " ++ << "compatible with output_shape (operand #2)"; ++ ++ return success(); ++} ++ ++RngAlgorithm DynamicRngBitGeneratorOpAdaptor::getRngAlgorithm() { ++ return op_->getAttr("rng_algorithm").cast().getValue(); ++} ++ ++TypedValue DynamicRngBitGeneratorOpAdaptor::getInitialState() { ++ return op_.getInputs()[0].cast>(); ++} ++ ++TypedValue DynamicRngBitGeneratorOpAdaptor::getOutputShape() { ++ return op_.getInputs()[1].cast>(); ++} ++ ++TypedValue DynamicRngBitGeneratorOpAdaptor::getOutputState() { ++ return op_.getResults()[0].cast>(); ++} ++ ++TypedValue DynamicRngBitGeneratorOpAdaptor::getOutput() { ++ return op_.getResults()[1].cast>(); ++} ++ ++std::optional getDynamicRngBitGeneratorOp( ++ CustomCallOp op) { ++ if (op.getCallTargetName() != "stablehlo.dynamic_rng_bit_generator") ++ return {}; ++ return DynamicRngBitGeneratorOpAdaptor(op); ++} ++ ++} // namespace stablehlo ++} // namespace mlir +diff --ruN a/stablehlo/stablehlo/dialect/ExperimentalOps.h b/stablehlo/stablehlo/dialect/ExperimentalOps.h +--- stablehlo/stablehlo/dialect/ExperimentalOps.h ++++ stablehlo/stablehlo/dialect/ExperimentalOps.h +@@ -0,0 +1,170 @@ ++/* Copyright 2023 The StableHLO Authors. ++ ++Licensed under the Apache License, Version 2.0 (the "License"); ++you may not use this file except in compliance with the License. ++You may obtain a copy of the License at ++ ++ http://www.apache.org/licenses/LICENSE-2.0 ++ ++Unless required by applicable law or agreed to in writing, software ++distributed under the License is distributed on an "AS IS" BASIS, ++WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++See the License for the specific language governing permissions and ++limitations under the License. ++==============================================================================*/ ++ ++#ifndef STABLEHLO_DIALECT_EXPERIMENTAL_OPS_H ++#define STABLEHLO_DIALECT_EXPERIMENTAL_OPS_H ++ ++// This file supports XLA-specific experiments with the StableHLO opset. ++// These experiments are not yet ready to be upstreamed to openxla/stablehlo ++// and are incubating towards the respective StableHLO RFCs. ++// ++// Custom calls (which are the implementation vehicle of these experiments) ++// don't have compatibility guarantees within the StableHLO process, but ++// the StableHLO team at Google provides out-of-band guarantees for these ++// custom calls, with the same compatibility window as StableHLO upstream. ++ ++#include "mlir/IR/Operation.h" ++#include "mlir/IR/Region.h" ++#include "mlir/IR/Value.h" ++#include "mlir/IR/ValueRange.h" ++#include "mlir/Support/LogicalResult.h" ++#include "stablehlo/dialect/StablehloOps.h" ++ ++namespace mlir { ++namespace stablehlo { ++ ++// The DynamicReduceWindowOp experiment provides a dynamic version of ++// ReduceWindowOp. Once the dynamism RFC is figured out, we expect to have an ++// upstream representation for this notion. ++// ++// Within this experiment, DynamicReduceWindowOp is represented via the ++// `stablehlo.custom_call @stablehlo.dynamic_reduce_window` custom call. ++// This custom call has the following operands which represent a dynamic version ++// of operands and attributes of ReduceWindowOp: ++// * [0:N] => inputs ++// * [N:2*N] => init_values ++// * [-5] => window_dimensions ++// * [-4] => window_strides ++// * [-3] => base_dilations ++// * [-2] => window_dilations ++// * [-1] => padding ++// Additionally, to represent the body of DynamicReduceWindowOp, the custom call ++// has a satellite function attached to the custom call via called_computations. ++// ++// Semantics of DynamicReduceWindowOp are inherited from semantics of ++// https://github.com/openxla/stablehlo/blob/main/docs/spec.md#reduce_window ++// with the following exceptions: ++// 1) All tensor constants, i.e. window_dimensions, window_strides, ++// base_dilations, window_dilations and padding, become tensors of ++// integer type. ++// 2) As a result, some of the constraints can no longer be validated ++// statically. However, this operation still expects these constraints ++// to hold dynamically, and if they don't hold, the behavior is undefined. ++class DynamicReduceWindowOpAdaptor { ++ public: ++ DynamicReduceWindowOpAdaptor(CustomCallOp op) : op_(op) {} ++ operator Operation*() { return op_; } ++ Operation* operator->() { return op_; } ++ ++ // Same accessors as for stablehlo::ReduceWindowOp, except that all the ++ // std::optional attributes have turned into values. ++ // These accessors assume that the operation is well-formed (i.e. that it ++ // can pass verification). ++ ValueRange getInputs(); ++ ValueRange getInitValues(); ++ TypedValue getWindowDimensions(); ++ TypedValue getWindowStrides(); ++ TypedValue getBaseDilations(); ++ TypedValue getWindowDilations(); ++ TypedValue getPadding(); ++ Region& getBody(); ++ ValueRange getResults(); ++ ++ // Verifies the constraints documented above. ++ // Emits errors if errors are detected. ++ LogicalResult verify(); ++ ++ private: ++ CustomCallOp op_; ++}; ++ ++// Wraps a custom call in a DynamicReduceWindowAdaptor. ++// Fails if the call_target_name of the custom call doesn't match ++// "stablehlo.dynamic_reduce_window". ++std::optional getDynamicReduceWindowOp( ++ CustomCallOp op); ++ ++// The DynamicRngBitGeneratorOp experiment provides a dynamic version of ++// RngBitGeneratorOp. Once the dynamism RFC is figured out, we expect to have an ++// upstream representation for this notion. ++// ++// Within this experiment, DynamicRngBitGeneratorOp is represented via the ++// `stablehlo.custom_call @stablehlo.dynamic_rng_bit_generator` custom call. ++// This custom call has the regular operand of RngBitGeneratorOp plus an ++// additional `output_shape` operand that determines the shape of the output: ++// * [0] => initial_state ++// * [1] => output_shape ++// ++// Semantics of DynamicRngBitGeneratorOp are inherited from semantics of ++// https://github.com/openxla/stablehlo/blob/main/docs/spec.md#rng_bit_generator ++// extended with an additional input (I3) and an additional constraint (C3): ++// ++// #### Inputs ++// ++// | Label | Name | Type | ++// |-------|-----------------|----------------------------------------------| ++// | (I1) | `rng_algorithm` | enum of `DEFAULT`, `THREE_FRY`, and `PHILOX` | ++// | (I2) | `initial_state` | 1-dimensional tensor of type `ui64` | ++// | (I3) | `output_shape` | 1-dimensional tensor of integer type | ++// ++// #### Outputs ++// ++// | Name | Type | ++// |----------------|------------------------------------------| ++// | `output_state` | 1-dimensional tensor of type `ui64` | ++// | `output` | tensor of integer or floating-point type | ++// ++// #### Constraints ++// ++// * (C1) `type(initial_state) = type(output_state)`. ++// * (C2) `size(initial_state)` is defined as: ++// * implementation-defined if `rng_algorithm = DEFAULT`. ++// * `2` if `rng_algorithm = THREE_FRY`. ++// * `2` or `3` if `rng_algorithm = PHILOX`. ++// * (C3) `shape(output) = output_shape`. ++class DynamicRngBitGeneratorOpAdaptor { ++ public: ++ DynamicRngBitGeneratorOpAdaptor(CustomCallOp op) : op_(op) {} ++ operator Operation*() { return op_; } ++ Operation* operator->() { return op_; } ++ ++ // Same accessors as for stablehlo::RngBitGeneratorOp, extended with the ++ // additional `output_shape` operand. ++ // These accessors assume that the operation is well-formed (i.e. that it ++ // can pass verification). ++ RngAlgorithm getRngAlgorithm(); ++ TypedValue getInitialState(); ++ TypedValue getOutputShape(); ++ TypedValue getOutputState(); ++ TypedValue getOutput(); ++ ++ // Verifies the constraints documented above. ++ // Emits errors if errors are detected. ++ LogicalResult verify(); ++ ++ private: ++ CustomCallOp op_; ++}; ++ ++// Wraps a custom call in a DynamicReduceWindowAdaptor. ++// Fails if the call_target_name of the custom call doesn't match ++// "stablehlo.dynamic_rng_bit_generator". ++std::optional getDynamicRngBitGeneratorOp( ++ CustomCallOp op); ++ ++} // namespace stablehlo ++} // namespace mlir ++ ++#endif // STABLEHLO_DIALECT_EXPERIMENTAL_OPS_H +diff --ruN a/stablehlo/stablehlo/dialect/StablehloOps.cpp b/stablehlo/stablehlo/dialect/StablehloOps.cpp +--- stablehlo/stablehlo/dialect/StablehloOps.cpp ++++ stablehlo/stablehlo/dialect/StablehloOps.cpp +@@ -1467,7 +1467,7 @@ + if (innerOp.getNumOperands() != 2 || + !innerOp.hasTrait() || + !hasSameOperandAndResultTypes(innerOp) || +- !innerOp.hasTrait() || ++ !innerOp.hasTrait() || + !innerOp.hasTrait()) + return false; + +@@ -1664,7 +1664,7 @@ + if (!innerOpDialect || !innerOpDialect->getNamespace().equals("stablehlo") || + !innerOpNameInfo->hasTrait::Impl>() || + !innerOpNameInfo->hasTrait() || +- !innerOpNameInfo->hasTrait() || ++ !innerOpNameInfo->hasTrait() || + !innerOpNameInfo->hasTrait()) { + parser.emitError(loc, + "expected the inner-op to be a commutative binary-op from " +diff --ruN a/stablehlo/stablehlo/dialect/StablehloOps.td b/stablehlo/stablehlo/dialect/StablehloOps.td +--- stablehlo/stablehlo/dialect/StablehloOps.td ++++ stablehlo/stablehlo/dialect/StablehloOps.td +@@ -687,7 +687,7 @@ + } + + def StableHLO_AddOp : StableHLO_BinaryElementwiseOp<"add", +- [Commutative, Pure, HLO_CompatibleOperandsAndResultType]> { ++ [HLO_Commutative, Pure, HLO_CompatibleOperandsAndResultType]> { + let summary = "Add operation"; + let description = [{ + Performs element-wise addition of two tensors `lhs` and `rhs` and produces a +@@ -769,7 +769,7 @@ + } + + def StableHLO_MaxOp : StableHLO_BinaryElementwiseOp<"maximum", +- [Commutative, Pure, HLO_CompatibleOperandsAndResultType]> { ++ [HLO_Commutative, Pure, HLO_CompatibleOperandsAndResultType]> { + let summary = "Max operation"; + let description = [{ + Performs element-wise max operation on tensors `lhs` and `rhs` and produces +@@ -786,7 +786,7 @@ + } + + def StableHLO_MinOp : StableHLO_BinaryElementwiseOp<"minimum", +- [Commutative, Pure, HLO_CompatibleOperandsAndResultType]> { ++ [HLO_Commutative, Pure, HLO_CompatibleOperandsAndResultType]> { + let summary = "Min operation"; + let description = [{ + Performs element-wise min operation on tensors `lhs` and `rhs` and produces a +@@ -803,7 +803,7 @@ + } + + def StableHLO_MulOp : StableHLO_BinaryElementwiseOp<"multiply", +- [Commutative, Pure, HLO_CompatibleOperandsAndResultType]> { ++ [HLO_Commutative, Pure, HLO_CompatibleOperandsAndResultType]> { + let summary = "Mul operation"; + let description = [{ + Performs element-wise product of two tensors `lhs` and `rhs` and produces a +@@ -933,7 +933,7 @@ + // See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations + class StableHLO_BinaryBiwiseOrLogicalElementwiseOp : + StableHLO_BinaryElementwiseOp { ++ [HLO_Commutative, Pure, HLO_CompatibleOperandsAndResultType]> { + let arguments = (ins + HLO_PredOrIntTensor:$lhs, + HLO_PredOrIntTensor:$rhs +diff --ruN a/stablehlo/stablehlo/testdata/acosh_shape_bfloat16_20_20.mlir b/stablehlo/stablehlo/testdata/acosh_shape_bfloat16_20_20.mlir +--- stablehlo/stablehlo/testdata/acosh_shape_bfloat16_20_20.mlir ++++ stablehlo/stablehlo/testdata/acosh_shape_bfloat16_20_20.mlir +@@ -16,9 +16,9 @@ + %10 = stablehlo.constant dense<6.914060e-01> : tensor<20x20xbf16> + %11 = stablehlo.add %8, %10 : tensor<20x20xbf16> + %12 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xbf16> +- %13 = stablehlo.add %12, %0 : tensor<20x20xbf16> ++ %13 = stablehlo.add %0, %12 : tensor<20x20xbf16> + %14 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xbf16> +- %15 = stablehlo.add %14, %0 : tensor<20x20xbf16> ++ %15 = stablehlo.add %0, %14 : tensor<20x20xbf16> + %16 = stablehlo.multiply %13, %15 : tensor<20x20xbf16> + %17 = stablehlo.sqrt %16 : tensor<20x20xbf16> + %18 = stablehlo.add %0, %17 : tensor<20x20xbf16> +diff --ruN a/stablehlo/stablehlo/testdata/acosh_shape_float16_20_20.mlir b/stablehlo/stablehlo/testdata/acosh_shape_float16_20_20.mlir +--- stablehlo/stablehlo/testdata/acosh_shape_float16_20_20.mlir ++++ stablehlo/stablehlo/testdata/acosh_shape_float16_20_20.mlir +@@ -16,9 +16,9 @@ + %10 = stablehlo.constant dense<6.933590e-01> : tensor<20x20xf16> + %11 = stablehlo.add %8, %10 : tensor<20x20xf16> + %12 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf16> +- %13 = stablehlo.add %12, %0 : tensor<20x20xf16> ++ %13 = stablehlo.add %0, %12 : tensor<20x20xf16> + %14 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xf16> +- %15 = stablehlo.add %14, %0 : tensor<20x20xf16> ++ %15 = stablehlo.add %0, %14 : tensor<20x20xf16> + %16 = stablehlo.multiply %13, %15 : tensor<20x20xf16> + %17 = stablehlo.sqrt %16 : tensor<20x20xf16> + %18 = stablehlo.add %0, %17 : tensor<20x20xf16> +diff --ruN a/stablehlo/stablehlo/testdata/acosh_shape_float32_20_20.mlir b/stablehlo/stablehlo/testdata/acosh_shape_float32_20_20.mlir +--- stablehlo/stablehlo/testdata/acosh_shape_float32_20_20.mlir ++++ stablehlo/stablehlo/testdata/acosh_shape_float32_20_20.mlir +@@ -16,9 +16,9 @@ + %10 = stablehlo.constant dense<0.693147182> : tensor<20x20xf32> + %11 = stablehlo.add %8, %10 : tensor<20x20xf32> + %12 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> +- %13 = stablehlo.add %12, %0 : tensor<20x20xf32> ++ %13 = stablehlo.add %0, %12 : tensor<20x20xf32> + %14 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xf32> +- %15 = stablehlo.add %14, %0 : tensor<20x20xf32> ++ %15 = stablehlo.add %0, %14 : tensor<20x20xf32> + %16 = stablehlo.multiply %13, %15 : tensor<20x20xf32> + %17 = stablehlo.sqrt %16 : tensor<20x20xf32> + %18 = stablehlo.add %0, %17 : tensor<20x20xf32> +diff --ruN a/stablehlo/stablehlo/testdata/asin_shape_bfloat16_20_20.mlir b/stablehlo/stablehlo/testdata/asin_shape_bfloat16_20_20.mlir +--- stablehlo/stablehlo/testdata/asin_shape_bfloat16_20_20.mlir ++++ stablehlo/stablehlo/testdata/asin_shape_bfloat16_20_20.mlir +@@ -11,9 +11,9 @@ + %5 = stablehlo.multiply %0, %0 : tensor<20x20xbf16> + %6 = stablehlo.subtract %4, %5 : tensor<20x20xbf16> + %7 = stablehlo.sqrt %6 : tensor<20x20xbf16> +- %8 = stablehlo.add %3, %7 : tensor<20x20xbf16> ++ %8 = stablehlo.add %7, %3 : tensor<20x20xbf16> + %9 = stablehlo.atan2 %0, %8 : tensor<20x20xbf16> +- %10 = stablehlo.multiply %2, %9 : tensor<20x20xbf16> ++ %10 = stablehlo.multiply %9, %2 : tensor<20x20xbf16> + %11 = stablehlo.custom_call @check.eq(%10, %1) : (tensor<20x20xbf16>, tensor<20x20xbf16>) -> tensor + return %11 : tensor + } +diff --ruN a/stablehlo/stablehlo/testdata/asin_shape_complex64_20_20.mlir b/stablehlo/stablehlo/testdata/asin_shape_complex64_20_20.mlir +--- stablehlo/stablehlo/testdata/asin_shape_complex64_20_20.mlir ++++ stablehlo/stablehlo/testdata/asin_shape_complex64_20_20.mlir +@@ -11,9 +11,9 @@ + %5 = stablehlo.multiply %0, %0 : tensor<20x20xcomplex> + %6 = stablehlo.subtract %4, %5 : tensor<20x20xcomplex> + %7 = stablehlo.sqrt %6 : tensor<20x20xcomplex> +- %8 = stablehlo.add %3, %7 : tensor<20x20xcomplex> ++ %8 = stablehlo.add %7, %3 : tensor<20x20xcomplex> + %9 = stablehlo.atan2 %0, %8 : tensor<20x20xcomplex> +- %10 = stablehlo.multiply %2, %9 : tensor<20x20xcomplex> ++ %10 = stablehlo.multiply %9, %2 : tensor<20x20xcomplex> + %11 = stablehlo.custom_call @check.eq(%10, %1) : (tensor<20x20xcomplex>, tensor<20x20xcomplex>) -> tensor + return %11 : tensor + } +diff --ruN a/stablehlo/stablehlo/testdata/asin_shape_float16_20_20.mlir b/stablehlo/stablehlo/testdata/asin_shape_float16_20_20.mlir +--- stablehlo/stablehlo/testdata/asin_shape_float16_20_20.mlir ++++ stablehlo/stablehlo/testdata/asin_shape_float16_20_20.mlir +@@ -11,9 +11,9 @@ + %5 = stablehlo.multiply %0, %0 : tensor<20x20xf16> + %6 = stablehlo.subtract %4, %5 : tensor<20x20xf16> + %7 = stablehlo.sqrt %6 : tensor<20x20xf16> +- %8 = stablehlo.add %3, %7 : tensor<20x20xf16> ++ %8 = stablehlo.add %7, %3 : tensor<20x20xf16> + %9 = stablehlo.atan2 %0, %8 : tensor<20x20xf16> +- %10 = stablehlo.multiply %2, %9 : tensor<20x20xf16> ++ %10 = stablehlo.multiply %9, %2 : tensor<20x20xf16> + %11 = stablehlo.custom_call @check.eq(%10, %1) : (tensor<20x20xf16>, tensor<20x20xf16>) -> tensor + return %11 : tensor + } +diff --ruN a/stablehlo/stablehlo/testdata/asin_shape_float32_20_20.mlir b/stablehlo/stablehlo/testdata/asin_shape_float32_20_20.mlir +--- stablehlo/stablehlo/testdata/asin_shape_float32_20_20.mlir ++++ stablehlo/stablehlo/testdata/asin_shape_float32_20_20.mlir +@@ -11,9 +11,9 @@ + %5 = stablehlo.multiply %0, %0 : tensor<20x20xf32> + %6 = stablehlo.subtract %4, %5 : tensor<20x20xf32> + %7 = stablehlo.sqrt %6 : tensor<20x20xf32> +- %8 = stablehlo.add %3, %7 : tensor<20x20xf32> ++ %8 = stablehlo.add %7, %3 : tensor<20x20xf32> + %9 = stablehlo.atan2 %0, %8 : tensor<20x20xf32> +- %10 = stablehlo.multiply %2, %9 : tensor<20x20xf32> ++ %10 = stablehlo.multiply %9, %2 : tensor<20x20xf32> + %11 = stablehlo.custom_call @check.eq(%10, %1) : (tensor<20x20xf32>, tensor<20x20xf32>) -> tensor + return %11 : tensor + } +diff --ruN a/stablehlo/stablehlo/testdata/asinh_shape_bfloat16_20_20.mlir b/stablehlo/stablehlo/testdata/asinh_shape_bfloat16_20_20.mlir +--- stablehlo/stablehlo/testdata/asinh_shape_bfloat16_20_20.mlir ++++ stablehlo/stablehlo/testdata/asinh_shape_bfloat16_20_20.mlir +@@ -28,7 +28,7 @@ + %22 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xbf16> + %23 = stablehlo.add %21, %22 : tensor<20x20xbf16> + %24 = stablehlo.sqrt %23 : tensor<20x20xbf16> +- %25 = stablehlo.add %18, %24 : tensor<20x20xbf16> ++ %25 = stablehlo.add %24, %18 : tensor<20x20xbf16> + %26 = stablehlo.divide %17, %25 : tensor<20x20xbf16> + %27 = stablehlo.multiply %16, %26 : tensor<20x20xbf16> + %28 = stablehlo.add %15, %27 : tensor<20x20xbf16> +diff --ruN a/stablehlo/stablehlo/testdata/asinh_shape_float16_20_20.mlir b/stablehlo/stablehlo/testdata/asinh_shape_float16_20_20.mlir +--- stablehlo/stablehlo/testdata/asinh_shape_float16_20_20.mlir ++++ stablehlo/stablehlo/testdata/asinh_shape_float16_20_20.mlir +@@ -28,7 +28,7 @@ + %22 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf16> + %23 = stablehlo.add %21, %22 : tensor<20x20xf16> + %24 = stablehlo.sqrt %23 : tensor<20x20xf16> +- %25 = stablehlo.add %18, %24 : tensor<20x20xf16> ++ %25 = stablehlo.add %24, %18 : tensor<20x20xf16> + %26 = stablehlo.divide %17, %25 : tensor<20x20xf16> + %27 = stablehlo.multiply %16, %26 : tensor<20x20xf16> + %28 = stablehlo.add %15, %27 : tensor<20x20xf16> +diff --ruN a/stablehlo/stablehlo/testdata/asinh_shape_float32_20_20.mlir b/stablehlo/stablehlo/testdata/asinh_shape_float32_20_20.mlir +--- stablehlo/stablehlo/testdata/asinh_shape_float32_20_20.mlir ++++ stablehlo/stablehlo/testdata/asinh_shape_float32_20_20.mlir +@@ -28,7 +28,7 @@ + %22 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %23 = stablehlo.add %21, %22 : tensor<20x20xf32> + %24 = stablehlo.sqrt %23 : tensor<20x20xf32> +- %25 = stablehlo.add %18, %24 : tensor<20x20xf32> ++ %25 = stablehlo.add %24, %18 : tensor<20x20xf32> + %26 = stablehlo.divide %17, %25 : tensor<20x20xf32> + %27 = stablehlo.multiply %16, %26 : tensor<20x20xf32> + %28 = stablehlo.add %15, %27 : tensor<20x20xf32> +diff --ruN a/stablehlo/stablehlo/testdata/bessel_i0e_shape_bfloat16_20_20.mlir b/stablehlo/stablehlo/testdata/bessel_i0e_shape_bfloat16_20_20.mlir +--- stablehlo/stablehlo/testdata/bessel_i0e_shape_bfloat16_20_20.mlir ++++ stablehlo/stablehlo/testdata/bessel_i0e_shape_bfloat16_20_20.mlir +@@ -33,7 +33,7 @@ + %8 = stablehlo.constant dense<5.000000e-01> : tensor<20x20xf32> + %9 = stablehlo.constant dense<5.000000e-01> : tensor + %10 = stablehlo.constant dense<5.000000e-01> : tensor<20x20xf32> +- %11 = stablehlo.multiply %10, %3 : tensor<20x20xf32> ++ %11 = stablehlo.multiply %3, %10 : tensor<20x20xf32> + %12 = stablehlo.constant dense<2.000000e+00> : tensor + %13 = stablehlo.constant dense<2.000000e+00> : tensor<20x20xf32> + %14 = stablehlo.subtract %11, %13 : tensor<20x20xf32> +@@ -133,7 +133,7 @@ + %108 = stablehlo.constant dense<0.676795303> : tensor<20x20xf32> + %109 = stablehlo.add %106, %108 : tensor<20x20xf32> + %110 = stablehlo.subtract %109, %99 : tensor<20x20xf32> +- %111 = stablehlo.multiply %8, %110 : tensor<20x20xf32> ++ %111 = stablehlo.multiply %110, %8 : tensor<20x20xf32> + %112 = stablehlo.constant dense<5.000000e-01> : tensor + %113 = stablehlo.constant dense<5.000000e-01> : tensor<20x20xf32> + %114 = stablehlo.constant dense<3.200000e+01> : tensor +@@ -182,7 +182,7 @@ + %157 = stablehlo.constant dense<0.804490387> : tensor<20x20xf32> + %158 = stablehlo.add %155, %157 : tensor<20x20xf32> + %159 = stablehlo.subtract %158, %148 : tensor<20x20xf32> +- %160 = stablehlo.multiply %113, %159 : tensor<20x20xf32> ++ %160 = stablehlo.multiply %159, %113 : tensor<20x20xf32> + %161 = stablehlo.sqrt %3 : tensor<20x20xf32> + %162 = stablehlo.divide %160, %161 : tensor<20x20xf32> + %163 = stablehlo.select %6, %111, %162 : tensor<20x20xi1>, tensor<20x20xf32> +diff --ruN a/stablehlo/stablehlo/testdata/bessel_i0e_shape_float16_20_20.mlir b/stablehlo/stablehlo/testdata/bessel_i0e_shape_float16_20_20.mlir +--- stablehlo/stablehlo/testdata/bessel_i0e_shape_float16_20_20.mlir ++++ stablehlo/stablehlo/testdata/bessel_i0e_shape_float16_20_20.mlir +@@ -33,7 +33,7 @@ + %8 = stablehlo.constant dense<5.000000e-01> : tensor<20x20xf32> + %9 = stablehlo.constant dense<5.000000e-01> : tensor + %10 = stablehlo.constant dense<5.000000e-01> : tensor<20x20xf32> +- %11 = stablehlo.multiply %10, %3 : tensor<20x20xf32> ++ %11 = stablehlo.multiply %3, %10 : tensor<20x20xf32> + %12 = stablehlo.constant dense<2.000000e+00> : tensor + %13 = stablehlo.constant dense<2.000000e+00> : tensor<20x20xf32> + %14 = stablehlo.subtract %11, %13 : tensor<20x20xf32> +@@ -133,7 +133,7 @@ + %108 = stablehlo.constant dense<0.676795303> : tensor<20x20xf32> + %109 = stablehlo.add %106, %108 : tensor<20x20xf32> + %110 = stablehlo.subtract %109, %99 : tensor<20x20xf32> +- %111 = stablehlo.multiply %8, %110 : tensor<20x20xf32> ++ %111 = stablehlo.multiply %110, %8 : tensor<20x20xf32> + %112 = stablehlo.constant dense<5.000000e-01> : tensor + %113 = stablehlo.constant dense<5.000000e-01> : tensor<20x20xf32> + %114 = stablehlo.constant dense<3.200000e+01> : tensor +@@ -182,7 +182,7 @@ + %157 = stablehlo.constant dense<0.804490387> : tensor<20x20xf32> + %158 = stablehlo.add %155, %157 : tensor<20x20xf32> + %159 = stablehlo.subtract %158, %148 : tensor<20x20xf32> +- %160 = stablehlo.multiply %113, %159 : tensor<20x20xf32> ++ %160 = stablehlo.multiply %159, %113 : tensor<20x20xf32> + %161 = stablehlo.sqrt %3 : tensor<20x20xf32> + %162 = stablehlo.divide %160, %161 : tensor<20x20xf32> + %163 = stablehlo.select %6, %111, %162 : tensor<20x20xi1>, tensor<20x20xf32> +diff --ruN a/stablehlo/stablehlo/testdata/bessel_i0e_shape_float32_20_20.mlir b/stablehlo/stablehlo/testdata/bessel_i0e_shape_float32_20_20.mlir +--- stablehlo/stablehlo/testdata/bessel_i0e_shape_float32_20_20.mlir ++++ stablehlo/stablehlo/testdata/bessel_i0e_shape_float32_20_20.mlir +@@ -32,7 +32,7 @@ + %7 = stablehlo.constant dense<5.000000e-01> : tensor<20x20xf32> + %8 = stablehlo.constant dense<5.000000e-01> : tensor + %9 = stablehlo.constant dense<5.000000e-01> : tensor<20x20xf32> +- %10 = stablehlo.multiply %9, %2 : tensor<20x20xf32> ++ %10 = stablehlo.multiply %2, %9 : tensor<20x20xf32> + %11 = stablehlo.constant dense<2.000000e+00> : tensor + %12 = stablehlo.constant dense<2.000000e+00> : tensor<20x20xf32> + %13 = stablehlo.subtract %10, %12 : tensor<20x20xf32> +@@ -132,7 +132,7 @@ + %107 = stablehlo.constant dense<0.676795303> : tensor<20x20xf32> + %108 = stablehlo.add %105, %107 : tensor<20x20xf32> + %109 = stablehlo.subtract %108, %98 : tensor<20x20xf32> +- %110 = stablehlo.multiply %7, %109 : tensor<20x20xf32> ++ %110 = stablehlo.multiply %109, %7 : tensor<20x20xf32> + %111 = stablehlo.constant dense<5.000000e-01> : tensor + %112 = stablehlo.constant dense<5.000000e-01> : tensor<20x20xf32> + %113 = stablehlo.constant dense<3.200000e+01> : tensor +@@ -181,7 +181,7 @@ + %156 = stablehlo.constant dense<0.804490387> : tensor<20x20xf32> + %157 = stablehlo.add %154, %156 : tensor<20x20xf32> + %158 = stablehlo.subtract %157, %147 : tensor<20x20xf32> +- %159 = stablehlo.multiply %112, %158 : tensor<20x20xf32> ++ %159 = stablehlo.multiply %158, %112 : tensor<20x20xf32> + %160 = stablehlo.sqrt %2 : tensor<20x20xf32> + %161 = stablehlo.divide %159, %160 : tensor<20x20xf32> + %162 = stablehlo.select %5, %110, %161 : tensor<20x20xi1>, tensor<20x20xf32> +diff --ruN a/stablehlo/stablehlo/testdata/bessel_i1e_shape_bfloat16_20_20.mlir b/stablehlo/stablehlo/testdata/bessel_i1e_shape_bfloat16_20_20.mlir +--- stablehlo/stablehlo/testdata/bessel_i1e_shape_bfloat16_20_20.mlir ++++ stablehlo/stablehlo/testdata/bessel_i1e_shape_bfloat16_20_20.mlir +@@ -11,7 +11,7 @@ + %5 = stablehlo.constant dense<2.000000e+00> : tensor<20x20xf32> + %6 = stablehlo.constant dense<3.200000e+01> : tensor<20x20xf32> + %7 = stablehlo.constant dense<8.000000e+00> : tensor<20x20xf32> +- %8 = stablehlo.multiply %4, %3 : tensor<20x20xf32> ++ %8 = stablehlo.multiply %3, %4 : tensor<20x20xf32> + %9 = stablehlo.subtract %8, %5 : tensor<20x20xf32> + %10 = stablehlo.constant dense<0.000000e+00> : tensor<20x20xf32> + %11 = stablehlo.constant dense<0.000000e+00> : tensor<20x20xf32> +diff --ruN a/stablehlo/stablehlo/testdata/bessel_i1e_shape_float16_20_20.mlir b/stablehlo/stablehlo/testdata/bessel_i1e_shape_float16_20_20.mlir +--- stablehlo/stablehlo/testdata/bessel_i1e_shape_float16_20_20.mlir ++++ stablehlo/stablehlo/testdata/bessel_i1e_shape_float16_20_20.mlir +@@ -11,7 +11,7 @@ + %5 = stablehlo.constant dense<2.000000e+00> : tensor<20x20xf32> + %6 = stablehlo.constant dense<3.200000e+01> : tensor<20x20xf32> + %7 = stablehlo.constant dense<8.000000e+00> : tensor<20x20xf32> +- %8 = stablehlo.multiply %4, %3 : tensor<20x20xf32> ++ %8 = stablehlo.multiply %3, %4 : tensor<20x20xf32> + %9 = stablehlo.subtract %8, %5 : tensor<20x20xf32> + %10 = stablehlo.constant dense<0.000000e+00> : tensor<20x20xf32> + %11 = stablehlo.constant dense<0.000000e+00> : tensor<20x20xf32> +diff --ruN a/stablehlo/stablehlo/testdata/bessel_i1e_shape_float32_20_20.mlir b/stablehlo/stablehlo/testdata/bessel_i1e_shape_float32_20_20.mlir +--- stablehlo/stablehlo/testdata/bessel_i1e_shape_float32_20_20.mlir ++++ stablehlo/stablehlo/testdata/bessel_i1e_shape_float32_20_20.mlir +@@ -10,7 +10,7 @@ + %4 = stablehlo.constant dense<2.000000e+00> : tensor<20x20xf32> + %5 = stablehlo.constant dense<3.200000e+01> : tensor<20x20xf32> + %6 = stablehlo.constant dense<8.000000e+00> : tensor<20x20xf32> +- %7 = stablehlo.multiply %3, %2 : tensor<20x20xf32> ++ %7 = stablehlo.multiply %2, %3 : tensor<20x20xf32> + %8 = stablehlo.subtract %7, %4 : tensor<20x20xf32> + %9 = stablehlo.constant dense<0.000000e+00> : tensor<20x20xf32> + %10 = stablehlo.constant dense<0.000000e+00> : tensor<20x20xf32> +diff --ruN a/stablehlo/stablehlo/testdata/conv_general_dilated_1d_stride_2_even_enable_xla_True_dynamic.mlir b/stablehlo/stablehlo/testdata/conv_general_dilated_1d_stride_2_even_enable_xla_True_dynamic.mlir +--- stablehlo/stablehlo/testdata/conv_general_dilated_1d_stride_2_even_enable_xla_True_dynamic.mlir ++++ stablehlo/stablehlo/testdata/conv_general_dilated_1d_stride_2_even_enable_xla_True_dynamic.mlir +@@ -16,7 +16,7 @@ + %11 = stablehlo.constant dense<1> : tensor + %12 = stablehlo.subtract %3, %11 : tensor + %13 = stablehlo.select %10, %12, %3 : tensor, tensor +- %14 = stablehlo.multiply %2, %13 : tensor ++ %14 = stablehlo.multiply %13, %2 : tensor + %15 = stablehlo.subtract %1, %14 : tensor + %16 = stablehlo.constant dense<2> : tensor + %17 = stablehlo.add %15, %16 : tensor +@@ -32,7 +32,7 @@ + %27 = stablehlo.constant dense<1> : tensor + %28 = stablehlo.subtract %19, %27 : tensor + %29 = stablehlo.select %26, %28, %19 : tensor, tensor +- %30 = stablehlo.multiply %18, %29 : tensor ++ %30 = stablehlo.multiply %29, %18 : tensor + %31 = stablehlo.subtract %17, %30 : tensor + %32 = stablehlo.constant dense<-1> : tensor + %33 = stablehlo.multiply %arg0, %32 : tensor +@@ -48,7 +48,7 @@ + %43 = stablehlo.constant dense<1> : tensor + %44 = stablehlo.subtract %35, %43 : tensor + %45 = stablehlo.select %42, %44, %35 : tensor, tensor +- %46 = stablehlo.multiply %34, %45 : tensor ++ %46 = stablehlo.multiply %45, %34 : tensor + %47 = stablehlo.subtract %33, %46 : tensor + %48 = stablehlo.constant dense<-1> : tensor + %49 = stablehlo.multiply %arg0, %48 : tensor +@@ -64,7 +64,7 @@ + %59 = stablehlo.constant dense<1> : tensor + %60 = stablehlo.subtract %51, %59 : tensor + %61 = stablehlo.select %58, %60, %51 : tensor, tensor +- %62 = stablehlo.multiply %50, %61 : tensor ++ %62 = stablehlo.multiply %61, %50 : tensor + %63 = stablehlo.subtract %49, %62 : tensor + %64 = stablehlo.constant dense<2> : tensor + %65 = stablehlo.add %63, %64 : tensor +@@ -80,7 +80,7 @@ + %75 = stablehlo.constant dense<1> : tensor + %76 = stablehlo.subtract %67, %75 : tensor + %77 = stablehlo.select %74, %76, %67 : tensor, tensor +- %78 = stablehlo.multiply %66, %77 : tensor ++ %78 = stablehlo.multiply %77, %66 : tensor + %79 = stablehlo.subtract %65, %78 : tensor + %80 = stablehlo.constant dense<-1> : tensor + %81 = stablehlo.multiply %77, %80 : tensor +diff --ruN a/stablehlo/stablehlo/testdata/conv_general_dilated_1d_stride_2_odd_enable_xla_True_dynamic.mlir b/stablehlo/stablehlo/testdata/conv_general_dilated_1d_stride_2_odd_enable_xla_True_dynamic.mlir +--- stablehlo/stablehlo/testdata/conv_general_dilated_1d_stride_2_odd_enable_xla_True_dynamic.mlir ++++ stablehlo/stablehlo/testdata/conv_general_dilated_1d_stride_2_odd_enable_xla_True_dynamic.mlir +@@ -16,7 +16,7 @@ + %11 = stablehlo.constant dense<1> : tensor + %12 = stablehlo.subtract %3, %11 : tensor + %13 = stablehlo.select %10, %12, %3 : tensor, tensor +- %14 = stablehlo.multiply %2, %13 : tensor ++ %14 = stablehlo.multiply %13, %2 : tensor + %15 = stablehlo.subtract %1, %14 : tensor + %16 = stablehlo.constant dense<2> : tensor + %17 = stablehlo.add %15, %16 : tensor +@@ -32,7 +32,7 @@ + %27 = stablehlo.constant dense<1> : tensor + %28 = stablehlo.subtract %19, %27 : tensor + %29 = stablehlo.select %26, %28, %19 : tensor, tensor +- %30 = stablehlo.multiply %18, %29 : tensor ++ %30 = stablehlo.multiply %29, %18 : tensor + %31 = stablehlo.subtract %17, %30 : tensor + %32 = stablehlo.constant dense<-1> : tensor + %33 = stablehlo.multiply %arg0, %32 : tensor +@@ -48,7 +48,7 @@ + %43 = stablehlo.constant dense<1> : tensor + %44 = stablehlo.subtract %35, %43 : tensor + %45 = stablehlo.select %42, %44, %35 : tensor, tensor +- %46 = stablehlo.multiply %34, %45 : tensor ++ %46 = stablehlo.multiply %45, %34 : tensor + %47 = stablehlo.subtract %33, %46 : tensor + %48 = stablehlo.constant dense<-1> : tensor + %49 = stablehlo.multiply %arg0, %48 : tensor +@@ -64,7 +64,7 @@ + %59 = stablehlo.constant dense<1> : tensor + %60 = stablehlo.subtract %51, %59 : tensor + %61 = stablehlo.select %58, %60, %51 : tensor, tensor +- %62 = stablehlo.multiply %50, %61 : tensor ++ %62 = stablehlo.multiply %61, %50 : tensor + %63 = stablehlo.subtract %49, %62 : tensor + %64 = stablehlo.constant dense<2> : tensor + %65 = stablehlo.add %63, %64 : tensor +@@ -80,7 +80,7 @@ + %75 = stablehlo.constant dense<1> : tensor + %76 = stablehlo.subtract %67, %75 : tensor + %77 = stablehlo.select %74, %76, %67 : tensor, tensor +- %78 = stablehlo.multiply %66, %77 : tensor ++ %78 = stablehlo.multiply %77, %66 : tensor + %79 = stablehlo.subtract %65, %78 : tensor + %80 = stablehlo.constant dense<-1> : tensor + %81 = stablehlo.multiply %77, %80 : tensor +diff --ruN a/stablehlo/stablehlo/testdata/digamma_shape_bfloat16_20_20.mlir b/stablehlo/stablehlo/testdata/digamma_shape_bfloat16_20_20.mlir +--- stablehlo/stablehlo/testdata/digamma_shape_bfloat16_20_20.mlir ++++ stablehlo/stablehlo/testdata/digamma_shape_bfloat16_20_20.mlir +@@ -21,7 +21,7 @@ + %15 = stablehlo.divide %11, %14 : tensor<20x20xf32> + %16 = stablehlo.subtract %9, %15 : tensor<20x20xf32> + %17 = stablehlo.divide %11, %13 : tensor<20x20xf32> +- %18 = stablehlo.add %10, %17 : tensor<20x20xf32> ++ %18 = stablehlo.add %17, %10 : tensor<20x20xf32> + %19 = stablehlo.constant dense<-1259.13916> : tensor<20x20xf32> + %20 = stablehlo.constant dense<2.000000e+00> : tensor<20x20xf32> + %21 = stablehlo.add %8, %20 : tensor<20x20xf32> +@@ -79,11 +79,11 @@ + %73 = stablehlo.divide %67, %69 : tensor<20x20xf32> + %74 = stablehlo.add %66, %73 : tensor<20x20xf32> + %75 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> +- %76 = stablehlo.add %75, %8 : tensor<20x20xf32> ++ %76 = stablehlo.add %8, %75 : tensor<20x20xf32> + %77 = stablehlo.constant dense<2.01490307> : tensor<20x20xf32> + %78 = stablehlo.divide %8, %75 : tensor<20x20xf32> + %79 = stablehlo.log_plus_one %78 : tensor<20x20xf32> +- %80 = stablehlo.add %77, %79 : tensor<20x20xf32> ++ %80 = stablehlo.add %79, %77 : tensor<20x20xf32> + %81 = stablehlo.divide %72, %74 : tensor<20x20xf32> + %82 = stablehlo.constant dense<7.000000e+00> : tensor<20x20xf32> + %83 = stablehlo.divide %82, %76 : tensor<20x20xf32> +@@ -95,10 +95,10 @@ + %89 = stablehlo.abs %88 : tensor<20x20xf32> + %90 = stablehlo.add %2, %89 : tensor<20x20xf32> + %91 = stablehlo.constant dense<3.14159274> : tensor<20x20xf32> +- %92 = stablehlo.multiply %91, %90 : tensor<20x20xf32> ++ %92 = stablehlo.multiply %90, %91 : tensor<20x20xf32> + %93 = stablehlo.cosine %92 : tensor<20x20xf32> + %94 = stablehlo.sine %92 : tensor<20x20xf32> +- %95 = stablehlo.multiply %91, %93 : tensor<20x20xf32> ++ %95 = stablehlo.multiply %93, %91 : tensor<20x20xf32> + %96 = stablehlo.divide %95, %94 : tensor<20x20xf32> + %97 = stablehlo.subtract %85, %96 : tensor<20x20xf32> + %98 = stablehlo.select %4, %97, %85 : tensor<20x20xi1>, tensor<20x20xf32> +diff --ruN a/stablehlo/stablehlo/testdata/digamma_shape_float16_20_20.mlir b/stablehlo/stablehlo/testdata/digamma_shape_float16_20_20.mlir +--- stablehlo/stablehlo/testdata/digamma_shape_float16_20_20.mlir ++++ stablehlo/stablehlo/testdata/digamma_shape_float16_20_20.mlir +@@ -21,7 +21,7 @@ + %15 = stablehlo.divide %11, %14 : tensor<20x20xf32> + %16 = stablehlo.subtract %9, %15 : tensor<20x20xf32> + %17 = stablehlo.divide %11, %13 : tensor<20x20xf32> +- %18 = stablehlo.add %10, %17 : tensor<20x20xf32> ++ %18 = stablehlo.add %17, %10 : tensor<20x20xf32> + %19 = stablehlo.constant dense<-1259.13916> : tensor<20x20xf32> + %20 = stablehlo.constant dense<2.000000e+00> : tensor<20x20xf32> + %21 = stablehlo.add %8, %20 : tensor<20x20xf32> +@@ -79,11 +79,11 @@ + %73 = stablehlo.divide %67, %69 : tensor<20x20xf32> + %74 = stablehlo.add %66, %73 : tensor<20x20xf32> + %75 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> +- %76 = stablehlo.add %75, %8 : tensor<20x20xf32> ++ %76 = stablehlo.add %8, %75 : tensor<20x20xf32> + %77 = stablehlo.constant dense<2.01490307> : tensor<20x20xf32> + %78 = stablehlo.divide %8, %75 : tensor<20x20xf32> + %79 = stablehlo.log_plus_one %78 : tensor<20x20xf32> +- %80 = stablehlo.add %77, %79 : tensor<20x20xf32> ++ %80 = stablehlo.add %79, %77 : tensor<20x20xf32> + %81 = stablehlo.divide %72, %74 : tensor<20x20xf32> + %82 = stablehlo.constant dense<7.000000e+00> : tensor<20x20xf32> + %83 = stablehlo.divide %82, %76 : tensor<20x20xf32> +@@ -95,10 +95,10 @@ + %89 = stablehlo.abs %88 : tensor<20x20xf32> + %90 = stablehlo.add %2, %89 : tensor<20x20xf32> + %91 = stablehlo.constant dense<3.14159274> : tensor<20x20xf32> +- %92 = stablehlo.multiply %91, %90 : tensor<20x20xf32> ++ %92 = stablehlo.multiply %90, %91 : tensor<20x20xf32> + %93 = stablehlo.cosine %92 : tensor<20x20xf32> + %94 = stablehlo.sine %92 : tensor<20x20xf32> +- %95 = stablehlo.multiply %91, %93 : tensor<20x20xf32> ++ %95 = stablehlo.multiply %93, %91 : tensor<20x20xf32> + %96 = stablehlo.divide %95, %94 : tensor<20x20xf32> + %97 = stablehlo.subtract %85, %96 : tensor<20x20xf32> + %98 = stablehlo.select %4, %97, %85 : tensor<20x20xi1>, tensor<20x20xf32> +diff --ruN a/stablehlo/stablehlo/testdata/digamma_shape_float32_20_20.mlir b/stablehlo/stablehlo/testdata/digamma_shape_float32_20_20.mlir +--- stablehlo/stablehlo/testdata/digamma_shape_float32_20_20.mlir ++++ stablehlo/stablehlo/testdata/digamma_shape_float32_20_20.mlir +@@ -20,7 +20,7 @@ + %14 = stablehlo.divide %10, %13 : tensor<20x20xf32> + %15 = stablehlo.subtract %8, %14 : tensor<20x20xf32> + %16 = stablehlo.divide %10, %12 : tensor<20x20xf32> +- %17 = stablehlo.add %9, %16 : tensor<20x20xf32> ++ %17 = stablehlo.add %16, %9 : tensor<20x20xf32> + %18 = stablehlo.constant dense<-1259.13916> : tensor<20x20xf32> + %19 = stablehlo.constant dense<2.000000e+00> : tensor<20x20xf32> + %20 = stablehlo.add %7, %19 : tensor<20x20xf32> +@@ -78,11 +78,11 @@ + %72 = stablehlo.divide %66, %68 : tensor<20x20xf32> + %73 = stablehlo.add %65, %72 : tensor<20x20xf32> + %74 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> +- %75 = stablehlo.add %74, %7 : tensor<20x20xf32> ++ %75 = stablehlo.add %7, %74 : tensor<20x20xf32> + %76 = stablehlo.constant dense<2.01490307> : tensor<20x20xf32> + %77 = stablehlo.divide %7, %74 : tensor<20x20xf32> + %78 = stablehlo.log_plus_one %77 : tensor<20x20xf32> +- %79 = stablehlo.add %76, %78 : tensor<20x20xf32> ++ %79 = stablehlo.add %78, %76 : tensor<20x20xf32> + %80 = stablehlo.divide %71, %73 : tensor<20x20xf32> + %81 = stablehlo.constant dense<7.000000e+00> : tensor<20x20xf32> + %82 = stablehlo.divide %81, %75 : tensor<20x20xf32> +@@ -94,10 +94,10 @@ + %88 = stablehlo.abs %87 : tensor<20x20xf32> + %89 = stablehlo.add %0, %88 : tensor<20x20xf32> + %90 = stablehlo.constant dense<3.14159274> : tensor<20x20xf32> +- %91 = stablehlo.multiply %90, %89 : tensor<20x20xf32> ++ %91 = stablehlo.multiply %89, %90 : tensor<20x20xf32> + %92 = stablehlo.cosine %91 : tensor<20x20xf32> + %93 = stablehlo.sine %91 : tensor<20x20xf32> +- %94 = stablehlo.multiply %90, %92 : tensor<20x20xf32> ++ %94 = stablehlo.multiply %92, %90 : tensor<20x20xf32> + %95 = stablehlo.divide %94, %93 : tensor<20x20xf32> + %96 = stablehlo.subtract %84, %95 : tensor<20x20xf32> + %97 = stablehlo.select %3, %96, %84 : tensor<20x20xi1>, tensor<20x20xf32> +diff --ruN a/stablehlo/stablehlo/testdata/erf_shape_bfloat16_20_20.mlir b/stablehlo/stablehlo/testdata/erf_shape_bfloat16_20_20.mlir +--- stablehlo/stablehlo/testdata/erf_shape_bfloat16_20_20.mlir ++++ stablehlo/stablehlo/testdata/erf_shape_bfloat16_20_20.mlir +@@ -11,7 +11,7 @@ + %5 = stablehlo.clamp %3, %2, %4 : tensor<20x20xf32> + %6 = stablehlo.multiply %5, %5 : tensor<20x20xf32> + %7 = stablehlo.constant dense<0.000000e+00> : tensor<20x20xf32> +- %8 = stablehlo.multiply %7, %6 : tensor<20x20xf32> ++ %8 = stablehlo.multiply %6, %7 : tensor<20x20xf32> + %9 = stablehlo.constant dense<-2.72614237E-10> : tensor<20x20xf32> + %10 = stablehlo.add %8, %9 : tensor<20x20xf32> + %11 = stablehlo.multiply %10, %6 : tensor<20x20xf32> +@@ -33,7 +33,7 @@ + %27 = stablehlo.constant dense<-0.0160960332> : tensor<20x20xf32> + %28 = stablehlo.add %26, %27 : tensor<20x20xf32> + %29 = stablehlo.constant dense<0.000000e+00> : tensor<20x20xf32> +- %30 = stablehlo.multiply %29, %6 : tensor<20x20xf32> ++ %30 = stablehlo.multiply %6, %29 : tensor<20x20xf32> + %31 = stablehlo.constant dense<-1.45660715E-5> : tensor<20x20xf32> + %32 = stablehlo.add %30, %31 : tensor<20x20xf32> + %33 = stablehlo.multiply %32, %6 : tensor<20x20xf32> +diff --ruN a/stablehlo/stablehlo/testdata/erf_shape_float16_20_20.mlir b/stablehlo/stablehlo/testdata/erf_shape_float16_20_20.mlir +--- stablehlo/stablehlo/testdata/erf_shape_float16_20_20.mlir ++++ stablehlo/stablehlo/testdata/erf_shape_float16_20_20.mlir +@@ -11,7 +11,7 @@ + %5 = stablehlo.clamp %3, %2, %4 : tensor<20x20xf32> + %6 = stablehlo.multiply %5, %5 : tensor<20x20xf32> + %7 = stablehlo.constant dense<0.000000e+00> : tensor<20x20xf32> +- %8 = stablehlo.multiply %7, %6 : tensor<20x20xf32> ++ %8 = stablehlo.multiply %6, %7 : tensor<20x20xf32> + %9 = stablehlo.constant dense<-2.72614237E-10> : tensor<20x20xf32> + %10 = stablehlo.add %8, %9 : tensor<20x20xf32> + %11 = stablehlo.multiply %10, %6 : tensor<20x20xf32> +@@ -33,7 +33,7 @@ + %27 = stablehlo.constant dense<-0.0160960332> : tensor<20x20xf32> + %28 = stablehlo.add %26, %27 : tensor<20x20xf32> + %29 = stablehlo.constant dense<0.000000e+00> : tensor<20x20xf32> +- %30 = stablehlo.multiply %29, %6 : tensor<20x20xf32> ++ %30 = stablehlo.multiply %6, %29 : tensor<20x20xf32> + %31 = stablehlo.constant dense<-1.45660715E-5> : tensor<20x20xf32> + %32 = stablehlo.add %30, %31 : tensor<20x20xf32> + %33 = stablehlo.multiply %32, %6 : tensor<20x20xf32> +diff --ruN a/stablehlo/stablehlo/testdata/erf_shape_float32_20_20.mlir b/stablehlo/stablehlo/testdata/erf_shape_float32_20_20.mlir +--- stablehlo/stablehlo/testdata/erf_shape_float32_20_20.mlir ++++ stablehlo/stablehlo/testdata/erf_shape_float32_20_20.mlir +@@ -10,7 +10,7 @@ + %4 = stablehlo.clamp %2, %0, %3 : tensor<20x20xf32> + %5 = stablehlo.multiply %4, %4 : tensor<20x20xf32> + %6 = stablehlo.constant dense<0.000000e+00> : tensor<20x20xf32> +- %7 = stablehlo.multiply %6, %5 : tensor<20x20xf32> ++ %7 = stablehlo.multiply %5, %6 : tensor<20x20xf32> + %8 = stablehlo.constant dense<-2.72614237E-10> : tensor<20x20xf32> + %9 = stablehlo.add %7, %8 : tensor<20x20xf32> + %10 = stablehlo.multiply %9, %5 : tensor<20x20xf32> +@@ -32,7 +32,7 @@ + %26 = stablehlo.constant dense<-0.0160960332> : tensor<20x20xf32> + %27 = stablehlo.add %25, %26 : tensor<20x20xf32> + %28 = stablehlo.constant dense<0.000000e+00> : tensor<20x20xf32> +- %29 = stablehlo.multiply %28, %5 : tensor<20x20xf32> ++ %29 = stablehlo.multiply %5, %28 : tensor<20x20xf32> + %30 = stablehlo.constant dense<-1.45660715E-5> : tensor<20x20xf32> + %31 = stablehlo.add %29, %30 : tensor<20x20xf32> + %32 = stablehlo.multiply %31, %5 : tensor<20x20xf32> +diff --ruN a/stablehlo/stablehlo/testdata/erfc_shape_bfloat16_20_20.mlir b/stablehlo/stablehlo/testdata/erfc_shape_bfloat16_20_20.mlir +--- stablehlo/stablehlo/testdata/erfc_shape_bfloat16_20_20.mlir ++++ stablehlo/stablehlo/testdata/erfc_shape_bfloat16_20_20.mlir +@@ -17,7 +17,7 @@ + %11 = stablehlo.constant dense<2.000000e+00> : tensor<20x20xf32> + %12 = stablehlo.compare LT, %5, %11 : (tensor<20x20xf32>, tensor<20x20xf32>) -> tensor<20x20xi1> + %13 = stablehlo.constant dense<0.000000e+00> : tensor<20x20xf32> +- %14 = stablehlo.multiply %13, %7 : tensor<20x20xf32> ++ %14 = stablehlo.multiply %7, %13 : tensor<20x20xf32> + %15 = stablehlo.constant dense<2.326820e-02> : tensor<20x20xf32> + %16 = stablehlo.add %14, %15 : tensor<20x20xf32> + %17 = stablehlo.multiply %16, %7 : tensor<20x20xf32> +@@ -45,7 +45,7 @@ + %39 = stablehlo.constant dense<0.563825965> : tensor<20x20xf32> + %40 = stablehlo.add %38, %39 : tensor<20x20xf32> + %41 = stablehlo.constant dense<0.000000e+00> : tensor<20x20xf32> +- %42 = stablehlo.multiply %41, %7 : tensor<20x20xf32> ++ %42 = stablehlo.multiply %7, %41 : tensor<20x20xf32> + %43 = stablehlo.constant dense<-10.477664> : tensor<20x20xf32> + %44 = stablehlo.add %42, %43 : tensor<20x20xf32> + %45 = stablehlo.multiply %44, %7 : tensor<20x20xf32> +@@ -81,7 +81,7 @@ + %75 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %76 = stablehlo.multiply %2, %2 : tensor<20x20xf32> + %77 = stablehlo.constant dense<0.000000e+00> : tensor<20x20xf32> +- %78 = stablehlo.multiply %77, %76 : tensor<20x20xf32> ++ %78 = stablehlo.multiply %76, %77 : tensor<20x20xf32> + %79 = stablehlo.constant dense<7.85386146E-5> : tensor<20x20xf32> + %80 = stablehlo.add %78, %79 : tensor<20x20xf32> + %81 = stablehlo.multiply %80, %76 : tensor<20x20xf32> +diff --ruN a/stablehlo/stablehlo/testdata/erfc_shape_float16_20_20.mlir b/stablehlo/stablehlo/testdata/erfc_shape_float16_20_20.mlir +--- stablehlo/stablehlo/testdata/erfc_shape_float16_20_20.mlir ++++ stablehlo/stablehlo/testdata/erfc_shape_float16_20_20.mlir +@@ -17,7 +17,7 @@ + %11 = stablehlo.constant dense<2.000000e+00> : tensor<20x20xf32> + %12 = stablehlo.compare LT, %5, %11 : (tensor<20x20xf32>, tensor<20x20xf32>) -> tensor<20x20xi1> + %13 = stablehlo.constant dense<0.000000e+00> : tensor<20x20xf32> +- %14 = stablehlo.multiply %13, %7 : tensor<20x20xf32> ++ %14 = stablehlo.multiply %7, %13 : tensor<20x20xf32> + %15 = stablehlo.constant dense<2.326820e-02> : tensor<20x20xf32> + %16 = stablehlo.add %14, %15 : tensor<20x20xf32> + %17 = stablehlo.multiply %16, %7 : tensor<20x20xf32> +@@ -45,7 +45,7 @@ + %39 = stablehlo.constant dense<0.563825965> : tensor<20x20xf32> + %40 = stablehlo.add %38, %39 : tensor<20x20xf32> + %41 = stablehlo.constant dense<0.000000e+00> : tensor<20x20xf32> +- %42 = stablehlo.multiply %41, %7 : tensor<20x20xf32> ++ %42 = stablehlo.multiply %7, %41 : tensor<20x20xf32> + %43 = stablehlo.constant dense<-10.477664> : tensor<20x20xf32> + %44 = stablehlo.add %42, %43 : tensor<20x20xf32> + %45 = stablehlo.multiply %44, %7 : tensor<20x20xf32> +@@ -81,7 +81,7 @@ + %75 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %76 = stablehlo.multiply %2, %2 : tensor<20x20xf32> + %77 = stablehlo.constant dense<0.000000e+00> : tensor<20x20xf32> +- %78 = stablehlo.multiply %77, %76 : tensor<20x20xf32> ++ %78 = stablehlo.multiply %76, %77 : tensor<20x20xf32> + %79 = stablehlo.constant dense<7.85386146E-5> : tensor<20x20xf32> + %80 = stablehlo.add %78, %79 : tensor<20x20xf32> + %81 = stablehlo.multiply %80, %76 : tensor<20x20xf32> +diff --ruN a/stablehlo/stablehlo/testdata/erfc_shape_float32_20_20.mlir b/stablehlo/stablehlo/testdata/erfc_shape_float32_20_20.mlir +--- stablehlo/stablehlo/testdata/erfc_shape_float32_20_20.mlir ++++ stablehlo/stablehlo/testdata/erfc_shape_float32_20_20.mlir +@@ -16,7 +16,7 @@ + %10 = stablehlo.constant dense<2.000000e+00> : tensor<20x20xf32> + %11 = stablehlo.compare LT, %4, %10 : (tensor<20x20xf32>, tensor<20x20xf32>) -> tensor<20x20xi1> + %12 = stablehlo.constant dense<0.000000e+00> : tensor<20x20xf32> +- %13 = stablehlo.multiply %12, %6 : tensor<20x20xf32> ++ %13 = stablehlo.multiply %6, %12 : tensor<20x20xf32> + %14 = stablehlo.constant dense<2.326820e-02> : tensor<20x20xf32> + %15 = stablehlo.add %13, %14 : tensor<20x20xf32> + %16 = stablehlo.multiply %15, %6 : tensor<20x20xf32> +@@ -44,7 +44,7 @@ + %38 = stablehlo.constant dense<0.563825965> : tensor<20x20xf32> + %39 = stablehlo.add %37, %38 : tensor<20x20xf32> + %40 = stablehlo.constant dense<0.000000e+00> : tensor<20x20xf32> +- %41 = stablehlo.multiply %40, %6 : tensor<20x20xf32> ++ %41 = stablehlo.multiply %6, %40 : tensor<20x20xf32> + %42 = stablehlo.constant dense<-10.477664> : tensor<20x20xf32> + %43 = stablehlo.add %41, %42 : tensor<20x20xf32> + %44 = stablehlo.multiply %43, %6 : tensor<20x20xf32> +@@ -80,7 +80,7 @@ + %74 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %75 = stablehlo.multiply %0, %0 : tensor<20x20xf32> + %76 = stablehlo.constant dense<0.000000e+00> : tensor<20x20xf32> +- %77 = stablehlo.multiply %76, %75 : tensor<20x20xf32> ++ %77 = stablehlo.multiply %75, %76 : tensor<20x20xf32> + %78 = stablehlo.constant dense<7.85386146E-5> : tensor<20x20xf32> + %79 = stablehlo.add %77, %78 : tensor<20x20xf32> + %80 = stablehlo.multiply %79, %75 : tensor<20x20xf32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_dtypes_shape_bfloat16_10__axis_0_enable_xla_True.mlir b/stablehlo/stablehlo/testdata/gather_dtypes_shape_bfloat16_10__axis_0_enable_xla_True.mlir +--- stablehlo/stablehlo/testdata/gather_dtypes_shape_bfloat16_10__axis_0_enable_xla_True.mlir ++++ stablehlo/stablehlo/testdata/gather_dtypes_shape_bfloat16_10__axis_0_enable_xla_True.mlir +@@ -34,7 +34,7 @@ + %12 = stablehlo.compare LT, %6, %11, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %13 = stablehlo.constant dense<1> : tensor + %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1xi32> +- %15 = stablehlo.add %6, %14 : tensor<1xi32> ++ %15 = stablehlo.add %14, %6 : tensor<1xi32> + %16 = stablehlo.select %12, %15, %6 : tensor<1xi1>, tensor<1xi32> + %17 = stablehlo.broadcast_in_dim %16, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %18 = "stablehlo.gather"(%9, %17) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -45,7 +45,7 @@ + %23 = stablehlo.compare LT, %7, %22, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %24 = stablehlo.constant dense<1> : tensor + %25 = stablehlo.broadcast_in_dim %24, dims = [] : (tensor) -> tensor<1xi32> +- %26 = stablehlo.add %7, %25 : tensor<1xi32> ++ %26 = stablehlo.add %25, %7 : tensor<1xi32> + %27 = stablehlo.select %23, %26, %7 : tensor<1xi1>, tensor<1xi32> + %28 = stablehlo.broadcast_in_dim %27, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %29 = "stablehlo.gather"(%20, %28) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_dtypes_shape_bool_10__axis_0_enable_xla_True.mlir b/stablehlo/stablehlo/testdata/gather_dtypes_shape_bool_10__axis_0_enable_xla_True.mlir +--- stablehlo/stablehlo/testdata/gather_dtypes_shape_bool_10__axis_0_enable_xla_True.mlir ++++ stablehlo/stablehlo/testdata/gather_dtypes_shape_bool_10__axis_0_enable_xla_True.mlir +@@ -34,7 +34,7 @@ + %12 = stablehlo.compare LT, %6, %11, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %13 = stablehlo.constant dense<1> : tensor + %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1xi32> +- %15 = stablehlo.add %6, %14 : tensor<1xi32> ++ %15 = stablehlo.add %14, %6 : tensor<1xi32> + %16 = stablehlo.select %12, %15, %6 : tensor<1xi1>, tensor<1xi32> + %17 = stablehlo.broadcast_in_dim %16, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %18 = "stablehlo.gather"(%9, %17) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -45,7 +45,7 @@ + %23 = stablehlo.compare LT, %7, %22, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %24 = stablehlo.constant dense<1> : tensor + %25 = stablehlo.broadcast_in_dim %24, dims = [] : (tensor) -> tensor<1xi32> +- %26 = stablehlo.add %7, %25 : tensor<1xi32> ++ %26 = stablehlo.add %25, %7 : tensor<1xi32> + %27 = stablehlo.select %23, %26, %7 : tensor<1xi1>, tensor<1xi32> + %28 = stablehlo.broadcast_in_dim %27, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %29 = "stablehlo.gather"(%20, %28) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_dtypes_shape_complex64_10__axis_0_enable_xla_True.mlir b/stablehlo/stablehlo/testdata/gather_dtypes_shape_complex64_10__axis_0_enable_xla_True.mlir +--- stablehlo/stablehlo/testdata/gather_dtypes_shape_complex64_10__axis_0_enable_xla_True.mlir ++++ stablehlo/stablehlo/testdata/gather_dtypes_shape_complex64_10__axis_0_enable_xla_True.mlir +@@ -34,7 +34,7 @@ + %12 = stablehlo.compare LT, %6, %11, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %13 = stablehlo.constant dense<1> : tensor + %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1xi32> +- %15 = stablehlo.add %6, %14 : tensor<1xi32> ++ %15 = stablehlo.add %14, %6 : tensor<1xi32> + %16 = stablehlo.select %12, %15, %6 : tensor<1xi1>, tensor<1xi32> + %17 = stablehlo.broadcast_in_dim %16, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %18 = "stablehlo.gather"(%9, %17) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -45,7 +45,7 @@ + %23 = stablehlo.compare LT, %7, %22, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %24 = stablehlo.constant dense<1> : tensor + %25 = stablehlo.broadcast_in_dim %24, dims = [] : (tensor) -> tensor<1xi32> +- %26 = stablehlo.add %7, %25 : tensor<1xi32> ++ %26 = stablehlo.add %25, %7 : tensor<1xi32> + %27 = stablehlo.select %23, %26, %7 : tensor<1xi1>, tensor<1xi32> + %28 = stablehlo.broadcast_in_dim %27, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %29 = "stablehlo.gather"(%20, %28) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_dtypes_shape_float16_10__axis_0_enable_xla_True.mlir b/stablehlo/stablehlo/testdata/gather_dtypes_shape_float16_10__axis_0_enable_xla_True.mlir +--- stablehlo/stablehlo/testdata/gather_dtypes_shape_float16_10__axis_0_enable_xla_True.mlir ++++ stablehlo/stablehlo/testdata/gather_dtypes_shape_float16_10__axis_0_enable_xla_True.mlir +@@ -34,7 +34,7 @@ + %12 = stablehlo.compare LT, %6, %11, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %13 = stablehlo.constant dense<1> : tensor + %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1xi32> +- %15 = stablehlo.add %6, %14 : tensor<1xi32> ++ %15 = stablehlo.add %14, %6 : tensor<1xi32> + %16 = stablehlo.select %12, %15, %6 : tensor<1xi1>, tensor<1xi32> + %17 = stablehlo.broadcast_in_dim %16, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %18 = "stablehlo.gather"(%9, %17) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -45,7 +45,7 @@ + %23 = stablehlo.compare LT, %7, %22, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %24 = stablehlo.constant dense<1> : tensor + %25 = stablehlo.broadcast_in_dim %24, dims = [] : (tensor) -> tensor<1xi32> +- %26 = stablehlo.add %7, %25 : tensor<1xi32> ++ %26 = stablehlo.add %25, %7 : tensor<1xi32> + %27 = stablehlo.select %23, %26, %7 : tensor<1xi1>, tensor<1xi32> + %28 = stablehlo.broadcast_in_dim %27, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %29 = "stablehlo.gather"(%20, %28) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_dtypes_shape_float32_10__axis_0_enable_xla_True.mlir b/stablehlo/stablehlo/testdata/gather_dtypes_shape_float32_10__axis_0_enable_xla_True.mlir +--- stablehlo/stablehlo/testdata/gather_dtypes_shape_float32_10__axis_0_enable_xla_True.mlir ++++ stablehlo/stablehlo/testdata/gather_dtypes_shape_float32_10__axis_0_enable_xla_True.mlir +@@ -34,7 +34,7 @@ + %12 = stablehlo.compare LT, %6, %11, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %13 = stablehlo.constant dense<1> : tensor + %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1xi32> +- %15 = stablehlo.add %6, %14 : tensor<1xi32> ++ %15 = stablehlo.add %14, %6 : tensor<1xi32> + %16 = stablehlo.select %12, %15, %6 : tensor<1xi1>, tensor<1xi32> + %17 = stablehlo.broadcast_in_dim %16, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %18 = "stablehlo.gather"(%9, %17) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -45,7 +45,7 @@ + %23 = stablehlo.compare LT, %7, %22, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %24 = stablehlo.constant dense<1> : tensor + %25 = stablehlo.broadcast_in_dim %24, dims = [] : (tensor) -> tensor<1xi32> +- %26 = stablehlo.add %7, %25 : tensor<1xi32> ++ %26 = stablehlo.add %25, %7 : tensor<1xi32> + %27 = stablehlo.select %23, %26, %7 : tensor<1xi1>, tensor<1xi32> + %28 = stablehlo.broadcast_in_dim %27, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %29 = "stablehlo.gather"(%20, %28) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_dtypes_shape_int16_10__axis_0_enable_xla_True.mlir b/stablehlo/stablehlo/testdata/gather_dtypes_shape_int16_10__axis_0_enable_xla_True.mlir +--- stablehlo/stablehlo/testdata/gather_dtypes_shape_int16_10__axis_0_enable_xla_True.mlir ++++ stablehlo/stablehlo/testdata/gather_dtypes_shape_int16_10__axis_0_enable_xla_True.mlir +@@ -34,7 +34,7 @@ + %12 = stablehlo.compare LT, %6, %11, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %13 = stablehlo.constant dense<1> : tensor + %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1xi32> +- %15 = stablehlo.add %6, %14 : tensor<1xi32> ++ %15 = stablehlo.add %14, %6 : tensor<1xi32> + %16 = stablehlo.select %12, %15, %6 : tensor<1xi1>, tensor<1xi32> + %17 = stablehlo.broadcast_in_dim %16, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %18 = "stablehlo.gather"(%9, %17) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -45,7 +45,7 @@ + %23 = stablehlo.compare LT, %7, %22, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %24 = stablehlo.constant dense<1> : tensor + %25 = stablehlo.broadcast_in_dim %24, dims = [] : (tensor) -> tensor<1xi32> +- %26 = stablehlo.add %7, %25 : tensor<1xi32> ++ %26 = stablehlo.add %25, %7 : tensor<1xi32> + %27 = stablehlo.select %23, %26, %7 : tensor<1xi1>, tensor<1xi32> + %28 = stablehlo.broadcast_in_dim %27, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %29 = "stablehlo.gather"(%20, %28) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_dtypes_shape_int32_10__axis_0_enable_xla_True.mlir b/stablehlo/stablehlo/testdata/gather_dtypes_shape_int32_10__axis_0_enable_xla_True.mlir +--- stablehlo/stablehlo/testdata/gather_dtypes_shape_int32_10__axis_0_enable_xla_True.mlir ++++ stablehlo/stablehlo/testdata/gather_dtypes_shape_int32_10__axis_0_enable_xla_True.mlir +@@ -34,7 +34,7 @@ + %12 = stablehlo.compare LT, %6, %11, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %13 = stablehlo.constant dense<1> : tensor + %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1xi32> +- %15 = stablehlo.add %6, %14 : tensor<1xi32> ++ %15 = stablehlo.add %14, %6 : tensor<1xi32> + %16 = stablehlo.select %12, %15, %6 : tensor<1xi1>, tensor<1xi32> + %17 = stablehlo.broadcast_in_dim %16, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %18 = "stablehlo.gather"(%9, %17) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -45,7 +45,7 @@ + %23 = stablehlo.compare LT, %7, %22, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %24 = stablehlo.constant dense<1> : tensor + %25 = stablehlo.broadcast_in_dim %24, dims = [] : (tensor) -> tensor<1xi32> +- %26 = stablehlo.add %7, %25 : tensor<1xi32> ++ %26 = stablehlo.add %25, %7 : tensor<1xi32> + %27 = stablehlo.select %23, %26, %7 : tensor<1xi1>, tensor<1xi32> + %28 = stablehlo.broadcast_in_dim %27, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %29 = "stablehlo.gather"(%20, %28) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_dtypes_shape_int8_10__axis_0_enable_xla_True.mlir b/stablehlo/stablehlo/testdata/gather_dtypes_shape_int8_10__axis_0_enable_xla_True.mlir +--- stablehlo/stablehlo/testdata/gather_dtypes_shape_int8_10__axis_0_enable_xla_True.mlir ++++ stablehlo/stablehlo/testdata/gather_dtypes_shape_int8_10__axis_0_enable_xla_True.mlir +@@ -34,7 +34,7 @@ + %12 = stablehlo.compare LT, %6, %11, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %13 = stablehlo.constant dense<1> : tensor + %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1xi32> +- %15 = stablehlo.add %6, %14 : tensor<1xi32> ++ %15 = stablehlo.add %14, %6 : tensor<1xi32> + %16 = stablehlo.select %12, %15, %6 : tensor<1xi1>, tensor<1xi32> + %17 = stablehlo.broadcast_in_dim %16, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %18 = "stablehlo.gather"(%9, %17) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -45,7 +45,7 @@ + %23 = stablehlo.compare LT, %7, %22, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %24 = stablehlo.constant dense<1> : tensor + %25 = stablehlo.broadcast_in_dim %24, dims = [] : (tensor) -> tensor<1xi32> +- %26 = stablehlo.add %7, %25 : tensor<1xi32> ++ %26 = stablehlo.add %25, %7 : tensor<1xi32> + %27 = stablehlo.select %23, %26, %7 : tensor<1xi1>, tensor<1xi32> + %28 = stablehlo.broadcast_in_dim %27, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %29 = "stablehlo.gather"(%20, %28) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_dtypes_shape_uint16_10__axis_0_enable_xla_True.mlir b/stablehlo/stablehlo/testdata/gather_dtypes_shape_uint16_10__axis_0_enable_xla_True.mlir +--- stablehlo/stablehlo/testdata/gather_dtypes_shape_uint16_10__axis_0_enable_xla_True.mlir ++++ stablehlo/stablehlo/testdata/gather_dtypes_shape_uint16_10__axis_0_enable_xla_True.mlir +@@ -34,7 +34,7 @@ + %12 = stablehlo.compare LT, %6, %11, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %13 = stablehlo.constant dense<1> : tensor + %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1xi32> +- %15 = stablehlo.add %6, %14 : tensor<1xi32> ++ %15 = stablehlo.add %14, %6 : tensor<1xi32> + %16 = stablehlo.select %12, %15, %6 : tensor<1xi1>, tensor<1xi32> + %17 = stablehlo.broadcast_in_dim %16, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %18 = "stablehlo.gather"(%9, %17) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -45,7 +45,7 @@ + %23 = stablehlo.compare LT, %7, %22, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %24 = stablehlo.constant dense<1> : tensor + %25 = stablehlo.broadcast_in_dim %24, dims = [] : (tensor) -> tensor<1xi32> +- %26 = stablehlo.add %7, %25 : tensor<1xi32> ++ %26 = stablehlo.add %25, %7 : tensor<1xi32> + %27 = stablehlo.select %23, %26, %7 : tensor<1xi1>, tensor<1xi32> + %28 = stablehlo.broadcast_in_dim %27, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %29 = "stablehlo.gather"(%20, %28) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_dtypes_shape_uint32_10__axis_0_enable_xla_True.mlir b/stablehlo/stablehlo/testdata/gather_dtypes_shape_uint32_10__axis_0_enable_xla_True.mlir +--- stablehlo/stablehlo/testdata/gather_dtypes_shape_uint32_10__axis_0_enable_xla_True.mlir ++++ stablehlo/stablehlo/testdata/gather_dtypes_shape_uint32_10__axis_0_enable_xla_True.mlir +@@ -34,7 +34,7 @@ + %12 = stablehlo.compare LT, %6, %11, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %13 = stablehlo.constant dense<1> : tensor + %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1xi32> +- %15 = stablehlo.add %6, %14 : tensor<1xi32> ++ %15 = stablehlo.add %14, %6 : tensor<1xi32> + %16 = stablehlo.select %12, %15, %6 : tensor<1xi1>, tensor<1xi32> + %17 = stablehlo.broadcast_in_dim %16, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %18 = "stablehlo.gather"(%9, %17) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -45,7 +45,7 @@ + %23 = stablehlo.compare LT, %7, %22, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %24 = stablehlo.constant dense<1> : tensor + %25 = stablehlo.broadcast_in_dim %24, dims = [] : (tensor) -> tensor<1xi32> +- %26 = stablehlo.add %7, %25 : tensor<1xi32> ++ %26 = stablehlo.add %25, %7 : tensor<1xi32> + %27 = stablehlo.select %23, %26, %7 : tensor<1xi1>, tensor<1xi32> + %28 = stablehlo.broadcast_in_dim %27, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %29 = "stablehlo.gather"(%20, %28) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_dtypes_shape_uint8_10__axis_0_enable_xla_True.mlir b/stablehlo/stablehlo/testdata/gather_dtypes_shape_uint8_10__axis_0_enable_xla_True.mlir +--- stablehlo/stablehlo/testdata/gather_dtypes_shape_uint8_10__axis_0_enable_xla_True.mlir ++++ stablehlo/stablehlo/testdata/gather_dtypes_shape_uint8_10__axis_0_enable_xla_True.mlir +@@ -34,7 +34,7 @@ + %12 = stablehlo.compare LT, %6, %11, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %13 = stablehlo.constant dense<1> : tensor + %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1xi32> +- %15 = stablehlo.add %6, %14 : tensor<1xi32> ++ %15 = stablehlo.add %14, %6 : tensor<1xi32> + %16 = stablehlo.select %12, %15, %6 : tensor<1xi1>, tensor<1xi32> + %17 = stablehlo.broadcast_in_dim %16, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %18 = "stablehlo.gather"(%9, %17) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -45,7 +45,7 @@ + %23 = stablehlo.compare LT, %7, %22, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %24 = stablehlo.constant dense<1> : tensor + %25 = stablehlo.broadcast_in_dim %24, dims = [] : (tensor) -> tensor<1xi32> +- %26 = stablehlo.add %7, %25 : tensor<1xi32> ++ %26 = stablehlo.add %25, %7 : tensor<1xi32> + %27 = stablehlo.select %23, %26, %7 : tensor<1xi1>, tensor<1xi32> + %28 = stablehlo.broadcast_in_dim %27, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %29 = "stablehlo.gather"(%20, %28) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__1__axis_0_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__1__axis_0_enable_xla_True_mode_fill.mlir +--- stablehlo/stablehlo/testdata/gather_from_take_indices_name__1__axis_0_enable_xla_True_mode_fill.mlir ++++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__1__axis_0_enable_xla_True_mode_fill.mlir +@@ -39,7 +39,7 @@ + %17 = stablehlo.compare LT, %6, %16, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %18 = stablehlo.constant dense<3> : tensor + %19 = stablehlo.broadcast_in_dim %18, dims = [] : (tensor) -> tensor<1xi32> +- %20 = stablehlo.add %6, %19 : tensor<1xi32> ++ %20 = stablehlo.add %19, %6 : tensor<1xi32> + %21 = stablehlo.select %17, %20, %6 : tensor<1xi1>, tensor<1xi32> + %22 = stablehlo.broadcast_in_dim %21, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %23 = "stablehlo.gather"(%14, %22) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -55,7 +55,7 @@ + %33 = stablehlo.compare LT, %7, %32, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %34 = stablehlo.constant dense<3> : tensor + %35 = stablehlo.broadcast_in_dim %34, dims = [] : (tensor) -> tensor<1xi32> +- %36 = stablehlo.add %7, %35 : tensor<1xi32> ++ %36 = stablehlo.add %35, %7 : tensor<1xi32> + %37 = stablehlo.select %33, %36, %7 : tensor<1xi1>, tensor<1xi32> + %38 = stablehlo.broadcast_in_dim %37, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %39 = "stablehlo.gather"(%30, %38) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__1__axis_1_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__1__axis_1_enable_xla_True_mode_fill.mlir +--- stablehlo/stablehlo/testdata/gather_from_take_indices_name__1__axis_1_enable_xla_True_mode_fill.mlir ++++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__1__axis_1_enable_xla_True_mode_fill.mlir +@@ -39,7 +39,7 @@ + %17 = stablehlo.compare LT, %6, %16, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %18 = stablehlo.constant dense<3> : tensor + %19 = stablehlo.broadcast_in_dim %18, dims = [] : (tensor) -> tensor<1xi32> +- %20 = stablehlo.add %6, %19 : tensor<1xi32> ++ %20 = stablehlo.add %19, %6 : tensor<1xi32> + %21 = stablehlo.select %17, %20, %6 : tensor<1xi1>, tensor<1xi32> + %22 = stablehlo.broadcast_in_dim %21, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %23 = "stablehlo.gather"(%14, %22) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -55,7 +55,7 @@ + %33 = stablehlo.compare LT, %7, %32, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %34 = stablehlo.constant dense<3> : tensor + %35 = stablehlo.broadcast_in_dim %34, dims = [] : (tensor) -> tensor<1xi32> +- %36 = stablehlo.add %7, %35 : tensor<1xi32> ++ %36 = stablehlo.add %35, %7 : tensor<1xi32> + %37 = stablehlo.select %33, %36, %7 : tensor<1xi1>, tensor<1xi32> + %38 = stablehlo.broadcast_in_dim %37, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %39 = "stablehlo.gather"(%30, %38) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__1__axis_2_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__1__axis_2_enable_xla_True_mode_fill.mlir +--- stablehlo/stablehlo/testdata/gather_from_take_indices_name__1__axis_2_enable_xla_True_mode_fill.mlir ++++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__1__axis_2_enable_xla_True_mode_fill.mlir +@@ -39,7 +39,7 @@ + %17 = stablehlo.compare LT, %6, %16, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %18 = stablehlo.constant dense<3> : tensor + %19 = stablehlo.broadcast_in_dim %18, dims = [] : (tensor) -> tensor<1xi32> +- %20 = stablehlo.add %6, %19 : tensor<1xi32> ++ %20 = stablehlo.add %19, %6 : tensor<1xi32> + %21 = stablehlo.select %17, %20, %6 : tensor<1xi1>, tensor<1xi32> + %22 = stablehlo.broadcast_in_dim %21, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %23 = "stablehlo.gather"(%14, %22) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -55,7 +55,7 @@ + %33 = stablehlo.compare LT, %7, %32, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %34 = stablehlo.constant dense<3> : tensor + %35 = stablehlo.broadcast_in_dim %34, dims = [] : (tensor) -> tensor<1xi32> +- %36 = stablehlo.add %7, %35 : tensor<1xi32> ++ %36 = stablehlo.add %35, %7 : tensor<1xi32> + %37 = stablehlo.select %33, %36, %7 : tensor<1xi1>, tensor<1xi32> + %38 = stablehlo.broadcast_in_dim %37, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %39 = "stablehlo.gather"(%30, %38) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__2__axis_0_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__2__axis_0_enable_xla_True_mode_fill.mlir +--- stablehlo/stablehlo/testdata/gather_from_take_indices_name__2__axis_0_enable_xla_True_mode_fill.mlir ++++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__2__axis_0_enable_xla_True_mode_fill.mlir +@@ -41,7 +41,7 @@ + %19 = stablehlo.compare LT, %8, %18, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %20 = stablehlo.constant dense<3> : tensor + %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<1xi32> +- %22 = stablehlo.add %8, %21 : tensor<1xi32> ++ %22 = stablehlo.add %21, %8 : tensor<1xi32> + %23 = stablehlo.select %19, %22, %8 : tensor<1xi1>, tensor<1xi32> + %24 = stablehlo.broadcast_in_dim %23, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %25 = "stablehlo.gather"(%16, %24) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -57,7 +57,7 @@ + %35 = stablehlo.compare LT, %9, %34, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %36 = stablehlo.constant dense<3> : tensor + %37 = stablehlo.broadcast_in_dim %36, dims = [] : (tensor) -> tensor<1xi32> +- %38 = stablehlo.add %9, %37 : tensor<1xi32> ++ %38 = stablehlo.add %37, %9 : tensor<1xi32> + %39 = stablehlo.select %35, %38, %9 : tensor<1xi1>, tensor<1xi32> + %40 = stablehlo.broadcast_in_dim %39, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %41 = "stablehlo.gather"(%32, %40) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__2__axis_1_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__2__axis_1_enable_xla_True_mode_fill.mlir +--- stablehlo/stablehlo/testdata/gather_from_take_indices_name__2__axis_1_enable_xla_True_mode_fill.mlir ++++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__2__axis_1_enable_xla_True_mode_fill.mlir +@@ -41,7 +41,7 @@ + %19 = stablehlo.compare LT, %8, %18, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %20 = stablehlo.constant dense<3> : tensor + %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<1xi32> +- %22 = stablehlo.add %8, %21 : tensor<1xi32> ++ %22 = stablehlo.add %21, %8 : tensor<1xi32> + %23 = stablehlo.select %19, %22, %8 : tensor<1xi1>, tensor<1xi32> + %24 = stablehlo.broadcast_in_dim %23, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %25 = "stablehlo.gather"(%16, %24) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -57,7 +57,7 @@ + %35 = stablehlo.compare LT, %9, %34, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %36 = stablehlo.constant dense<3> : tensor + %37 = stablehlo.broadcast_in_dim %36, dims = [] : (tensor) -> tensor<1xi32> +- %38 = stablehlo.add %9, %37 : tensor<1xi32> ++ %38 = stablehlo.add %37, %9 : tensor<1xi32> + %39 = stablehlo.select %35, %38, %9 : tensor<1xi1>, tensor<1xi32> + %40 = stablehlo.broadcast_in_dim %39, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %41 = "stablehlo.gather"(%32, %40) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__2__axis_2_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__2__axis_2_enable_xla_True_mode_fill.mlir +--- stablehlo/stablehlo/testdata/gather_from_take_indices_name__2__axis_2_enable_xla_True_mode_fill.mlir ++++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__2__axis_2_enable_xla_True_mode_fill.mlir +@@ -41,7 +41,7 @@ + %19 = stablehlo.compare LT, %8, %18, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %20 = stablehlo.constant dense<3> : tensor + %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<1xi32> +- %22 = stablehlo.add %8, %21 : tensor<1xi32> ++ %22 = stablehlo.add %21, %8 : tensor<1xi32> + %23 = stablehlo.select %19, %22, %8 : tensor<1xi1>, tensor<1xi32> + %24 = stablehlo.broadcast_in_dim %23, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %25 = "stablehlo.gather"(%16, %24) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -57,7 +57,7 @@ + %35 = stablehlo.compare LT, %9, %34, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %36 = stablehlo.constant dense<3> : tensor + %37 = stablehlo.broadcast_in_dim %36, dims = [] : (tensor) -> tensor<1xi32> +- %38 = stablehlo.add %9, %37 : tensor<1xi32> ++ %38 = stablehlo.add %37, %9 : tensor<1xi32> + %39 = stablehlo.select %35, %38, %9 : tensor<1xi1>, tensor<1xi32> + %40 = stablehlo.broadcast_in_dim %39, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %41 = "stablehlo.gather"(%32, %40) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__3__axis_0_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__3__axis_0_enable_xla_True_mode_fill.mlir +--- stablehlo/stablehlo/testdata/gather_from_take_indices_name__3__axis_0_enable_xla_True_mode_fill.mlir ++++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__3__axis_0_enable_xla_True_mode_fill.mlir +@@ -41,7 +41,7 @@ + %19 = stablehlo.compare LT, %8, %18, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %20 = stablehlo.constant dense<3> : tensor + %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<1xi32> +- %22 = stablehlo.add %8, %21 : tensor<1xi32> ++ %22 = stablehlo.add %21, %8 : tensor<1xi32> + %23 = stablehlo.select %19, %22, %8 : tensor<1xi1>, tensor<1xi32> + %24 = stablehlo.broadcast_in_dim %23, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %25 = "stablehlo.gather"(%16, %24) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -57,7 +57,7 @@ + %35 = stablehlo.compare LT, %9, %34, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %36 = stablehlo.constant dense<3> : tensor + %37 = stablehlo.broadcast_in_dim %36, dims = [] : (tensor) -> tensor<1xi32> +- %38 = stablehlo.add %9, %37 : tensor<1xi32> ++ %38 = stablehlo.add %37, %9 : tensor<1xi32> + %39 = stablehlo.select %35, %38, %9 : tensor<1xi1>, tensor<1xi32> + %40 = stablehlo.broadcast_in_dim %39, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %41 = "stablehlo.gather"(%32, %40) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__3__axis_1_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__3__axis_1_enable_xla_True_mode_fill.mlir +--- stablehlo/stablehlo/testdata/gather_from_take_indices_name__3__axis_1_enable_xla_True_mode_fill.mlir ++++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__3__axis_1_enable_xla_True_mode_fill.mlir +@@ -41,7 +41,7 @@ + %19 = stablehlo.compare LT, %8, %18, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %20 = stablehlo.constant dense<3> : tensor + %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<1xi32> +- %22 = stablehlo.add %8, %21 : tensor<1xi32> ++ %22 = stablehlo.add %21, %8 : tensor<1xi32> + %23 = stablehlo.select %19, %22, %8 : tensor<1xi1>, tensor<1xi32> + %24 = stablehlo.broadcast_in_dim %23, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %25 = "stablehlo.gather"(%16, %24) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -57,7 +57,7 @@ + %35 = stablehlo.compare LT, %9, %34, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %36 = stablehlo.constant dense<3> : tensor + %37 = stablehlo.broadcast_in_dim %36, dims = [] : (tensor) -> tensor<1xi32> +- %38 = stablehlo.add %9, %37 : tensor<1xi32> ++ %38 = stablehlo.add %37, %9 : tensor<1xi32> + %39 = stablehlo.select %35, %38, %9 : tensor<1xi1>, tensor<1xi32> + %40 = stablehlo.broadcast_in_dim %39, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %41 = "stablehlo.gather"(%32, %40) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__3__axis_2_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__3__axis_2_enable_xla_True_mode_fill.mlir +--- stablehlo/stablehlo/testdata/gather_from_take_indices_name__3__axis_2_enable_xla_True_mode_fill.mlir ++++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__3__axis_2_enable_xla_True_mode_fill.mlir +@@ -41,7 +41,7 @@ + %19 = stablehlo.compare LT, %8, %18, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %20 = stablehlo.constant dense<3> : tensor + %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<1xi32> +- %22 = stablehlo.add %8, %21 : tensor<1xi32> ++ %22 = stablehlo.add %21, %8 : tensor<1xi32> + %23 = stablehlo.select %19, %22, %8 : tensor<1xi1>, tensor<1xi32> + %24 = stablehlo.broadcast_in_dim %23, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %25 = "stablehlo.gather"(%16, %24) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -57,7 +57,7 @@ + %35 = stablehlo.compare LT, %9, %34, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %36 = stablehlo.constant dense<3> : tensor + %37 = stablehlo.broadcast_in_dim %36, dims = [] : (tensor) -> tensor<1xi32> +- %38 = stablehlo.add %9, %37 : tensor<1xi32> ++ %38 = stablehlo.add %37, %9 : tensor<1xi32> + %39 = stablehlo.select %35, %38, %9 : tensor<1xi1>, tensor<1xi32> + %40 = stablehlo.broadcast_in_dim %39, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %41 = "stablehlo.gather"(%32, %40) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__3_uint32__axis_0_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__3_uint32__axis_0_enable_xla_True_mode_fill.mlir +--- stablehlo/stablehlo/testdata/gather_from_take_indices_name__3_uint32__axis_0_enable_xla_True_mode_fill.mlir ++++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__3_uint32__axis_0_enable_xla_True_mode_fill.mlir +@@ -42,7 +42,7 @@ + %20 = stablehlo.compare LT, %8, %19, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %21 = stablehlo.constant dense<3> : tensor + %22 = stablehlo.broadcast_in_dim %21, dims = [] : (tensor) -> tensor<1xi32> +- %23 = stablehlo.add %8, %22 : tensor<1xi32> ++ %23 = stablehlo.add %22, %8 : tensor<1xi32> + %24 = stablehlo.select %20, %23, %8 : tensor<1xi1>, tensor<1xi32> + %25 = stablehlo.broadcast_in_dim %24, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %26 = "stablehlo.gather"(%16, %25) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -58,7 +58,7 @@ + %36 = stablehlo.compare LT, %9, %35, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %37 = stablehlo.constant dense<3> : tensor + %38 = stablehlo.broadcast_in_dim %37, dims = [] : (tensor) -> tensor<1xi32> +- %39 = stablehlo.add %9, %38 : tensor<1xi32> ++ %39 = stablehlo.add %38, %9 : tensor<1xi32> + %40 = stablehlo.select %36, %39, %9 : tensor<1xi1>, tensor<1xi32> + %41 = stablehlo.broadcast_in_dim %40, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %42 = "stablehlo.gather"(%33, %41) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__3_uint32__axis_1_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__3_uint32__axis_1_enable_xla_True_mode_fill.mlir +--- stablehlo/stablehlo/testdata/gather_from_take_indices_name__3_uint32__axis_1_enable_xla_True_mode_fill.mlir ++++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__3_uint32__axis_1_enable_xla_True_mode_fill.mlir +@@ -42,7 +42,7 @@ + %20 = stablehlo.compare LT, %8, %19, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %21 = stablehlo.constant dense<3> : tensor + %22 = stablehlo.broadcast_in_dim %21, dims = [] : (tensor) -> tensor<1xi32> +- %23 = stablehlo.add %8, %22 : tensor<1xi32> ++ %23 = stablehlo.add %22, %8 : tensor<1xi32> + %24 = stablehlo.select %20, %23, %8 : tensor<1xi1>, tensor<1xi32> + %25 = stablehlo.broadcast_in_dim %24, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %26 = "stablehlo.gather"(%16, %25) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -58,7 +58,7 @@ + %36 = stablehlo.compare LT, %9, %35, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %37 = stablehlo.constant dense<3> : tensor + %38 = stablehlo.broadcast_in_dim %37, dims = [] : (tensor) -> tensor<1xi32> +- %39 = stablehlo.add %9, %38 : tensor<1xi32> ++ %39 = stablehlo.add %38, %9 : tensor<1xi32> + %40 = stablehlo.select %36, %39, %9 : tensor<1xi1>, tensor<1xi32> + %41 = stablehlo.broadcast_in_dim %40, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %42 = "stablehlo.gather"(%33, %41) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__3_uint32__axis_2_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__3_uint32__axis_2_enable_xla_True_mode_fill.mlir +--- stablehlo/stablehlo/testdata/gather_from_take_indices_name__3_uint32__axis_2_enable_xla_True_mode_fill.mlir ++++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__3_uint32__axis_2_enable_xla_True_mode_fill.mlir +@@ -42,7 +42,7 @@ + %20 = stablehlo.compare LT, %8, %19, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %21 = stablehlo.constant dense<3> : tensor + %22 = stablehlo.broadcast_in_dim %21, dims = [] : (tensor) -> tensor<1xi32> +- %23 = stablehlo.add %8, %22 : tensor<1xi32> ++ %23 = stablehlo.add %22, %8 : tensor<1xi32> + %24 = stablehlo.select %20, %23, %8 : tensor<1xi1>, tensor<1xi32> + %25 = stablehlo.broadcast_in_dim %24, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %26 = "stablehlo.gather"(%16, %25) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -58,7 +58,7 @@ + %36 = stablehlo.compare LT, %9, %35, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %37 = stablehlo.constant dense<3> : tensor + %38 = stablehlo.broadcast_in_dim %37, dims = [] : (tensor) -> tensor<1xi32> +- %39 = stablehlo.add %9, %38 : tensor<1xi32> ++ %39 = stablehlo.add %38, %9 : tensor<1xi32> + %40 = stablehlo.select %36, %39, %9 : tensor<1xi1>, tensor<1xi32> + %41 = stablehlo.broadcast_in_dim %40, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %42 = "stablehlo.gather"(%33, %41) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__4__axis_0_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__4__axis_0_enable_xla_True_mode_fill.mlir +--- stablehlo/stablehlo/testdata/gather_from_take_indices_name__4__axis_0_enable_xla_True_mode_fill.mlir ++++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__4__axis_0_enable_xla_True_mode_fill.mlir +@@ -41,7 +41,7 @@ + %19 = stablehlo.compare LT, %8, %18, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %20 = stablehlo.constant dense<3> : tensor + %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<1xi32> +- %22 = stablehlo.add %8, %21 : tensor<1xi32> ++ %22 = stablehlo.add %21, %8 : tensor<1xi32> + %23 = stablehlo.select %19, %22, %8 : tensor<1xi1>, tensor<1xi32> + %24 = stablehlo.broadcast_in_dim %23, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %25 = "stablehlo.gather"(%16, %24) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -57,7 +57,7 @@ + %35 = stablehlo.compare LT, %9, %34, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %36 = stablehlo.constant dense<3> : tensor + %37 = stablehlo.broadcast_in_dim %36, dims = [] : (tensor) -> tensor<1xi32> +- %38 = stablehlo.add %9, %37 : tensor<1xi32> ++ %38 = stablehlo.add %37, %9 : tensor<1xi32> + %39 = stablehlo.select %35, %38, %9 : tensor<1xi1>, tensor<1xi32> + %40 = stablehlo.broadcast_in_dim %39, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %41 = "stablehlo.gather"(%32, %40) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__4__axis_1_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__4__axis_1_enable_xla_True_mode_fill.mlir +--- stablehlo/stablehlo/testdata/gather_from_take_indices_name__4__axis_1_enable_xla_True_mode_fill.mlir ++++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__4__axis_1_enable_xla_True_mode_fill.mlir +@@ -41,7 +41,7 @@ + %19 = stablehlo.compare LT, %8, %18, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %20 = stablehlo.constant dense<3> : tensor + %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<1xi32> +- %22 = stablehlo.add %8, %21 : tensor<1xi32> ++ %22 = stablehlo.add %21, %8 : tensor<1xi32> + %23 = stablehlo.select %19, %22, %8 : tensor<1xi1>, tensor<1xi32> + %24 = stablehlo.broadcast_in_dim %23, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %25 = "stablehlo.gather"(%16, %24) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -57,7 +57,7 @@ + %35 = stablehlo.compare LT, %9, %34, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %36 = stablehlo.constant dense<3> : tensor + %37 = stablehlo.broadcast_in_dim %36, dims = [] : (tensor) -> tensor<1xi32> +- %38 = stablehlo.add %9, %37 : tensor<1xi32> ++ %38 = stablehlo.add %37, %9 : tensor<1xi32> + %39 = stablehlo.select %35, %38, %9 : tensor<1xi1>, tensor<1xi32> + %40 = stablehlo.broadcast_in_dim %39, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %41 = "stablehlo.gather"(%32, %40) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__4__axis_2_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__4__axis_2_enable_xla_True_mode_fill.mlir +--- stablehlo/stablehlo/testdata/gather_from_take_indices_name__4__axis_2_enable_xla_True_mode_fill.mlir ++++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__4__axis_2_enable_xla_True_mode_fill.mlir +@@ -41,7 +41,7 @@ + %19 = stablehlo.compare LT, %8, %18, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %20 = stablehlo.constant dense<3> : tensor + %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<1xi32> +- %22 = stablehlo.add %8, %21 : tensor<1xi32> ++ %22 = stablehlo.add %21, %8 : tensor<1xi32> + %23 = stablehlo.select %19, %22, %8 : tensor<1xi1>, tensor<1xi32> + %24 = stablehlo.broadcast_in_dim %23, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %25 = "stablehlo.gather"(%16, %24) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -57,7 +57,7 @@ + %35 = stablehlo.compare LT, %9, %34, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %36 = stablehlo.constant dense<3> : tensor + %37 = stablehlo.broadcast_in_dim %36, dims = [] : (tensor) -> tensor<1xi32> +- %38 = stablehlo.add %9, %37 : tensor<1xi32> ++ %38 = stablehlo.add %37, %9 : tensor<1xi32> + %39 = stablehlo.select %35, %38, %9 : tensor<1xi1>, tensor<1xi32> + %40 = stablehlo.broadcast_in_dim %39, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %41 = "stablehlo.gather"(%32, %40) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__5_oob__axis_0_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__5_oob__axis_0_enable_xla_True_mode_fill.mlir +--- stablehlo/stablehlo/testdata/gather_from_take_indices_name__5_oob__axis_0_enable_xla_True_mode_fill.mlir ++++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__5_oob__axis_0_enable_xla_True_mode_fill.mlir +@@ -41,7 +41,7 @@ + %19 = stablehlo.compare LT, %8, %18, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %20 = stablehlo.constant dense<3> : tensor + %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<1xi32> +- %22 = stablehlo.add %8, %21 : tensor<1xi32> ++ %22 = stablehlo.add %21, %8 : tensor<1xi32> + %23 = stablehlo.select %19, %22, %8 : tensor<1xi1>, tensor<1xi32> + %24 = stablehlo.broadcast_in_dim %23, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %25 = "stablehlo.gather"(%16, %24) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -57,7 +57,7 @@ + %35 = stablehlo.compare LT, %9, %34, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %36 = stablehlo.constant dense<3> : tensor + %37 = stablehlo.broadcast_in_dim %36, dims = [] : (tensor) -> tensor<1xi32> +- %38 = stablehlo.add %9, %37 : tensor<1xi32> ++ %38 = stablehlo.add %37, %9 : tensor<1xi32> + %39 = stablehlo.select %35, %38, %9 : tensor<1xi1>, tensor<1xi32> + %40 = stablehlo.broadcast_in_dim %39, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %41 = "stablehlo.gather"(%32, %40) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__5_oob__axis_1_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__5_oob__axis_1_enable_xla_True_mode_fill.mlir +--- stablehlo/stablehlo/testdata/gather_from_take_indices_name__5_oob__axis_1_enable_xla_True_mode_fill.mlir ++++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__5_oob__axis_1_enable_xla_True_mode_fill.mlir +@@ -41,7 +41,7 @@ + %19 = stablehlo.compare LT, %8, %18, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %20 = stablehlo.constant dense<3> : tensor + %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<1xi32> +- %22 = stablehlo.add %8, %21 : tensor<1xi32> ++ %22 = stablehlo.add %21, %8 : tensor<1xi32> + %23 = stablehlo.select %19, %22, %8 : tensor<1xi1>, tensor<1xi32> + %24 = stablehlo.broadcast_in_dim %23, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %25 = "stablehlo.gather"(%16, %24) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -57,7 +57,7 @@ + %35 = stablehlo.compare LT, %9, %34, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %36 = stablehlo.constant dense<3> : tensor + %37 = stablehlo.broadcast_in_dim %36, dims = [] : (tensor) -> tensor<1xi32> +- %38 = stablehlo.add %9, %37 : tensor<1xi32> ++ %38 = stablehlo.add %37, %9 : tensor<1xi32> + %39 = stablehlo.select %35, %38, %9 : tensor<1xi1>, tensor<1xi32> + %40 = stablehlo.broadcast_in_dim %39, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %41 = "stablehlo.gather"(%32, %40) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__5_oob__axis_2_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__5_oob__axis_2_enable_xla_True_mode_fill.mlir +--- stablehlo/stablehlo/testdata/gather_from_take_indices_name__5_oob__axis_2_enable_xla_True_mode_fill.mlir ++++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__5_oob__axis_2_enable_xla_True_mode_fill.mlir +@@ -41,7 +41,7 @@ + %19 = stablehlo.compare LT, %8, %18, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %20 = stablehlo.constant dense<3> : tensor + %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<1xi32> +- %22 = stablehlo.add %8, %21 : tensor<1xi32> ++ %22 = stablehlo.add %21, %8 : tensor<1xi32> + %23 = stablehlo.select %19, %22, %8 : tensor<1xi1>, tensor<1xi32> + %24 = stablehlo.broadcast_in_dim %23, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %25 = "stablehlo.gather"(%16, %24) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -57,7 +57,7 @@ + %35 = stablehlo.compare LT, %9, %34, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %36 = stablehlo.constant dense<3> : tensor + %37 = stablehlo.broadcast_in_dim %36, dims = [] : (tensor) -> tensor<1xi32> +- %38 = stablehlo.add %9, %37 : tensor<1xi32> ++ %38 = stablehlo.add %37, %9 : tensor<1xi32> + %39 = stablehlo.select %35, %38, %9 : tensor<1xi1>, tensor<1xi32> + %40 = stablehlo.broadcast_in_dim %39, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %41 = "stablehlo.gather"(%32, %40) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__6_neg__axis_0_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__6_neg__axis_0_enable_xla_True_mode_fill.mlir +--- stablehlo/stablehlo/testdata/gather_from_take_indices_name__6_neg__axis_0_enable_xla_True_mode_fill.mlir ++++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__6_neg__axis_0_enable_xla_True_mode_fill.mlir +@@ -41,7 +41,7 @@ + %19 = stablehlo.compare LT, %8, %18, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %20 = stablehlo.constant dense<3> : tensor + %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<1xi32> +- %22 = stablehlo.add %8, %21 : tensor<1xi32> ++ %22 = stablehlo.add %21, %8 : tensor<1xi32> + %23 = stablehlo.select %19, %22, %8 : tensor<1xi1>, tensor<1xi32> + %24 = stablehlo.broadcast_in_dim %23, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %25 = "stablehlo.gather"(%16, %24) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -57,7 +57,7 @@ + %35 = stablehlo.compare LT, %9, %34, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %36 = stablehlo.constant dense<3> : tensor + %37 = stablehlo.broadcast_in_dim %36, dims = [] : (tensor) -> tensor<1xi32> +- %38 = stablehlo.add %9, %37 : tensor<1xi32> ++ %38 = stablehlo.add %37, %9 : tensor<1xi32> + %39 = stablehlo.select %35, %38, %9 : tensor<1xi1>, tensor<1xi32> + %40 = stablehlo.broadcast_in_dim %39, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %41 = "stablehlo.gather"(%32, %40) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__6_neg__axis_1_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__6_neg__axis_1_enable_xla_True_mode_fill.mlir +--- stablehlo/stablehlo/testdata/gather_from_take_indices_name__6_neg__axis_1_enable_xla_True_mode_fill.mlir ++++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__6_neg__axis_1_enable_xla_True_mode_fill.mlir +@@ -41,7 +41,7 @@ + %19 = stablehlo.compare LT, %8, %18, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %20 = stablehlo.constant dense<3> : tensor + %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<1xi32> +- %22 = stablehlo.add %8, %21 : tensor<1xi32> ++ %22 = stablehlo.add %21, %8 : tensor<1xi32> + %23 = stablehlo.select %19, %22, %8 : tensor<1xi1>, tensor<1xi32> + %24 = stablehlo.broadcast_in_dim %23, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %25 = "stablehlo.gather"(%16, %24) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -57,7 +57,7 @@ + %35 = stablehlo.compare LT, %9, %34, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %36 = stablehlo.constant dense<3> : tensor + %37 = stablehlo.broadcast_in_dim %36, dims = [] : (tensor) -> tensor<1xi32> +- %38 = stablehlo.add %9, %37 : tensor<1xi32> ++ %38 = stablehlo.add %37, %9 : tensor<1xi32> + %39 = stablehlo.select %35, %38, %9 : tensor<1xi1>, tensor<1xi32> + %40 = stablehlo.broadcast_in_dim %39, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %41 = "stablehlo.gather"(%32, %40) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__6_neg__axis_2_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__6_neg__axis_2_enable_xla_True_mode_fill.mlir +--- stablehlo/stablehlo/testdata/gather_from_take_indices_name__6_neg__axis_2_enable_xla_True_mode_fill.mlir ++++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__6_neg__axis_2_enable_xla_True_mode_fill.mlir +@@ -41,7 +41,7 @@ + %19 = stablehlo.compare LT, %8, %18, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %20 = stablehlo.constant dense<3> : tensor + %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<1xi32> +- %22 = stablehlo.add %8, %21 : tensor<1xi32> ++ %22 = stablehlo.add %21, %8 : tensor<1xi32> + %23 = stablehlo.select %19, %22, %8 : tensor<1xi1>, tensor<1xi32> + %24 = stablehlo.broadcast_in_dim %23, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %25 = "stablehlo.gather"(%16, %24) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -57,7 +57,7 @@ + %35 = stablehlo.compare LT, %9, %34, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %36 = stablehlo.constant dense<3> : tensor + %37 = stablehlo.broadcast_in_dim %36, dims = [] : (tensor) -> tensor<1xi32> +- %38 = stablehlo.add %9, %37 : tensor<1xi32> ++ %38 = stablehlo.add %37, %9 : tensor<1xi32> + %39 = stablehlo.select %35, %38, %9 : tensor<1xi1>, tensor<1xi32> + %40 = stablehlo.broadcast_in_dim %39, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %41 = "stablehlo.gather"(%32, %40) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__7_neg__axis_0_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__7_neg__axis_0_enable_xla_True_mode_fill.mlir +--- stablehlo/stablehlo/testdata/gather_from_take_indices_name__7_neg__axis_0_enable_xla_True_mode_fill.mlir ++++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__7_neg__axis_0_enable_xla_True_mode_fill.mlir +@@ -41,7 +41,7 @@ + %19 = stablehlo.compare LT, %8, %18, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %20 = stablehlo.constant dense<3> : tensor + %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<1xi32> +- %22 = stablehlo.add %8, %21 : tensor<1xi32> ++ %22 = stablehlo.add %21, %8 : tensor<1xi32> + %23 = stablehlo.select %19, %22, %8 : tensor<1xi1>, tensor<1xi32> + %24 = stablehlo.broadcast_in_dim %23, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %25 = "stablehlo.gather"(%16, %24) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -57,7 +57,7 @@ + %35 = stablehlo.compare LT, %9, %34, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %36 = stablehlo.constant dense<3> : tensor + %37 = stablehlo.broadcast_in_dim %36, dims = [] : (tensor) -> tensor<1xi32> +- %38 = stablehlo.add %9, %37 : tensor<1xi32> ++ %38 = stablehlo.add %37, %9 : tensor<1xi32> + %39 = stablehlo.select %35, %38, %9 : tensor<1xi1>, tensor<1xi32> + %40 = stablehlo.broadcast_in_dim %39, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %41 = "stablehlo.gather"(%32, %40) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__7_neg__axis_1_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__7_neg__axis_1_enable_xla_True_mode_fill.mlir +--- stablehlo/stablehlo/testdata/gather_from_take_indices_name__7_neg__axis_1_enable_xla_True_mode_fill.mlir ++++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__7_neg__axis_1_enable_xla_True_mode_fill.mlir +@@ -41,7 +41,7 @@ + %19 = stablehlo.compare LT, %8, %18, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %20 = stablehlo.constant dense<3> : tensor + %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<1xi32> +- %22 = stablehlo.add %8, %21 : tensor<1xi32> ++ %22 = stablehlo.add %21, %8 : tensor<1xi32> + %23 = stablehlo.select %19, %22, %8 : tensor<1xi1>, tensor<1xi32> + %24 = stablehlo.broadcast_in_dim %23, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %25 = "stablehlo.gather"(%16, %24) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -57,7 +57,7 @@ + %35 = stablehlo.compare LT, %9, %34, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %36 = stablehlo.constant dense<3> : tensor + %37 = stablehlo.broadcast_in_dim %36, dims = [] : (tensor) -> tensor<1xi32> +- %38 = stablehlo.add %9, %37 : tensor<1xi32> ++ %38 = stablehlo.add %37, %9 : tensor<1xi32> + %39 = stablehlo.select %35, %38, %9 : tensor<1xi1>, tensor<1xi32> + %40 = stablehlo.broadcast_in_dim %39, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %41 = "stablehlo.gather"(%32, %40) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__7_neg__axis_2_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__7_neg__axis_2_enable_xla_True_mode_fill.mlir +--- stablehlo/stablehlo/testdata/gather_from_take_indices_name__7_neg__axis_2_enable_xla_True_mode_fill.mlir ++++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__7_neg__axis_2_enable_xla_True_mode_fill.mlir +@@ -41,7 +41,7 @@ + %19 = stablehlo.compare LT, %8, %18, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %20 = stablehlo.constant dense<3> : tensor + %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<1xi32> +- %22 = stablehlo.add %8, %21 : tensor<1xi32> ++ %22 = stablehlo.add %21, %8 : tensor<1xi32> + %23 = stablehlo.select %19, %22, %8 : tensor<1xi1>, tensor<1xi32> + %24 = stablehlo.broadcast_in_dim %23, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %25 = "stablehlo.gather"(%16, %24) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -57,7 +57,7 @@ + %35 = stablehlo.compare LT, %9, %34, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %36 = stablehlo.constant dense<3> : tensor + %37 = stablehlo.broadcast_in_dim %36, dims = [] : (tensor) -> tensor<1xi32> +- %38 = stablehlo.add %9, %37 : tensor<1xi32> ++ %38 = stablehlo.add %37, %9 : tensor<1xi32> + %39 = stablehlo.select %35, %38, %9 : tensor<1xi1>, tensor<1xi32> + %40 = stablehlo.broadcast_in_dim %39, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %41 = "stablehlo.gather"(%32, %40) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__8_neg_oob__axis_0_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__8_neg_oob__axis_0_enable_xla_True_mode_fill.mlir +--- stablehlo/stablehlo/testdata/gather_from_take_indices_name__8_neg_oob__axis_0_enable_xla_True_mode_fill.mlir ++++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__8_neg_oob__axis_0_enable_xla_True_mode_fill.mlir +@@ -41,7 +41,7 @@ + %19 = stablehlo.compare LT, %8, %18, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %20 = stablehlo.constant dense<3> : tensor + %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<1xi32> +- %22 = stablehlo.add %8, %21 : tensor<1xi32> ++ %22 = stablehlo.add %21, %8 : tensor<1xi32> + %23 = stablehlo.select %19, %22, %8 : tensor<1xi1>, tensor<1xi32> + %24 = stablehlo.broadcast_in_dim %23, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %25 = "stablehlo.gather"(%16, %24) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -57,7 +57,7 @@ + %35 = stablehlo.compare LT, %9, %34, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %36 = stablehlo.constant dense<3> : tensor + %37 = stablehlo.broadcast_in_dim %36, dims = [] : (tensor) -> tensor<1xi32> +- %38 = stablehlo.add %9, %37 : tensor<1xi32> ++ %38 = stablehlo.add %37, %9 : tensor<1xi32> + %39 = stablehlo.select %35, %38, %9 : tensor<1xi1>, tensor<1xi32> + %40 = stablehlo.broadcast_in_dim %39, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %41 = "stablehlo.gather"(%32, %40) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__8_neg_oob__axis_1_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__8_neg_oob__axis_1_enable_xla_True_mode_fill.mlir +--- stablehlo/stablehlo/testdata/gather_from_take_indices_name__8_neg_oob__axis_1_enable_xla_True_mode_fill.mlir ++++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__8_neg_oob__axis_1_enable_xla_True_mode_fill.mlir +@@ -41,7 +41,7 @@ + %19 = stablehlo.compare LT, %8, %18, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %20 = stablehlo.constant dense<3> : tensor + %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<1xi32> +- %22 = stablehlo.add %8, %21 : tensor<1xi32> ++ %22 = stablehlo.add %21, %8 : tensor<1xi32> + %23 = stablehlo.select %19, %22, %8 : tensor<1xi1>, tensor<1xi32> + %24 = stablehlo.broadcast_in_dim %23, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %25 = "stablehlo.gather"(%16, %24) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -57,7 +57,7 @@ + %35 = stablehlo.compare LT, %9, %34, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %36 = stablehlo.constant dense<3> : tensor + %37 = stablehlo.broadcast_in_dim %36, dims = [] : (tensor) -> tensor<1xi32> +- %38 = stablehlo.add %9, %37 : tensor<1xi32> ++ %38 = stablehlo.add %37, %9 : tensor<1xi32> + %39 = stablehlo.select %35, %38, %9 : tensor<1xi1>, tensor<1xi32> + %40 = stablehlo.broadcast_in_dim %39, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %41 = "stablehlo.gather"(%32, %40) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__8_neg_oob__axis_2_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__8_neg_oob__axis_2_enable_xla_True_mode_fill.mlir +--- stablehlo/stablehlo/testdata/gather_from_take_indices_name__8_neg_oob__axis_2_enable_xla_True_mode_fill.mlir ++++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__8_neg_oob__axis_2_enable_xla_True_mode_fill.mlir +@@ -41,7 +41,7 @@ + %19 = stablehlo.compare LT, %8, %18, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %20 = stablehlo.constant dense<3> : tensor + %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<1xi32> +- %22 = stablehlo.add %8, %21 : tensor<1xi32> ++ %22 = stablehlo.add %21, %8 : tensor<1xi32> + %23 = stablehlo.select %19, %22, %8 : tensor<1xi1>, tensor<1xi32> + %24 = stablehlo.broadcast_in_dim %23, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %25 = "stablehlo.gather"(%16, %24) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +@@ -57,7 +57,7 @@ + %35 = stablehlo.compare LT, %9, %34, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %36 = stablehlo.constant dense<3> : tensor + %37 = stablehlo.broadcast_in_dim %36, dims = [] : (tensor) -> tensor<1xi32> +- %38 = stablehlo.add %9, %37 : tensor<1xi32> ++ %38 = stablehlo.add %37, %9 : tensor<1xi32> + %39 = stablehlo.select %35, %38, %9 : tensor<1xi1>, tensor<1xi32> + %40 = stablehlo.broadcast_in_dim %39, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> + %41 = "stablehlo.gather"(%32, %40) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/igamma_broadcasting_lhs_float32_1_20__rhs_float32_20_20.mlir b/stablehlo/stablehlo/testdata/igamma_broadcasting_lhs_float32_1_20__rhs_float32_20_20.mlir +--- stablehlo/stablehlo/testdata/igamma_broadcasting_lhs_float32_1_20__rhs_float32_20_20.mlir ++++ stablehlo/stablehlo/testdata/igamma_broadcasting_lhs_float32_1_20__rhs_float32_20_20.mlir +@@ -202,7 +202,7 @@ + %21 = stablehlo.multiply %19, %20 : tensor<20x20xf32> + %22 = stablehlo.constant dense<-1.000000e+00> : tensor + %23 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xf32> +- %24 = stablehlo.multiply %23, %1 : tensor<20x20xf32> ++ %24 = stablehlo.multiply %1, %23 : tensor<20x20xf32> + %25 = stablehlo.multiply %24, %2 : tensor<20x20xf32> + %26 = stablehlo.multiply %6, %6 : tensor<20x20xf32> + %27 = stablehlo.divide %25, %26 : tensor<20x20xf32> +@@ -272,7 +272,7 @@ + %34 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %35 = stablehlo.subtract %34, %29 : tensor<20x20xf32> + %36 = stablehlo.select %32, %35, %29 : tensor<20x20xi1>, tensor<20x20xf32> +- %37 = stablehlo.multiply %26, %36 : tensor<20x20xf32> ++ %37 = stablehlo.multiply %36, %26 : tensor<20x20xf32> + %38 = stablehlo.sine %37 : tensor<20x20xf32> + %39 = stablehlo.log %38 : tensor<20x20xf32> + %40 = stablehlo.is_finite %39 : (tensor<20x20xf32>) -> tensor<20x20xi1> +@@ -290,17 +290,17 @@ + %52 = stablehlo.add %50, %51 : tensor<20x20xf32> + %53 = stablehlo.constant dense<7.500000e+00> : tensor + %54 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> +- %55 = stablehlo.add %54, %50 : tensor<20x20xf32> ++ %55 = stablehlo.add %50, %54 : tensor<20x20xf32> + %56 = stablehlo.constant dense<2.01490307> : tensor + %57 = stablehlo.constant dense<2.01490307> : tensor<20x20xf32> + %58 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> + %59 = stablehlo.divide %50, %58 : tensor<20x20xf32> + %60 = stablehlo.log_plus_one %59 : tensor<20x20xf32> +- %61 = stablehlo.add %57, %60 : tensor<20x20xf32> ++ %61 = stablehlo.add %60, %57 : tensor<20x20xf32> + %62 = stablehlo.divide %55, %61 : tensor<20x20xf32> + %63 = stablehlo.subtract %52, %62 : tensor<20x20xf32> + %64 = stablehlo.multiply %63, %61 : tensor<20x20xf32> +- %65 = stablehlo.add %45, %64 : tensor<20x20xf32> ++ %65 = stablehlo.add %64, %45 : tensor<20x20xf32> + %66 = stablehlo.constant dense<1.000000e+00> : tensor + %67 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %68 = stablehlo.constant dense<676.520386> : tensor +@@ -311,7 +311,7 @@ + %73 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %74 = stablehlo.add %72, %73 : tensor<20x20xf32> + %75 = stablehlo.divide %69, %74 : tensor<20x20xf32> +- %76 = stablehlo.add %67, %75 : tensor<20x20xf32> ++ %76 = stablehlo.add %75, %67 : tensor<20x20xf32> + %77 = stablehlo.constant dense<-1259.13916> : tensor + %78 = stablehlo.constant dense<-1259.13916> : tensor<20x20xf32> + %79 = stablehlo.constant dense<1.000000e+00> : tensor +@@ -585,7 +585,7 @@ + %240 = stablehlo.multiply %iterArg_4, %239 : tensor<20x20xf32> + %241 = stablehlo.constant dense<-1.000000e+00> : tensor + %242 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xf32> +- %243 = stablehlo.multiply %242, %iterArg_1 : tensor<20x20xf32> ++ %243 = stablehlo.multiply %iterArg_1, %242 : tensor<20x20xf32> + %244 = stablehlo.multiply %243, %iterArg_3 : tensor<20x20xf32> + %245 = stablehlo.multiply %227, %227 : tensor<20x20xf32> + %246 = stablehlo.divide %244, %245 : tensor<20x20xf32> +diff --ruN a/stablehlo/stablehlo/testdata/igamma_broadcasting_lhs_float32_20_20__rhs_float32_1_20.mlir b/stablehlo/stablehlo/testdata/igamma_broadcasting_lhs_float32_20_20__rhs_float32_1_20.mlir +--- stablehlo/stablehlo/testdata/igamma_broadcasting_lhs_float32_20_20__rhs_float32_1_20.mlir ++++ stablehlo/stablehlo/testdata/igamma_broadcasting_lhs_float32_20_20__rhs_float32_1_20.mlir +@@ -202,7 +202,7 @@ + %21 = stablehlo.multiply %19, %20 : tensor<20x20xf32> + %22 = stablehlo.constant dense<-1.000000e+00> : tensor + %23 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xf32> +- %24 = stablehlo.multiply %23, %1 : tensor<20x20xf32> ++ %24 = stablehlo.multiply %1, %23 : tensor<20x20xf32> + %25 = stablehlo.multiply %24, %2 : tensor<20x20xf32> + %26 = stablehlo.multiply %6, %6 : tensor<20x20xf32> + %27 = stablehlo.divide %25, %26 : tensor<20x20xf32> +@@ -272,7 +272,7 @@ + %34 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %35 = stablehlo.subtract %34, %29 : tensor<20x20xf32> + %36 = stablehlo.select %32, %35, %29 : tensor<20x20xi1>, tensor<20x20xf32> +- %37 = stablehlo.multiply %26, %36 : tensor<20x20xf32> ++ %37 = stablehlo.multiply %36, %26 : tensor<20x20xf32> + %38 = stablehlo.sine %37 : tensor<20x20xf32> + %39 = stablehlo.log %38 : tensor<20x20xf32> + %40 = stablehlo.is_finite %39 : (tensor<20x20xf32>) -> tensor<20x20xi1> +@@ -290,17 +290,17 @@ + %52 = stablehlo.add %50, %51 : tensor<20x20xf32> + %53 = stablehlo.constant dense<7.500000e+00> : tensor + %54 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> +- %55 = stablehlo.add %54, %50 : tensor<20x20xf32> ++ %55 = stablehlo.add %50, %54 : tensor<20x20xf32> + %56 = stablehlo.constant dense<2.01490307> : tensor + %57 = stablehlo.constant dense<2.01490307> : tensor<20x20xf32> + %58 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> + %59 = stablehlo.divide %50, %58 : tensor<20x20xf32> + %60 = stablehlo.log_plus_one %59 : tensor<20x20xf32> +- %61 = stablehlo.add %57, %60 : tensor<20x20xf32> ++ %61 = stablehlo.add %60, %57 : tensor<20x20xf32> + %62 = stablehlo.divide %55, %61 : tensor<20x20xf32> + %63 = stablehlo.subtract %52, %62 : tensor<20x20xf32> + %64 = stablehlo.multiply %63, %61 : tensor<20x20xf32> +- %65 = stablehlo.add %45, %64 : tensor<20x20xf32> ++ %65 = stablehlo.add %64, %45 : tensor<20x20xf32> + %66 = stablehlo.constant dense<1.000000e+00> : tensor + %67 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %68 = stablehlo.constant dense<676.520386> : tensor +@@ -311,7 +311,7 @@ + %73 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %74 = stablehlo.add %72, %73 : tensor<20x20xf32> + %75 = stablehlo.divide %69, %74 : tensor<20x20xf32> +- %76 = stablehlo.add %67, %75 : tensor<20x20xf32> ++ %76 = stablehlo.add %75, %67 : tensor<20x20xf32> + %77 = stablehlo.constant dense<-1259.13916> : tensor + %78 = stablehlo.constant dense<-1259.13916> : tensor<20x20xf32> + %79 = stablehlo.constant dense<1.000000e+00> : tensor +@@ -585,7 +585,7 @@ + %240 = stablehlo.multiply %iterArg_4, %239 : tensor<20x20xf32> + %241 = stablehlo.constant dense<-1.000000e+00> : tensor + %242 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xf32> +- %243 = stablehlo.multiply %242, %iterArg_1 : tensor<20x20xf32> ++ %243 = stablehlo.multiply %iterArg_1, %242 : tensor<20x20xf32> + %244 = stablehlo.multiply %243, %iterArg_3 : tensor<20x20xf32> + %245 = stablehlo.multiply %227, %227 : tensor<20x20xf32> + %246 = stablehlo.divide %244, %245 : tensor<20x20xf32> +diff --ruN a/stablehlo/stablehlo/testdata/igamma_dtypes_lhs_bfloat16_20_20__rhs_bfloat16_20_20.mlir b/stablehlo/stablehlo/testdata/igamma_dtypes_lhs_bfloat16_20_20__rhs_bfloat16_20_20.mlir +--- stablehlo/stablehlo/testdata/igamma_dtypes_lhs_bfloat16_20_20__rhs_bfloat16_20_20.mlir ++++ stablehlo/stablehlo/testdata/igamma_dtypes_lhs_bfloat16_20_20__rhs_bfloat16_20_20.mlir +@@ -202,7 +202,7 @@ + %21 = stablehlo.multiply %19, %20 : tensor<20x20xf32> + %22 = stablehlo.constant dense<-1.000000e+00> : tensor + %23 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xf32> +- %24 = stablehlo.multiply %23, %1 : tensor<20x20xf32> ++ %24 = stablehlo.multiply %1, %23 : tensor<20x20xf32> + %25 = stablehlo.multiply %24, %2 : tensor<20x20xf32> + %26 = stablehlo.multiply %6, %6 : tensor<20x20xf32> + %27 = stablehlo.divide %25, %26 : tensor<20x20xf32> +@@ -272,7 +272,7 @@ + %34 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %35 = stablehlo.subtract %34, %29 : tensor<20x20xf32> + %36 = stablehlo.select %32, %35, %29 : tensor<20x20xi1>, tensor<20x20xf32> +- %37 = stablehlo.multiply %26, %36 : tensor<20x20xf32> ++ %37 = stablehlo.multiply %36, %26 : tensor<20x20xf32> + %38 = stablehlo.sine %37 : tensor<20x20xf32> + %39 = stablehlo.log %38 : tensor<20x20xf32> + %40 = stablehlo.is_finite %39 : (tensor<20x20xf32>) -> tensor<20x20xi1> +@@ -290,17 +290,17 @@ + %52 = stablehlo.add %50, %51 : tensor<20x20xf32> + %53 = stablehlo.constant dense<7.500000e+00> : tensor + %54 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> +- %55 = stablehlo.add %54, %50 : tensor<20x20xf32> ++ %55 = stablehlo.add %50, %54 : tensor<20x20xf32> + %56 = stablehlo.constant dense<2.01490307> : tensor + %57 = stablehlo.constant dense<2.01490307> : tensor<20x20xf32> + %58 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> + %59 = stablehlo.divide %50, %58 : tensor<20x20xf32> + %60 = stablehlo.log_plus_one %59 : tensor<20x20xf32> +- %61 = stablehlo.add %57, %60 : tensor<20x20xf32> ++ %61 = stablehlo.add %60, %57 : tensor<20x20xf32> + %62 = stablehlo.divide %55, %61 : tensor<20x20xf32> + %63 = stablehlo.subtract %52, %62 : tensor<20x20xf32> + %64 = stablehlo.multiply %63, %61 : tensor<20x20xf32> +- %65 = stablehlo.add %45, %64 : tensor<20x20xf32> ++ %65 = stablehlo.add %64, %45 : tensor<20x20xf32> + %66 = stablehlo.constant dense<1.000000e+00> : tensor + %67 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %68 = stablehlo.constant dense<676.520386> : tensor +@@ -311,7 +311,7 @@ + %73 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %74 = stablehlo.add %72, %73 : tensor<20x20xf32> + %75 = stablehlo.divide %69, %74 : tensor<20x20xf32> +- %76 = stablehlo.add %67, %75 : tensor<20x20xf32> ++ %76 = stablehlo.add %75, %67 : tensor<20x20xf32> + %77 = stablehlo.constant dense<-1259.13916> : tensor + %78 = stablehlo.constant dense<-1259.13916> : tensor<20x20xf32> + %79 = stablehlo.constant dense<1.000000e+00> : tensor +@@ -585,7 +585,7 @@ + %241 = stablehlo.multiply %iterArg_4, %240 : tensor<20x20xf32> + %242 = stablehlo.constant dense<-1.000000e+00> : tensor + %243 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xf32> +- %244 = stablehlo.multiply %243, %iterArg_1 : tensor<20x20xf32> ++ %244 = stablehlo.multiply %iterArg_1, %243 : tensor<20x20xf32> + %245 = stablehlo.multiply %244, %iterArg_3 : tensor<20x20xf32> + %246 = stablehlo.multiply %228, %228 : tensor<20x20xf32> + %247 = stablehlo.divide %245, %246 : tensor<20x20xf32> +diff --ruN a/stablehlo/stablehlo/testdata/igamma_dtypes_lhs_float16_20_20__rhs_float16_20_20.mlir b/stablehlo/stablehlo/testdata/igamma_dtypes_lhs_float16_20_20__rhs_float16_20_20.mlir +--- stablehlo/stablehlo/testdata/igamma_dtypes_lhs_float16_20_20__rhs_float16_20_20.mlir ++++ stablehlo/stablehlo/testdata/igamma_dtypes_lhs_float16_20_20__rhs_float16_20_20.mlir +@@ -202,7 +202,7 @@ + %21 = stablehlo.multiply %19, %20 : tensor<20x20xf32> + %22 = stablehlo.constant dense<-1.000000e+00> : tensor + %23 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xf32> +- %24 = stablehlo.multiply %23, %1 : tensor<20x20xf32> ++ %24 = stablehlo.multiply %1, %23 : tensor<20x20xf32> + %25 = stablehlo.multiply %24, %2 : tensor<20x20xf32> + %26 = stablehlo.multiply %6, %6 : tensor<20x20xf32> + %27 = stablehlo.divide %25, %26 : tensor<20x20xf32> +@@ -272,7 +272,7 @@ + %34 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %35 = stablehlo.subtract %34, %29 : tensor<20x20xf32> + %36 = stablehlo.select %32, %35, %29 : tensor<20x20xi1>, tensor<20x20xf32> +- %37 = stablehlo.multiply %26, %36 : tensor<20x20xf32> ++ %37 = stablehlo.multiply %36, %26 : tensor<20x20xf32> + %38 = stablehlo.sine %37 : tensor<20x20xf32> + %39 = stablehlo.log %38 : tensor<20x20xf32> + %40 = stablehlo.is_finite %39 : (tensor<20x20xf32>) -> tensor<20x20xi1> +@@ -290,17 +290,17 @@ + %52 = stablehlo.add %50, %51 : tensor<20x20xf32> + %53 = stablehlo.constant dense<7.500000e+00> : tensor + %54 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> +- %55 = stablehlo.add %54, %50 : tensor<20x20xf32> ++ %55 = stablehlo.add %50, %54 : tensor<20x20xf32> + %56 = stablehlo.constant dense<2.01490307> : tensor + %57 = stablehlo.constant dense<2.01490307> : tensor<20x20xf32> + %58 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> + %59 = stablehlo.divide %50, %58 : tensor<20x20xf32> + %60 = stablehlo.log_plus_one %59 : tensor<20x20xf32> +- %61 = stablehlo.add %57, %60 : tensor<20x20xf32> ++ %61 = stablehlo.add %60, %57 : tensor<20x20xf32> + %62 = stablehlo.divide %55, %61 : tensor<20x20xf32> + %63 = stablehlo.subtract %52, %62 : tensor<20x20xf32> + %64 = stablehlo.multiply %63, %61 : tensor<20x20xf32> +- %65 = stablehlo.add %45, %64 : tensor<20x20xf32> ++ %65 = stablehlo.add %64, %45 : tensor<20x20xf32> + %66 = stablehlo.constant dense<1.000000e+00> : tensor + %67 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %68 = stablehlo.constant dense<676.520386> : tensor +@@ -311,7 +311,7 @@ + %73 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %74 = stablehlo.add %72, %73 : tensor<20x20xf32> + %75 = stablehlo.divide %69, %74 : tensor<20x20xf32> +- %76 = stablehlo.add %67, %75 : tensor<20x20xf32> ++ %76 = stablehlo.add %75, %67 : tensor<20x20xf32> + %77 = stablehlo.constant dense<-1259.13916> : tensor + %78 = stablehlo.constant dense<-1259.13916> : tensor<20x20xf32> + %79 = stablehlo.constant dense<1.000000e+00> : tensor +@@ -585,7 +585,7 @@ + %241 = stablehlo.multiply %iterArg_4, %240 : tensor<20x20xf32> + %242 = stablehlo.constant dense<-1.000000e+00> : tensor + %243 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xf32> +- %244 = stablehlo.multiply %243, %iterArg_1 : tensor<20x20xf32> ++ %244 = stablehlo.multiply %iterArg_1, %243 : tensor<20x20xf32> + %245 = stablehlo.multiply %244, %iterArg_3 : tensor<20x20xf32> + %246 = stablehlo.multiply %228, %228 : tensor<20x20xf32> + %247 = stablehlo.divide %245, %246 : tensor<20x20xf32> +diff --ruN a/stablehlo/stablehlo/testdata/igamma_dtypes_lhs_float32_20_20__rhs_float32_20_20.mlir b/stablehlo/stablehlo/testdata/igamma_dtypes_lhs_float32_20_20__rhs_float32_20_20.mlir +--- stablehlo/stablehlo/testdata/igamma_dtypes_lhs_float32_20_20__rhs_float32_20_20.mlir ++++ stablehlo/stablehlo/testdata/igamma_dtypes_lhs_float32_20_20__rhs_float32_20_20.mlir +@@ -202,7 +202,7 @@ + %21 = stablehlo.multiply %19, %20 : tensor<20x20xf32> + %22 = stablehlo.constant dense<-1.000000e+00> : tensor + %23 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xf32> +- %24 = stablehlo.multiply %23, %1 : tensor<20x20xf32> ++ %24 = stablehlo.multiply %1, %23 : tensor<20x20xf32> + %25 = stablehlo.multiply %24, %2 : tensor<20x20xf32> + %26 = stablehlo.multiply %6, %6 : tensor<20x20xf32> + %27 = stablehlo.divide %25, %26 : tensor<20x20xf32> +@@ -270,7 +270,7 @@ + %32 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %33 = stablehlo.subtract %32, %27 : tensor<20x20xf32> + %34 = stablehlo.select %30, %33, %27 : tensor<20x20xi1>, tensor<20x20xf32> +- %35 = stablehlo.multiply %24, %34 : tensor<20x20xf32> ++ %35 = stablehlo.multiply %34, %24 : tensor<20x20xf32> + %36 = stablehlo.sine %35 : tensor<20x20xf32> + %37 = stablehlo.log %36 : tensor<20x20xf32> + %38 = stablehlo.is_finite %37 : (tensor<20x20xf32>) -> tensor<20x20xi1> +@@ -288,17 +288,17 @@ + %50 = stablehlo.add %48, %49 : tensor<20x20xf32> + %51 = stablehlo.constant dense<7.500000e+00> : tensor + %52 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> +- %53 = stablehlo.add %52, %48 : tensor<20x20xf32> ++ %53 = stablehlo.add %48, %52 : tensor<20x20xf32> + %54 = stablehlo.constant dense<2.01490307> : tensor + %55 = stablehlo.constant dense<2.01490307> : tensor<20x20xf32> + %56 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> + %57 = stablehlo.divide %48, %56 : tensor<20x20xf32> + %58 = stablehlo.log_plus_one %57 : tensor<20x20xf32> +- %59 = stablehlo.add %55, %58 : tensor<20x20xf32> ++ %59 = stablehlo.add %58, %55 : tensor<20x20xf32> + %60 = stablehlo.divide %53, %59 : tensor<20x20xf32> + %61 = stablehlo.subtract %50, %60 : tensor<20x20xf32> + %62 = stablehlo.multiply %61, %59 : tensor<20x20xf32> +- %63 = stablehlo.add %43, %62 : tensor<20x20xf32> ++ %63 = stablehlo.add %62, %43 : tensor<20x20xf32> + %64 = stablehlo.constant dense<1.000000e+00> : tensor + %65 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %66 = stablehlo.constant dense<676.520386> : tensor +@@ -309,7 +309,7 @@ + %71 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %72 = stablehlo.add %70, %71 : tensor<20x20xf32> + %73 = stablehlo.divide %67, %72 : tensor<20x20xf32> +- %74 = stablehlo.add %65, %73 : tensor<20x20xf32> ++ %74 = stablehlo.add %73, %65 : tensor<20x20xf32> + %75 = stablehlo.constant dense<-1259.13916> : tensor + %76 = stablehlo.constant dense<-1259.13916> : tensor<20x20xf32> + %77 = stablehlo.constant dense<1.000000e+00> : tensor +@@ -583,7 +583,7 @@ + %238 = stablehlo.multiply %iterArg_4, %237 : tensor<20x20xf32> + %239 = stablehlo.constant dense<-1.000000e+00> : tensor + %240 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xf32> +- %241 = stablehlo.multiply %240, %iterArg_1 : tensor<20x20xf32> ++ %241 = stablehlo.multiply %iterArg_1, %240 : tensor<20x20xf32> + %242 = stablehlo.multiply %241, %iterArg_3 : tensor<20x20xf32> + %243 = stablehlo.multiply %225, %225 : tensor<20x20xf32> + %244 = stablehlo.divide %242, %243 : tensor<20x20xf32> +diff --ruN a/stablehlo/stablehlo/testdata/igammac_broadcasting_lhs_float32_1_20__rhs_float32_20_20.mlir b/stablehlo/stablehlo/testdata/igammac_broadcasting_lhs_float32_1_20__rhs_float32_20_20.mlir +--- stablehlo/stablehlo/testdata/igammac_broadcasting_lhs_float32_1_20__rhs_float32_20_20.mlir ++++ stablehlo/stablehlo/testdata/igammac_broadcasting_lhs_float32_1_20__rhs_float32_20_20.mlir +@@ -47,7 +47,7 @@ + %21 = stablehlo.multiply %19, %20 : tensor<20x20xf32> + %22 = stablehlo.constant dense<-1.000000e+00> : tensor + %23 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xf32> +- %24 = stablehlo.multiply %23, %1 : tensor<20x20xf32> ++ %24 = stablehlo.multiply %1, %23 : tensor<20x20xf32> + %25 = stablehlo.multiply %24, %2 : tensor<20x20xf32> + %26 = stablehlo.multiply %6, %6 : tensor<20x20xf32> + %27 = stablehlo.divide %25, %26 : tensor<20x20xf32> +@@ -268,7 +268,7 @@ + %30 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %31 = stablehlo.subtract %30, %25 : tensor<20x20xf32> + %32 = stablehlo.select %28, %31, %25 : tensor<20x20xi1>, tensor<20x20xf32> +- %33 = stablehlo.multiply %22, %32 : tensor<20x20xf32> ++ %33 = stablehlo.multiply %32, %22 : tensor<20x20xf32> + %34 = stablehlo.sine %33 : tensor<20x20xf32> + %35 = stablehlo.log %34 : tensor<20x20xf32> + %36 = stablehlo.is_finite %35 : (tensor<20x20xf32>) -> tensor<20x20xi1> +@@ -286,17 +286,17 @@ + %48 = stablehlo.add %46, %47 : tensor<20x20xf32> + %49 = stablehlo.constant dense<7.500000e+00> : tensor + %50 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> +- %51 = stablehlo.add %50, %46 : tensor<20x20xf32> ++ %51 = stablehlo.add %46, %50 : tensor<20x20xf32> + %52 = stablehlo.constant dense<2.01490307> : tensor + %53 = stablehlo.constant dense<2.01490307> : tensor<20x20xf32> + %54 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> + %55 = stablehlo.divide %46, %54 : tensor<20x20xf32> + %56 = stablehlo.log_plus_one %55 : tensor<20x20xf32> +- %57 = stablehlo.add %53, %56 : tensor<20x20xf32> ++ %57 = stablehlo.add %56, %53 : tensor<20x20xf32> + %58 = stablehlo.divide %51, %57 : tensor<20x20xf32> + %59 = stablehlo.subtract %48, %58 : tensor<20x20xf32> + %60 = stablehlo.multiply %59, %57 : tensor<20x20xf32> +- %61 = stablehlo.add %41, %60 : tensor<20x20xf32> ++ %61 = stablehlo.add %60, %41 : tensor<20x20xf32> + %62 = stablehlo.constant dense<1.000000e+00> : tensor + %63 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %64 = stablehlo.constant dense<676.520386> : tensor +@@ -307,7 +307,7 @@ + %69 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %70 = stablehlo.add %68, %69 : tensor<20x20xf32> + %71 = stablehlo.divide %65, %70 : tensor<20x20xf32> +- %72 = stablehlo.add %63, %71 : tensor<20x20xf32> ++ %72 = stablehlo.add %71, %63 : tensor<20x20xf32> + %73 = stablehlo.constant dense<-1259.13916> : tensor + %74 = stablehlo.constant dense<-1259.13916> : tensor<20x20xf32> + %75 = stablehlo.constant dense<1.000000e+00> : tensor +@@ -428,7 +428,7 @@ + %228 = stablehlo.multiply %iterArg_4, %227 : tensor<20x20xf32> + %229 = stablehlo.constant dense<-1.000000e+00> : tensor + %230 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xf32> +- %231 = stablehlo.multiply %230, %iterArg_1 : tensor<20x20xf32> ++ %231 = stablehlo.multiply %iterArg_1, %230 : tensor<20x20xf32> + %232 = stablehlo.multiply %231, %iterArg_3 : tensor<20x20xf32> + %233 = stablehlo.multiply %215, %215 : tensor<20x20xf32> + %234 = stablehlo.divide %232, %233 : tensor<20x20xf32> +diff --ruN a/stablehlo/stablehlo/testdata/igammac_broadcasting_lhs_float32_20_20__rhs_float32_1_20.mlir b/stablehlo/stablehlo/testdata/igammac_broadcasting_lhs_float32_20_20__rhs_float32_1_20.mlir +--- stablehlo/stablehlo/testdata/igammac_broadcasting_lhs_float32_20_20__rhs_float32_1_20.mlir ++++ stablehlo/stablehlo/testdata/igammac_broadcasting_lhs_float32_20_20__rhs_float32_1_20.mlir +@@ -47,7 +47,7 @@ + %21 = stablehlo.multiply %19, %20 : tensor<20x20xf32> + %22 = stablehlo.constant dense<-1.000000e+00> : tensor + %23 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xf32> +- %24 = stablehlo.multiply %23, %1 : tensor<20x20xf32> ++ %24 = stablehlo.multiply %1, %23 : tensor<20x20xf32> + %25 = stablehlo.multiply %24, %2 : tensor<20x20xf32> + %26 = stablehlo.multiply %6, %6 : tensor<20x20xf32> + %27 = stablehlo.divide %25, %26 : tensor<20x20xf32> +@@ -268,7 +268,7 @@ + %30 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %31 = stablehlo.subtract %30, %25 : tensor<20x20xf32> + %32 = stablehlo.select %28, %31, %25 : tensor<20x20xi1>, tensor<20x20xf32> +- %33 = stablehlo.multiply %22, %32 : tensor<20x20xf32> ++ %33 = stablehlo.multiply %32, %22 : tensor<20x20xf32> + %34 = stablehlo.sine %33 : tensor<20x20xf32> + %35 = stablehlo.log %34 : tensor<20x20xf32> + %36 = stablehlo.is_finite %35 : (tensor<20x20xf32>) -> tensor<20x20xi1> +@@ -286,17 +286,17 @@ + %48 = stablehlo.add %46, %47 : tensor<20x20xf32> + %49 = stablehlo.constant dense<7.500000e+00> : tensor + %50 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> +- %51 = stablehlo.add %50, %46 : tensor<20x20xf32> ++ %51 = stablehlo.add %46, %50 : tensor<20x20xf32> + %52 = stablehlo.constant dense<2.01490307> : tensor + %53 = stablehlo.constant dense<2.01490307> : tensor<20x20xf32> + %54 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> + %55 = stablehlo.divide %46, %54 : tensor<20x20xf32> + %56 = stablehlo.log_plus_one %55 : tensor<20x20xf32> +- %57 = stablehlo.add %53, %56 : tensor<20x20xf32> ++ %57 = stablehlo.add %56, %53 : tensor<20x20xf32> + %58 = stablehlo.divide %51, %57 : tensor<20x20xf32> + %59 = stablehlo.subtract %48, %58 : tensor<20x20xf32> + %60 = stablehlo.multiply %59, %57 : tensor<20x20xf32> +- %61 = stablehlo.add %41, %60 : tensor<20x20xf32> ++ %61 = stablehlo.add %60, %41 : tensor<20x20xf32> + %62 = stablehlo.constant dense<1.000000e+00> : tensor + %63 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %64 = stablehlo.constant dense<676.520386> : tensor +@@ -307,7 +307,7 @@ + %69 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %70 = stablehlo.add %68, %69 : tensor<20x20xf32> + %71 = stablehlo.divide %65, %70 : tensor<20x20xf32> +- %72 = stablehlo.add %63, %71 : tensor<20x20xf32> ++ %72 = stablehlo.add %71, %63 : tensor<20x20xf32> + %73 = stablehlo.constant dense<-1259.13916> : tensor + %74 = stablehlo.constant dense<-1259.13916> : tensor<20x20xf32> + %75 = stablehlo.constant dense<1.000000e+00> : tensor +@@ -428,7 +428,7 @@ + %228 = stablehlo.multiply %iterArg_4, %227 : tensor<20x20xf32> + %229 = stablehlo.constant dense<-1.000000e+00> : tensor + %230 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xf32> +- %231 = stablehlo.multiply %230, %iterArg_1 : tensor<20x20xf32> ++ %231 = stablehlo.multiply %iterArg_1, %230 : tensor<20x20xf32> + %232 = stablehlo.multiply %231, %iterArg_3 : tensor<20x20xf32> + %233 = stablehlo.multiply %215, %215 : tensor<20x20xf32> + %234 = stablehlo.divide %232, %233 : tensor<20x20xf32> +diff --ruN a/stablehlo/stablehlo/testdata/igammac_dtypes_lhs_bfloat16_20_20__rhs_bfloat16_20_20.mlir b/stablehlo/stablehlo/testdata/igammac_dtypes_lhs_bfloat16_20_20__rhs_bfloat16_20_20.mlir +--- stablehlo/stablehlo/testdata/igammac_dtypes_lhs_bfloat16_20_20__rhs_bfloat16_20_20.mlir ++++ stablehlo/stablehlo/testdata/igammac_dtypes_lhs_bfloat16_20_20__rhs_bfloat16_20_20.mlir +@@ -47,7 +47,7 @@ + %21 = stablehlo.multiply %19, %20 : tensor<20x20xf32> + %22 = stablehlo.constant dense<-1.000000e+00> : tensor + %23 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xf32> +- %24 = stablehlo.multiply %23, %1 : tensor<20x20xf32> ++ %24 = stablehlo.multiply %1, %23 : tensor<20x20xf32> + %25 = stablehlo.multiply %24, %2 : tensor<20x20xf32> + %26 = stablehlo.multiply %6, %6 : tensor<20x20xf32> + %27 = stablehlo.divide %25, %26 : tensor<20x20xf32> +@@ -268,7 +268,7 @@ + %30 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %31 = stablehlo.subtract %30, %25 : tensor<20x20xf32> + %32 = stablehlo.select %28, %31, %25 : tensor<20x20xi1>, tensor<20x20xf32> +- %33 = stablehlo.multiply %22, %32 : tensor<20x20xf32> ++ %33 = stablehlo.multiply %32, %22 : tensor<20x20xf32> + %34 = stablehlo.sine %33 : tensor<20x20xf32> + %35 = stablehlo.log %34 : tensor<20x20xf32> + %36 = stablehlo.is_finite %35 : (tensor<20x20xf32>) -> tensor<20x20xi1> +@@ -286,17 +286,17 @@ + %48 = stablehlo.add %46, %47 : tensor<20x20xf32> + %49 = stablehlo.constant dense<7.500000e+00> : tensor + %50 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> +- %51 = stablehlo.add %50, %46 : tensor<20x20xf32> ++ %51 = stablehlo.add %46, %50 : tensor<20x20xf32> + %52 = stablehlo.constant dense<2.01490307> : tensor + %53 = stablehlo.constant dense<2.01490307> : tensor<20x20xf32> + %54 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> + %55 = stablehlo.divide %46, %54 : tensor<20x20xf32> + %56 = stablehlo.log_plus_one %55 : tensor<20x20xf32> +- %57 = stablehlo.add %53, %56 : tensor<20x20xf32> ++ %57 = stablehlo.add %56, %53 : tensor<20x20xf32> + %58 = stablehlo.divide %51, %57 : tensor<20x20xf32> + %59 = stablehlo.subtract %48, %58 : tensor<20x20xf32> + %60 = stablehlo.multiply %59, %57 : tensor<20x20xf32> +- %61 = stablehlo.add %41, %60 : tensor<20x20xf32> ++ %61 = stablehlo.add %60, %41 : tensor<20x20xf32> + %62 = stablehlo.constant dense<1.000000e+00> : tensor + %63 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %64 = stablehlo.constant dense<676.520386> : tensor +@@ -307,7 +307,7 @@ + %69 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %70 = stablehlo.add %68, %69 : tensor<20x20xf32> + %71 = stablehlo.divide %65, %70 : tensor<20x20xf32> +- %72 = stablehlo.add %63, %71 : tensor<20x20xf32> ++ %72 = stablehlo.add %71, %63 : tensor<20x20xf32> + %73 = stablehlo.constant dense<-1259.13916> : tensor + %74 = stablehlo.constant dense<-1259.13916> : tensor<20x20xf32> + %75 = stablehlo.constant dense<1.000000e+00> : tensor +@@ -428,7 +428,7 @@ + %229 = stablehlo.multiply %iterArg_4, %228 : tensor<20x20xf32> + %230 = stablehlo.constant dense<-1.000000e+00> : tensor + %231 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xf32> +- %232 = stablehlo.multiply %231, %iterArg_1 : tensor<20x20xf32> ++ %232 = stablehlo.multiply %iterArg_1, %231 : tensor<20x20xf32> + %233 = stablehlo.multiply %232, %iterArg_3 : tensor<20x20xf32> + %234 = stablehlo.multiply %216, %216 : tensor<20x20xf32> + %235 = stablehlo.divide %233, %234 : tensor<20x20xf32> +diff --ruN a/stablehlo/stablehlo/testdata/igammac_dtypes_lhs_float16_20_20__rhs_float16_20_20.mlir b/stablehlo/stablehlo/testdata/igammac_dtypes_lhs_float16_20_20__rhs_float16_20_20.mlir +--- stablehlo/stablehlo/testdata/igammac_dtypes_lhs_float16_20_20__rhs_float16_20_20.mlir ++++ stablehlo/stablehlo/testdata/igammac_dtypes_lhs_float16_20_20__rhs_float16_20_20.mlir +@@ -47,7 +47,7 @@ + %21 = stablehlo.multiply %19, %20 : tensor<20x20xf32> + %22 = stablehlo.constant dense<-1.000000e+00> : tensor + %23 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xf32> +- %24 = stablehlo.multiply %23, %1 : tensor<20x20xf32> ++ %24 = stablehlo.multiply %1, %23 : tensor<20x20xf32> + %25 = stablehlo.multiply %24, %2 : tensor<20x20xf32> + %26 = stablehlo.multiply %6, %6 : tensor<20x20xf32> + %27 = stablehlo.divide %25, %26 : tensor<20x20xf32> +@@ -268,7 +268,7 @@ + %30 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %31 = stablehlo.subtract %30, %25 : tensor<20x20xf32> + %32 = stablehlo.select %28, %31, %25 : tensor<20x20xi1>, tensor<20x20xf32> +- %33 = stablehlo.multiply %22, %32 : tensor<20x20xf32> ++ %33 = stablehlo.multiply %32, %22 : tensor<20x20xf32> + %34 = stablehlo.sine %33 : tensor<20x20xf32> + %35 = stablehlo.log %34 : tensor<20x20xf32> + %36 = stablehlo.is_finite %35 : (tensor<20x20xf32>) -> tensor<20x20xi1> +@@ -286,17 +286,17 @@ + %48 = stablehlo.add %46, %47 : tensor<20x20xf32> + %49 = stablehlo.constant dense<7.500000e+00> : tensor + %50 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> +- %51 = stablehlo.add %50, %46 : tensor<20x20xf32> ++ %51 = stablehlo.add %46, %50 : tensor<20x20xf32> + %52 = stablehlo.constant dense<2.01490307> : tensor + %53 = stablehlo.constant dense<2.01490307> : tensor<20x20xf32> + %54 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> + %55 = stablehlo.divide %46, %54 : tensor<20x20xf32> + %56 = stablehlo.log_plus_one %55 : tensor<20x20xf32> +- %57 = stablehlo.add %53, %56 : tensor<20x20xf32> ++ %57 = stablehlo.add %56, %53 : tensor<20x20xf32> + %58 = stablehlo.divide %51, %57 : tensor<20x20xf32> + %59 = stablehlo.subtract %48, %58 : tensor<20x20xf32> + %60 = stablehlo.multiply %59, %57 : tensor<20x20xf32> +- %61 = stablehlo.add %41, %60 : tensor<20x20xf32> ++ %61 = stablehlo.add %60, %41 : tensor<20x20xf32> + %62 = stablehlo.constant dense<1.000000e+00> : tensor + %63 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %64 = stablehlo.constant dense<676.520386> : tensor +@@ -307,7 +307,7 @@ + %69 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %70 = stablehlo.add %68, %69 : tensor<20x20xf32> + %71 = stablehlo.divide %65, %70 : tensor<20x20xf32> +- %72 = stablehlo.add %63, %71 : tensor<20x20xf32> ++ %72 = stablehlo.add %71, %63 : tensor<20x20xf32> + %73 = stablehlo.constant dense<-1259.13916> : tensor + %74 = stablehlo.constant dense<-1259.13916> : tensor<20x20xf32> + %75 = stablehlo.constant dense<1.000000e+00> : tensor +@@ -428,7 +428,7 @@ + %229 = stablehlo.multiply %iterArg_4, %228 : tensor<20x20xf32> + %230 = stablehlo.constant dense<-1.000000e+00> : tensor + %231 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xf32> +- %232 = stablehlo.multiply %231, %iterArg_1 : tensor<20x20xf32> ++ %232 = stablehlo.multiply %iterArg_1, %231 : tensor<20x20xf32> + %233 = stablehlo.multiply %232, %iterArg_3 : tensor<20x20xf32> + %234 = stablehlo.multiply %216, %216 : tensor<20x20xf32> + %235 = stablehlo.divide %233, %234 : tensor<20x20xf32> +diff --ruN a/stablehlo/stablehlo/testdata/igammac_dtypes_lhs_float32_20_20__rhs_float32_20_20.mlir b/stablehlo/stablehlo/testdata/igammac_dtypes_lhs_float32_20_20__rhs_float32_20_20.mlir +--- stablehlo/stablehlo/testdata/igammac_dtypes_lhs_float32_20_20__rhs_float32_20_20.mlir ++++ stablehlo/stablehlo/testdata/igammac_dtypes_lhs_float32_20_20__rhs_float32_20_20.mlir +@@ -47,7 +47,7 @@ + %21 = stablehlo.multiply %19, %20 : tensor<20x20xf32> + %22 = stablehlo.constant dense<-1.000000e+00> : tensor + %23 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xf32> +- %24 = stablehlo.multiply %23, %1 : tensor<20x20xf32> ++ %24 = stablehlo.multiply %1, %23 : tensor<20x20xf32> + %25 = stablehlo.multiply %24, %2 : tensor<20x20xf32> + %26 = stablehlo.multiply %6, %6 : tensor<20x20xf32> + %27 = stablehlo.divide %25, %26 : tensor<20x20xf32> +@@ -266,7 +266,7 @@ + %28 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %29 = stablehlo.subtract %28, %23 : tensor<20x20xf32> + %30 = stablehlo.select %26, %29, %23 : tensor<20x20xi1>, tensor<20x20xf32> +- %31 = stablehlo.multiply %20, %30 : tensor<20x20xf32> ++ %31 = stablehlo.multiply %30, %20 : tensor<20x20xf32> + %32 = stablehlo.sine %31 : tensor<20x20xf32> + %33 = stablehlo.log %32 : tensor<20x20xf32> + %34 = stablehlo.is_finite %33 : (tensor<20x20xf32>) -> tensor<20x20xi1> +@@ -284,17 +284,17 @@ + %46 = stablehlo.add %44, %45 : tensor<20x20xf32> + %47 = stablehlo.constant dense<7.500000e+00> : tensor + %48 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> +- %49 = stablehlo.add %48, %44 : tensor<20x20xf32> ++ %49 = stablehlo.add %44, %48 : tensor<20x20xf32> + %50 = stablehlo.constant dense<2.01490307> : tensor + %51 = stablehlo.constant dense<2.01490307> : tensor<20x20xf32> + %52 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> + %53 = stablehlo.divide %44, %52 : tensor<20x20xf32> + %54 = stablehlo.log_plus_one %53 : tensor<20x20xf32> +- %55 = stablehlo.add %51, %54 : tensor<20x20xf32> ++ %55 = stablehlo.add %54, %51 : tensor<20x20xf32> + %56 = stablehlo.divide %49, %55 : tensor<20x20xf32> + %57 = stablehlo.subtract %46, %56 : tensor<20x20xf32> + %58 = stablehlo.multiply %57, %55 : tensor<20x20xf32> +- %59 = stablehlo.add %39, %58 : tensor<20x20xf32> ++ %59 = stablehlo.add %58, %39 : tensor<20x20xf32> + %60 = stablehlo.constant dense<1.000000e+00> : tensor + %61 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %62 = stablehlo.constant dense<676.520386> : tensor +@@ -305,7 +305,7 @@ + %67 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %68 = stablehlo.add %66, %67 : tensor<20x20xf32> + %69 = stablehlo.divide %63, %68 : tensor<20x20xf32> +- %70 = stablehlo.add %61, %69 : tensor<20x20xf32> ++ %70 = stablehlo.add %69, %61 : tensor<20x20xf32> + %71 = stablehlo.constant dense<-1259.13916> : tensor + %72 = stablehlo.constant dense<-1259.13916> : tensor<20x20xf32> + %73 = stablehlo.constant dense<1.000000e+00> : tensor +@@ -426,7 +426,7 @@ + %226 = stablehlo.multiply %iterArg_4, %225 : tensor<20x20xf32> + %227 = stablehlo.constant dense<-1.000000e+00> : tensor + %228 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xf32> +- %229 = stablehlo.multiply %228, %iterArg_1 : tensor<20x20xf32> ++ %229 = stablehlo.multiply %iterArg_1, %228 : tensor<20x20xf32> + %230 = stablehlo.multiply %229, %iterArg_3 : tensor<20x20xf32> + %231 = stablehlo.multiply %213, %213 : tensor<20x20xf32> + %232 = stablehlo.divide %230, %231 : tensor<20x20xf32> +diff --ruN a/stablehlo/stablehlo/testdata/index_in_dim_0_dynamic.mlir b/stablehlo/stablehlo/testdata/index_in_dim_0_dynamic.mlir +--- stablehlo/stablehlo/testdata/index_in_dim_0_dynamic.mlir ++++ stablehlo/stablehlo/testdata/index_in_dim_0_dynamic.mlir +@@ -3,7 +3,7 @@ + module @jit_fun_flat_jax { + func.func public @main(%arg0: tensor, %arg1: tensor {mhlo.sharding = ""}) -> tensor<4xf32> { + %0 = stablehlo.constant dense<-1> : tensor +- %1 = stablehlo.add %0, %arg0 : tensor ++ %1 = stablehlo.add %arg0, %0 : tensor + %2 = stablehlo.convert %1 : (tensor) -> tensor + %3 = stablehlo.reshape %2 : (tensor) -> tensor<1xi32> + %4 = stablehlo.constant dense<0> : tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/index_in_dim_idx_neg_dynamic.mlir b/stablehlo/stablehlo/testdata/index_in_dim_idx_neg_dynamic.mlir +--- stablehlo/stablehlo/testdata/index_in_dim_idx_neg_dynamic.mlir ++++ stablehlo/stablehlo/testdata/index_in_dim_idx_neg_dynamic.mlir +@@ -3,7 +3,7 @@ + module @jit_fun_flat_jax { + func.func public @main(%arg0: tensor, %arg1: tensor {mhlo.sharding = ""}) -> tensor<4xf32> { + %0 = stablehlo.constant dense<-1> : tensor +- %1 = stablehlo.add %0, %arg0 : tensor ++ %1 = stablehlo.add %arg0, %0 : tensor + %2 = stablehlo.convert %1 : (tensor) -> tensor + %3 = stablehlo.reshape %2 : (tensor) -> tensor<1xi32> + %4 = stablehlo.constant dense<0> : tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/lgamma_shape_bfloat16_20_20.mlir b/stablehlo/stablehlo/testdata/lgamma_shape_bfloat16_20_20.mlir +--- stablehlo/stablehlo/testdata/lgamma_shape_bfloat16_20_20.mlir ++++ stablehlo/stablehlo/testdata/lgamma_shape_bfloat16_20_20.mlir +@@ -17,7 +17,7 @@ + %11 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %12 = stablehlo.add %8, %11 : tensor<20x20xf32> + %13 = stablehlo.divide %10, %12 : tensor<20x20xf32> +- %14 = stablehlo.add %9, %13 : tensor<20x20xf32> ++ %14 = stablehlo.add %13, %9 : tensor<20x20xf32> + %15 = stablehlo.constant dense<-1259.13916> : tensor<20x20xf32> + %16 = stablehlo.constant dense<2.000000e+00> : tensor<20x20xf32> + %17 = stablehlo.add %8, %16 : tensor<20x20xf32> +@@ -54,18 +54,18 @@ + %48 = stablehlo.divide %45, %47 : tensor<20x20xf32> + %49 = stablehlo.add %44, %48 : tensor<20x20xf32> + %50 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> +- %51 = stablehlo.add %50, %8 : tensor<20x20xf32> ++ %51 = stablehlo.add %8, %50 : tensor<20x20xf32> + %52 = stablehlo.constant dense<2.01490307> : tensor<20x20xf32> + %53 = stablehlo.divide %8, %50 : tensor<20x20xf32> + %54 = stablehlo.log_plus_one %53 : tensor<20x20xf32> +- %55 = stablehlo.add %52, %54 : tensor<20x20xf32> ++ %55 = stablehlo.add %54, %52 : tensor<20x20xf32> + %56 = stablehlo.divide %51, %55 : tensor<20x20xf32> + %57 = stablehlo.add %8, %3 : tensor<20x20xf32> + %58 = stablehlo.subtract %57, %56 : tensor<20x20xf32> + %59 = stablehlo.multiply %58, %55 : tensor<20x20xf32> + %60 = stablehlo.log %49 : tensor<20x20xf32> + %61 = stablehlo.constant dense<0.918938517> : tensor<20x20xf32> +- %62 = stablehlo.add %61, %59 : tensor<20x20xf32> ++ %62 = stablehlo.add %59, %61 : tensor<20x20xf32> + %63 = stablehlo.add %62, %60 : tensor<20x20xf32> + %64 = stablehlo.abs %2 : tensor<20x20xf32> + %65 = stablehlo.floor %64 : tensor<20x20xf32> +@@ -74,7 +74,7 @@ + %68 = stablehlo.subtract %6, %66 : tensor<20x20xf32> + %69 = stablehlo.select %67, %68, %66 : tensor<20x20xi1>, tensor<20x20xf32> + %70 = stablehlo.constant dense<3.14159274> : tensor<20x20xf32> +- %71 = stablehlo.multiply %70, %69 : tensor<20x20xf32> ++ %71 = stablehlo.multiply %69, %70 : tensor<20x20xf32> + %72 = stablehlo.sine %71 : tensor<20x20xf32> + %73 = stablehlo.log %72 : tensor<20x20xf32> + %74 = stablehlo.constant dense<1.14472985> : tensor<20x20xf32> +diff --ruN a/stablehlo/stablehlo/testdata/lgamma_shape_float16_20_20.mlir b/stablehlo/stablehlo/testdata/lgamma_shape_float16_20_20.mlir +--- stablehlo/stablehlo/testdata/lgamma_shape_float16_20_20.mlir ++++ stablehlo/stablehlo/testdata/lgamma_shape_float16_20_20.mlir +@@ -17,7 +17,7 @@ + %11 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %12 = stablehlo.add %8, %11 : tensor<20x20xf32> + %13 = stablehlo.divide %10, %12 : tensor<20x20xf32> +- %14 = stablehlo.add %9, %13 : tensor<20x20xf32> ++ %14 = stablehlo.add %13, %9 : tensor<20x20xf32> + %15 = stablehlo.constant dense<-1259.13916> : tensor<20x20xf32> + %16 = stablehlo.constant dense<2.000000e+00> : tensor<20x20xf32> + %17 = stablehlo.add %8, %16 : tensor<20x20xf32> +@@ -54,18 +54,18 @@ + %48 = stablehlo.divide %45, %47 : tensor<20x20xf32> + %49 = stablehlo.add %44, %48 : tensor<20x20xf32> + %50 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> +- %51 = stablehlo.add %50, %8 : tensor<20x20xf32> ++ %51 = stablehlo.add %8, %50 : tensor<20x20xf32> + %52 = stablehlo.constant dense<2.01490307> : tensor<20x20xf32> + %53 = stablehlo.divide %8, %50 : tensor<20x20xf32> + %54 = stablehlo.log_plus_one %53 : tensor<20x20xf32> +- %55 = stablehlo.add %52, %54 : tensor<20x20xf32> ++ %55 = stablehlo.add %54, %52 : tensor<20x20xf32> + %56 = stablehlo.divide %51, %55 : tensor<20x20xf32> + %57 = stablehlo.add %8, %3 : tensor<20x20xf32> + %58 = stablehlo.subtract %57, %56 : tensor<20x20xf32> + %59 = stablehlo.multiply %58, %55 : tensor<20x20xf32> + %60 = stablehlo.log %49 : tensor<20x20xf32> + %61 = stablehlo.constant dense<0.918938517> : tensor<20x20xf32> +- %62 = stablehlo.add %61, %59 : tensor<20x20xf32> ++ %62 = stablehlo.add %59, %61 : tensor<20x20xf32> + %63 = stablehlo.add %62, %60 : tensor<20x20xf32> + %64 = stablehlo.abs %2 : tensor<20x20xf32> + %65 = stablehlo.floor %64 : tensor<20x20xf32> +@@ -74,7 +74,7 @@ + %68 = stablehlo.subtract %6, %66 : tensor<20x20xf32> + %69 = stablehlo.select %67, %68, %66 : tensor<20x20xi1>, tensor<20x20xf32> + %70 = stablehlo.constant dense<3.14159274> : tensor<20x20xf32> +- %71 = stablehlo.multiply %70, %69 : tensor<20x20xf32> ++ %71 = stablehlo.multiply %69, %70 : tensor<20x20xf32> + %72 = stablehlo.sine %71 : tensor<20x20xf32> + %73 = stablehlo.log %72 : tensor<20x20xf32> + %74 = stablehlo.constant dense<1.14472985> : tensor<20x20xf32> +diff --ruN a/stablehlo/stablehlo/testdata/lgamma_shape_float32_20_20.mlir b/stablehlo/stablehlo/testdata/lgamma_shape_float32_20_20.mlir +--- stablehlo/stablehlo/testdata/lgamma_shape_float32_20_20.mlir ++++ stablehlo/stablehlo/testdata/lgamma_shape_float32_20_20.mlir +@@ -16,7 +16,7 @@ + %10 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> + %11 = stablehlo.add %7, %10 : tensor<20x20xf32> + %12 = stablehlo.divide %9, %11 : tensor<20x20xf32> +- %13 = stablehlo.add %8, %12 : tensor<20x20xf32> ++ %13 = stablehlo.add %12, %8 : tensor<20x20xf32> + %14 = stablehlo.constant dense<-1259.13916> : tensor<20x20xf32> + %15 = stablehlo.constant dense<2.000000e+00> : tensor<20x20xf32> + %16 = stablehlo.add %7, %15 : tensor<20x20xf32> +@@ -53,18 +53,18 @@ + %47 = stablehlo.divide %44, %46 : tensor<20x20xf32> + %48 = stablehlo.add %43, %47 : tensor<20x20xf32> + %49 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> +- %50 = stablehlo.add %49, %7 : tensor<20x20xf32> ++ %50 = stablehlo.add %7, %49 : tensor<20x20xf32> + %51 = stablehlo.constant dense<2.01490307> : tensor<20x20xf32> + %52 = stablehlo.divide %7, %49 : tensor<20x20xf32> + %53 = stablehlo.log_plus_one %52 : tensor<20x20xf32> +- %54 = stablehlo.add %51, %53 : tensor<20x20xf32> ++ %54 = stablehlo.add %53, %51 : tensor<20x20xf32> + %55 = stablehlo.divide %50, %54 : tensor<20x20xf32> + %56 = stablehlo.add %7, %2 : tensor<20x20xf32> + %57 = stablehlo.subtract %56, %55 : tensor<20x20xf32> + %58 = stablehlo.multiply %57, %54 : tensor<20x20xf32> + %59 = stablehlo.log %48 : tensor<20x20xf32> + %60 = stablehlo.constant dense<0.918938517> : tensor<20x20xf32> +- %61 = stablehlo.add %60, %58 : tensor<20x20xf32> ++ %61 = stablehlo.add %58, %60 : tensor<20x20xf32> + %62 = stablehlo.add %61, %59 : tensor<20x20xf32> + %63 = stablehlo.abs %0 : tensor<20x20xf32> + %64 = stablehlo.floor %63 : tensor<20x20xf32> +@@ -73,7 +73,7 @@ + %67 = stablehlo.subtract %5, %65 : tensor<20x20xf32> + %68 = stablehlo.select %66, %67, %65 : tensor<20x20xi1>, tensor<20x20xf32> + %69 = stablehlo.constant dense<3.14159274> : tensor<20x20xf32> +- %70 = stablehlo.multiply %69, %68 : tensor<20x20xf32> ++ %70 = stablehlo.multiply %68, %69 : tensor<20x20xf32> + %71 = stablehlo.sine %70 : tensor<20x20xf32> + %72 = stablehlo.log %71 : tensor<20x20xf32> + %73 = stablehlo.constant dense<1.14472985> : tensor<20x20xf32> +diff --ruN a/stablehlo/stablehlo/testdata/nanquantile_axis_None_dynamic.mlir b/stablehlo/stablehlo/testdata/nanquantile_axis_None_dynamic.mlir +--- stablehlo/stablehlo/testdata/nanquantile_axis_None_dynamic.mlir ++++ stablehlo/stablehlo/testdata/nanquantile_axis_None_dynamic.mlir +@@ -72,12 +72,12 @@ + %24 = stablehlo.subtract %14, %23 : tensor + %25 = stablehlo.minimum %18, %24 : tensor + %26 = stablehlo.constant dense<0.000000e+00> : tensor +- %27 = stablehlo.maximum %26, %25 : tensor ++ %27 = stablehlo.maximum %25, %26 : tensor + %28 = stablehlo.constant dense<1.000000e+00> : tensor + %29 = stablehlo.subtract %14, %28 : tensor + %30 = stablehlo.minimum %19, %29 : tensor + %31 = stablehlo.constant dense<0.000000e+00> : tensor +- %32 = stablehlo.maximum %31, %30 : tensor ++ %32 = stablehlo.maximum %30, %31 : tensor + %33 = stablehlo.convert %27 : (tensor) -> tensor + %34 = stablehlo.convert %32 : (tensor) -> tensor + %35 = stablehlo.constant dense<5> : tensor +diff --ruN a/stablehlo/stablehlo/testdata/random_gamma_shape_float32.mlir b/stablehlo/stablehlo/testdata/random_gamma_shape_float32.mlir +--- stablehlo/stablehlo/testdata/random_gamma_shape_float32.mlir ++++ stablehlo/stablehlo/testdata/random_gamma_shape_float32.mlir +@@ -338,7 +338,7 @@ + %122 = stablehlo.add %120, %121 : tensor + %123 = stablehlo.reshape %122 : (tensor) -> tensor + %124 = stablehlo.constant dense<0.000000e+00> : tensor +- %125 = stablehlo.maximum %124, %123 : tensor ++ %125 = stablehlo.maximum %123, %124 : tensor + %126 = stablehlo.constant dense<0.000000e+00> : tensor + %127 = stablehlo.constant dense<1.000000e+00> : tensor + %128 = stablehlo.constant dense<2.000000e+00> : tensor +@@ -346,7 +346,7 @@ + cond { + %151 = stablehlo.multiply %iterArg_6, %iterArg_6 : tensor + %152 = stablehlo.constant dense<3.310000e-02> : tensor +- %153 = stablehlo.multiply %152, %151 : tensor ++ %153 = stablehlo.multiply %151, %152 : tensor + %154 = stablehlo.constant dense<1.000000e+00> : tensor + %155 = stablehlo.subtract %154, %153 : tensor + %156 = stablehlo.compare GE, %iterArg_8, %155, FLOAT : (tensor, tensor) -> tensor +@@ -656,13 +656,13 @@ + %289 = stablehlo.add %287, %288 : tensor + %290 = stablehlo.reshape %289 : (tensor) -> tensor + %291 = stablehlo.constant dense<-0.99999994> : tensor +- %292 = stablehlo.maximum %291, %290 : tensor ++ %292 = stablehlo.maximum %290, %291 : tensor + %293 = func.call @erf_inv(%292) : (tensor) -> tensor + %294 = stablehlo.constant dense<1.41421354> : tensor +- %295 = stablehlo.multiply %294, %293 : tensor ++ %295 = stablehlo.multiply %293, %294 : tensor + %296 = stablehlo.multiply %295, %iterArg_9 : tensor + %297 = stablehlo.constant dense<1.000000e+00> : tensor +- %298 = stablehlo.add %297, %296 : tensor ++ %298 = stablehlo.add %296, %297 : tensor + stablehlo.return %iterArg_9, %248, %295, %298 : tensor, tensor<2xui32>, tensor, tensor + } + %181 = stablehlo.multiply %180#2, %180#2 : tensor +@@ -773,7 +773,7 @@ + %222 = stablehlo.add %220, %221 : tensor + %223 = stablehlo.reshape %222 : (tensor) -> tensor + %224 = stablehlo.constant dense<0.000000e+00> : tensor +- %225 = stablehlo.maximum %224, %223 : tensor ++ %225 = stablehlo.maximum %223, %224 : tensor + stablehlo.return %iterArg_3, %iterArg_4, %173, %181, %183, %225 : tensor, tensor, tensor<2xui32>, tensor, tensor, tensor + } + %130 = stablehlo.constant dense<1.000000e+00> : tensor +diff --ruN a/stablehlo/stablehlo/testdata/random_gamma_shape_float32_3.mlir b/stablehlo/stablehlo/testdata/random_gamma_shape_float32_3.mlir +--- stablehlo/stablehlo/testdata/random_gamma_shape_float32_3.mlir ++++ stablehlo/stablehlo/testdata/random_gamma_shape_float32_3.mlir +@@ -341,7 +341,7 @@ + %122 = stablehlo.add %120, %121 : tensor + %123 = stablehlo.reshape %122 : (tensor) -> tensor + %124 = stablehlo.constant dense<0.000000e+00> : tensor +- %125 = stablehlo.maximum %124, %123 : tensor ++ %125 = stablehlo.maximum %123, %124 : tensor + %126 = stablehlo.constant dense<0.000000e+00> : tensor + %127 = stablehlo.constant dense<1.000000e+00> : tensor + %128 = stablehlo.constant dense<2.000000e+00> : tensor +@@ -349,7 +349,7 @@ + cond { + %151 = stablehlo.multiply %iterArg_6, %iterArg_6 : tensor + %152 = stablehlo.constant dense<3.310000e-02> : tensor +- %153 = stablehlo.multiply %152, %151 : tensor ++ %153 = stablehlo.multiply %151, %152 : tensor + %154 = stablehlo.constant dense<1.000000e+00> : tensor + %155 = stablehlo.subtract %154, %153 : tensor + %156 = stablehlo.compare GE, %iterArg_8, %155, FLOAT : (tensor, tensor) -> tensor +@@ -659,13 +659,13 @@ + %289 = stablehlo.add %287, %288 : tensor + %290 = stablehlo.reshape %289 : (tensor) -> tensor + %291 = stablehlo.constant dense<-0.99999994> : tensor +- %292 = stablehlo.maximum %291, %290 : tensor ++ %292 = stablehlo.maximum %290, %291 : tensor + %293 = func.call @erf_inv(%292) : (tensor) -> tensor + %294 = stablehlo.constant dense<1.41421354> : tensor +- %295 = stablehlo.multiply %294, %293 : tensor ++ %295 = stablehlo.multiply %293, %294 : tensor + %296 = stablehlo.multiply %295, %iterArg_9 : tensor + %297 = stablehlo.constant dense<1.000000e+00> : tensor +- %298 = stablehlo.add %297, %296 : tensor ++ %298 = stablehlo.add %296, %297 : tensor + stablehlo.return %iterArg_9, %248, %295, %298 : tensor, tensor<2xui32>, tensor, tensor + } + %181 = stablehlo.multiply %180#2, %180#2 : tensor +@@ -776,7 +776,7 @@ + %222 = stablehlo.add %220, %221 : tensor + %223 = stablehlo.reshape %222 : (tensor) -> tensor + %224 = stablehlo.constant dense<0.000000e+00> : tensor +- %225 = stablehlo.maximum %224, %223 : tensor ++ %225 = stablehlo.maximum %223, %224 : tensor + stablehlo.return %iterArg_3, %iterArg_4, %173, %181, %183, %225 : tensor, tensor, tensor<2xui32>, tensor, tensor, tensor + } + %130 = stablehlo.constant dense<1.000000e+00> : tensor +diff --ruN a/stablehlo/stablehlo/testdata/random_gamma_shape_float64.mlir b/stablehlo/stablehlo/testdata/random_gamma_shape_float64.mlir +--- stablehlo/stablehlo/testdata/random_gamma_shape_float64.mlir ++++ stablehlo/stablehlo/testdata/random_gamma_shape_float64.mlir +@@ -338,7 +338,7 @@ + %122 = stablehlo.add %120, %121 : tensor + %123 = stablehlo.reshape %122 : (tensor) -> tensor + %124 = stablehlo.constant dense<0.000000e+00> : tensor +- %125 = stablehlo.maximum %124, %123 : tensor ++ %125 = stablehlo.maximum %123, %124 : tensor + %126 = stablehlo.constant dense<0.000000e+00> : tensor + %127 = stablehlo.constant dense<1.000000e+00> : tensor + %128 = stablehlo.constant dense<2.000000e+00> : tensor +@@ -346,7 +346,7 @@ + cond { + %151 = stablehlo.multiply %iterArg_6, %iterArg_6 : tensor + %152 = stablehlo.constant dense<3.310000e-02> : tensor +- %153 = stablehlo.multiply %152, %151 : tensor ++ %153 = stablehlo.multiply %151, %152 : tensor + %154 = stablehlo.constant dense<1.000000e+00> : tensor + %155 = stablehlo.subtract %154, %153 : tensor + %156 = stablehlo.compare GE, %iterArg_8, %155, FLOAT : (tensor, tensor) -> tensor +@@ -656,13 +656,13 @@ + %289 = stablehlo.add %287, %288 : tensor + %290 = stablehlo.reshape %289 : (tensor) -> tensor + %291 = stablehlo.constant dense<-0.99999994> : tensor +- %292 = stablehlo.maximum %291, %290 : tensor ++ %292 = stablehlo.maximum %290, %291 : tensor + %293 = func.call @erf_inv(%292) : (tensor) -> tensor + %294 = stablehlo.constant dense<1.41421354> : tensor +- %295 = stablehlo.multiply %294, %293 : tensor ++ %295 = stablehlo.multiply %293, %294 : tensor + %296 = stablehlo.multiply %295, %iterArg_9 : tensor + %297 = stablehlo.constant dense<1.000000e+00> : tensor +- %298 = stablehlo.add %297, %296 : tensor ++ %298 = stablehlo.add %296, %297 : tensor + stablehlo.return %iterArg_9, %248, %295, %298 : tensor, tensor<2xui32>, tensor, tensor + } + %181 = stablehlo.multiply %180#2, %180#2 : tensor +@@ -773,7 +773,7 @@ + %222 = stablehlo.add %220, %221 : tensor + %223 = stablehlo.reshape %222 : (tensor) -> tensor + %224 = stablehlo.constant dense<0.000000e+00> : tensor +- %225 = stablehlo.maximum %224, %223 : tensor ++ %225 = stablehlo.maximum %223, %224 : tensor + stablehlo.return %iterArg_3, %iterArg_4, %173, %181, %183, %225 : tensor, tensor, tensor<2xui32>, tensor, tensor, tensor + } + %130 = stablehlo.constant dense<1.000000e+00> : tensor +diff --ruN a/stablehlo/stablehlo/testdata/random_gamma_shape_float64_3.mlir b/stablehlo/stablehlo/testdata/random_gamma_shape_float64_3.mlir +--- stablehlo/stablehlo/testdata/random_gamma_shape_float64_3.mlir ++++ stablehlo/stablehlo/testdata/random_gamma_shape_float64_3.mlir +@@ -341,7 +341,7 @@ + %122 = stablehlo.add %120, %121 : tensor + %123 = stablehlo.reshape %122 : (tensor) -> tensor + %124 = stablehlo.constant dense<0.000000e+00> : tensor +- %125 = stablehlo.maximum %124, %123 : tensor ++ %125 = stablehlo.maximum %123, %124 : tensor + %126 = stablehlo.constant dense<0.000000e+00> : tensor + %127 = stablehlo.constant dense<1.000000e+00> : tensor + %128 = stablehlo.constant dense<2.000000e+00> : tensor +@@ -349,7 +349,7 @@ + cond { + %151 = stablehlo.multiply %iterArg_6, %iterArg_6 : tensor + %152 = stablehlo.constant dense<3.310000e-02> : tensor +- %153 = stablehlo.multiply %152, %151 : tensor ++ %153 = stablehlo.multiply %151, %152 : tensor + %154 = stablehlo.constant dense<1.000000e+00> : tensor + %155 = stablehlo.subtract %154, %153 : tensor + %156 = stablehlo.compare GE, %iterArg_8, %155, FLOAT : (tensor, tensor) -> tensor +@@ -659,13 +659,13 @@ + %289 = stablehlo.add %287, %288 : tensor + %290 = stablehlo.reshape %289 : (tensor) -> tensor + %291 = stablehlo.constant dense<-0.99999994> : tensor +- %292 = stablehlo.maximum %291, %290 : tensor ++ %292 = stablehlo.maximum %290, %291 : tensor + %293 = func.call @erf_inv(%292) : (tensor) -> tensor + %294 = stablehlo.constant dense<1.41421354> : tensor +- %295 = stablehlo.multiply %294, %293 : tensor ++ %295 = stablehlo.multiply %293, %294 : tensor + %296 = stablehlo.multiply %295, %iterArg_9 : tensor + %297 = stablehlo.constant dense<1.000000e+00> : tensor +- %298 = stablehlo.add %297, %296 : tensor ++ %298 = stablehlo.add %296, %297 : tensor + stablehlo.return %iterArg_9, %248, %295, %298 : tensor, tensor<2xui32>, tensor, tensor + } + %181 = stablehlo.multiply %180#2, %180#2 : tensor +@@ -776,7 +776,7 @@ + %222 = stablehlo.add %220, %221 : tensor + %223 = stablehlo.reshape %222 : (tensor) -> tensor + %224 = stablehlo.constant dense<0.000000e+00> : tensor +- %225 = stablehlo.maximum %224, %223 : tensor ++ %225 = stablehlo.maximum %223, %224 : tensor + stablehlo.return %iterArg_3, %iterArg_4, %173, %181, %183, %225 : tensor, tensor, tensor<2xui32>, tensor, tensor, tensor + } + %130 = stablehlo.constant dense<1.000000e+00> : tensor +diff --ruN a/stablehlo/stablehlo/testdata/random_uniform_shape_bfloat16.mlir b/stablehlo/stablehlo/testdata/random_uniform_shape_bfloat16.mlir +--- stablehlo/stablehlo/testdata/random_uniform_shape_bfloat16.mlir ++++ stablehlo/stablehlo/testdata/random_uniform_shape_bfloat16.mlir +@@ -125,7 +125,7 @@ + %55 = stablehlo.add %53, %54 : tensor + %56 = stablehlo.reshape %55 : (tensor) -> tensor + %57 = stablehlo.constant dense<0.000000e+00> : tensor +- %58 = stablehlo.maximum %57, %56 : tensor ++ %58 = stablehlo.maximum %56, %57 : tensor + %59 = stablehlo.custom_call @check.eq(%58, %1) : (tensor, tensor) -> tensor + return %59 : tensor + } +diff --ruN a/stablehlo/stablehlo/testdata/random_uniform_shape_float16.mlir b/stablehlo/stablehlo/testdata/random_uniform_shape_float16.mlir +--- stablehlo/stablehlo/testdata/random_uniform_shape_float16.mlir ++++ stablehlo/stablehlo/testdata/random_uniform_shape_float16.mlir +@@ -125,7 +125,7 @@ + %55 = stablehlo.add %53, %54 : tensor + %56 = stablehlo.reshape %55 : (tensor) -> tensor + %57 = stablehlo.constant dense<0.000000e+00> : tensor +- %58 = stablehlo.maximum %57, %56 : tensor ++ %58 = stablehlo.maximum %56, %57 : tensor + %59 = stablehlo.custom_call @check.eq(%58, %1) : (tensor, tensor) -> tensor + return %59 : tensor + } +diff --ruN a/stablehlo/stablehlo/testdata/random_uniform_shape_float32.mlir b/stablehlo/stablehlo/testdata/random_uniform_shape_float32.mlir +--- stablehlo/stablehlo/testdata/random_uniform_shape_float32.mlir ++++ stablehlo/stablehlo/testdata/random_uniform_shape_float32.mlir +@@ -110,7 +110,7 @@ + %40 = stablehlo.add %38, %39 : tensor + %41 = stablehlo.reshape %40 : (tensor) -> tensor + %42 = stablehlo.constant dense<0.000000e+00> : tensor +- %43 = stablehlo.maximum %42, %41 : tensor ++ %43 = stablehlo.maximum %41, %42 : tensor + %44 = stablehlo.custom_call @check.eq(%43, %1) : (tensor, tensor) -> tensor + return %44 : tensor + } +diff --ruN a/stablehlo/stablehlo/testdata/regularized_incomplete_beta__bfloat16.mlir b/stablehlo/stablehlo/testdata/regularized_incomplete_beta__bfloat16.mlir +--- stablehlo/stablehlo/testdata/regularized_incomplete_beta__bfloat16.mlir ++++ stablehlo/stablehlo/testdata/regularized_incomplete_beta__bfloat16.mlir +@@ -71,9 +71,9 @@ + %40 = stablehlo.multiply %38, %39 : tensor<9xf32> + %41 = stablehlo.constant dense<2.000000e+00> : tensor + %42 = stablehlo.constant dense<2.000000e+00> : tensor<9xf32> +- %43 = stablehlo.multiply %42, %32 : tensor<9xf32> ++ %43 = stablehlo.multiply %32, %42 : tensor<9xf32> + %44 = stablehlo.add %25, %43 : tensor<9xf32> +- %45 = stablehlo.multiply %42, %32 : tensor<9xf32> ++ %45 = stablehlo.multiply %32, %42 : tensor<9xf32> + %46 = stablehlo.add %25, %45 : tensor<9xf32> + %47 = stablehlo.constant dense<1.000000e+00> : tensor + %48 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> +@@ -83,10 +83,10 @@ + %52 = stablehlo.subtract %35, %32 : tensor<9xf32> + %53 = stablehlo.multiply %32, %52 : tensor<9xf32> + %54 = stablehlo.multiply %53, %39 : tensor<9xf32> +- %55 = stablehlo.multiply %42, %32 : tensor<9xf32> ++ %55 = stablehlo.multiply %32, %42 : tensor<9xf32> + %56 = stablehlo.add %25, %55 : tensor<9xf32> + %57 = stablehlo.subtract %56, %48 : tensor<9xf32> +- %58 = stablehlo.multiply %42, %32 : tensor<9xf32> ++ %58 = stablehlo.multiply %32, %42 : tensor<9xf32> + %59 = stablehlo.add %25, %58 : tensor<9xf32> + %60 = stablehlo.multiply %57, %59 : tensor<9xf32> + %61 = stablehlo.divide %54, %60 : tensor<9xf32> +@@ -225,9 +225,9 @@ + %502 = stablehlo.multiply %501, %iterArg_6 : tensor<9xf32> + %503 = stablehlo.constant dense<2.000000e+00> : tensor + %504 = stablehlo.constant dense<2.000000e+00> : tensor<9xf32> +- %505 = stablehlo.multiply %504, %496 : tensor<9xf32> ++ %505 = stablehlo.multiply %496, %504 : tensor<9xf32> + %506 = stablehlo.add %iterArg_4, %505 : tensor<9xf32> +- %507 = stablehlo.multiply %504, %496 : tensor<9xf32> ++ %507 = stablehlo.multiply %496, %504 : tensor<9xf32> + %508 = stablehlo.add %iterArg_4, %507 : tensor<9xf32> + %509 = stablehlo.constant dense<1.000000e+00> : tensor + %510 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> +@@ -237,10 +237,10 @@ + %514 = stablehlo.subtract %iterArg_5, %496 : tensor<9xf32> + %515 = stablehlo.multiply %496, %514 : tensor<9xf32> + %516 = stablehlo.multiply %515, %iterArg_6 : tensor<9xf32> +- %517 = stablehlo.multiply %504, %496 : tensor<9xf32> ++ %517 = stablehlo.multiply %496, %504 : tensor<9xf32> + %518 = stablehlo.add %iterArg_4, %517 : tensor<9xf32> + %519 = stablehlo.subtract %518, %510 : tensor<9xf32> +- %520 = stablehlo.multiply %504, %496 : tensor<9xf32> ++ %520 = stablehlo.multiply %496, %504 : tensor<9xf32> + %521 = stablehlo.add %iterArg_4, %520 : tensor<9xf32> + %522 = stablehlo.multiply %519, %521 : tensor<9xf32> + %523 = stablehlo.divide %516, %522 : tensor<9xf32> +@@ -322,7 +322,7 @@ + %79 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> + %80 = stablehlo.subtract %79, %74 : tensor<9xf32> + %81 = stablehlo.select %77, %80, %74 : tensor<9xi1>, tensor<9xf32> +- %82 = stablehlo.multiply %71, %81 : tensor<9xf32> ++ %82 = stablehlo.multiply %81, %71 : tensor<9xf32> + %83 = stablehlo.sine %82 : tensor<9xf32> + %84 = stablehlo.log %83 : tensor<9xf32> + %85 = stablehlo.is_finite %84 : (tensor<9xf32>) -> tensor<9xi1> +@@ -340,17 +340,17 @@ + %97 = stablehlo.add %95, %96 : tensor<9xf32> + %98 = stablehlo.constant dense<7.500000e+00> : tensor + %99 = stablehlo.constant dense<7.500000e+00> : tensor<9xf32> +- %100 = stablehlo.add %99, %95 : tensor<9xf32> ++ %100 = stablehlo.add %95, %99 : tensor<9xf32> + %101 = stablehlo.constant dense<2.01490307> : tensor + %102 = stablehlo.constant dense<2.01490307> : tensor<9xf32> + %103 = stablehlo.constant dense<7.500000e+00> : tensor<9xf32> + %104 = stablehlo.divide %95, %103 : tensor<9xf32> + %105 = stablehlo.log_plus_one %104 : tensor<9xf32> +- %106 = stablehlo.add %102, %105 : tensor<9xf32> ++ %106 = stablehlo.add %105, %102 : tensor<9xf32> + %107 = stablehlo.divide %100, %106 : tensor<9xf32> + %108 = stablehlo.subtract %97, %107 : tensor<9xf32> + %109 = stablehlo.multiply %108, %106 : tensor<9xf32> +- %110 = stablehlo.add %90, %109 : tensor<9xf32> ++ %110 = stablehlo.add %109, %90 : tensor<9xf32> + %111 = stablehlo.constant dense<1.000000e+00> : tensor + %112 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> + %113 = stablehlo.constant dense<676.520386> : tensor +@@ -361,7 +361,7 @@ + %118 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> + %119 = stablehlo.add %117, %118 : tensor<9xf32> + %120 = stablehlo.divide %114, %119 : tensor<9xf32> +- %121 = stablehlo.add %112, %120 : tensor<9xf32> ++ %121 = stablehlo.add %120, %112 : tensor<9xf32> + %122 = stablehlo.constant dense<-1259.13916> : tensor + %123 = stablehlo.constant dense<-1259.13916> : tensor<9xf32> + %124 = stablehlo.constant dense<1.000000e+00> : tensor +@@ -453,7 +453,7 @@ + %210 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> + %211 = stablehlo.subtract %210, %205 : tensor<9xf32> + %212 = stablehlo.select %208, %211, %205 : tensor<9xi1>, tensor<9xf32> +- %213 = stablehlo.multiply %202, %212 : tensor<9xf32> ++ %213 = stablehlo.multiply %212, %202 : tensor<9xf32> + %214 = stablehlo.sine %213 : tensor<9xf32> + %215 = stablehlo.log %214 : tensor<9xf32> + %216 = stablehlo.is_finite %215 : (tensor<9xf32>) -> tensor<9xi1> +@@ -471,17 +471,17 @@ + %228 = stablehlo.add %226, %227 : tensor<9xf32> + %229 = stablehlo.constant dense<7.500000e+00> : tensor + %230 = stablehlo.constant dense<7.500000e+00> : tensor<9xf32> +- %231 = stablehlo.add %230, %226 : tensor<9xf32> ++ %231 = stablehlo.add %226, %230 : tensor<9xf32> + %232 = stablehlo.constant dense<2.01490307> : tensor + %233 = stablehlo.constant dense<2.01490307> : tensor<9xf32> + %234 = stablehlo.constant dense<7.500000e+00> : tensor<9xf32> + %235 = stablehlo.divide %226, %234 : tensor<9xf32> + %236 = stablehlo.log_plus_one %235 : tensor<9xf32> +- %237 = stablehlo.add %233, %236 : tensor<9xf32> ++ %237 = stablehlo.add %236, %233 : tensor<9xf32> + %238 = stablehlo.divide %231, %237 : tensor<9xf32> + %239 = stablehlo.subtract %228, %238 : tensor<9xf32> + %240 = stablehlo.multiply %239, %237 : tensor<9xf32> +- %241 = stablehlo.add %221, %240 : tensor<9xf32> ++ %241 = stablehlo.add %240, %221 : tensor<9xf32> + %242 = stablehlo.constant dense<1.000000e+00> : tensor + %243 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> + %244 = stablehlo.constant dense<676.520386> : tensor +@@ -492,7 +492,7 @@ + %249 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> + %250 = stablehlo.add %248, %249 : tensor<9xf32> + %251 = stablehlo.divide %245, %250 : tensor<9xf32> +- %252 = stablehlo.add %243, %251 : tensor<9xf32> ++ %252 = stablehlo.add %251, %243 : tensor<9xf32> + %253 = stablehlo.constant dense<-1259.13916> : tensor + %254 = stablehlo.constant dense<-1259.13916> : tensor<9xf32> + %255 = stablehlo.constant dense<1.000000e+00> : tensor +@@ -586,7 +586,7 @@ + %343 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> + %344 = stablehlo.subtract %343, %338 : tensor<9xf32> + %345 = stablehlo.select %341, %344, %338 : tensor<9xi1>, tensor<9xf32> +- %346 = stablehlo.multiply %335, %345 : tensor<9xf32> ++ %346 = stablehlo.multiply %345, %335 : tensor<9xf32> + %347 = stablehlo.sine %346 : tensor<9xf32> + %348 = stablehlo.log %347 : tensor<9xf32> + %349 = stablehlo.is_finite %348 : (tensor<9xf32>) -> tensor<9xi1> +@@ -604,17 +604,17 @@ + %361 = stablehlo.add %359, %360 : tensor<9xf32> + %362 = stablehlo.constant dense<7.500000e+00> : tensor + %363 = stablehlo.constant dense<7.500000e+00> : tensor<9xf32> +- %364 = stablehlo.add %363, %359 : tensor<9xf32> ++ %364 = stablehlo.add %359, %363 : tensor<9xf32> + %365 = stablehlo.constant dense<2.01490307> : tensor + %366 = stablehlo.constant dense<2.01490307> : tensor<9xf32> + %367 = stablehlo.constant dense<7.500000e+00> : tensor<9xf32> + %368 = stablehlo.divide %359, %367 : tensor<9xf32> + %369 = stablehlo.log_plus_one %368 : tensor<9xf32> +- %370 = stablehlo.add %366, %369 : tensor<9xf32> ++ %370 = stablehlo.add %369, %366 : tensor<9xf32> + %371 = stablehlo.divide %364, %370 : tensor<9xf32> + %372 = stablehlo.subtract %361, %371 : tensor<9xf32> + %373 = stablehlo.multiply %372, %370 : tensor<9xf32> +- %374 = stablehlo.add %354, %373 : tensor<9xf32> ++ %374 = stablehlo.add %373, %354 : tensor<9xf32> + %375 = stablehlo.constant dense<1.000000e+00> : tensor + %376 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> + %377 = stablehlo.constant dense<676.520386> : tensor +@@ -625,7 +625,7 @@ + %382 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> + %383 = stablehlo.add %381, %382 : tensor<9xf32> + %384 = stablehlo.divide %378, %383 : tensor<9xf32> +- %385 = stablehlo.add %376, %384 : tensor<9xf32> ++ %385 = stablehlo.add %384, %376 : tensor<9xf32> + %386 = stablehlo.constant dense<-1259.13916> : tensor + %387 = stablehlo.constant dense<-1259.13916> : tensor<9xf32> + %388 = stablehlo.constant dense<1.000000e+00> : tensor +diff --ruN a/stablehlo/stablehlo/testdata/regularized_incomplete_beta__float16.mlir b/stablehlo/stablehlo/testdata/regularized_incomplete_beta__float16.mlir +--- stablehlo/stablehlo/testdata/regularized_incomplete_beta__float16.mlir ++++ stablehlo/stablehlo/testdata/regularized_incomplete_beta__float16.mlir +@@ -71,9 +71,9 @@ + %40 = stablehlo.multiply %38, %39 : tensor<9xf32> + %41 = stablehlo.constant dense<2.000000e+00> : tensor + %42 = stablehlo.constant dense<2.000000e+00> : tensor<9xf32> +- %43 = stablehlo.multiply %42, %32 : tensor<9xf32> ++ %43 = stablehlo.multiply %32, %42 : tensor<9xf32> + %44 = stablehlo.add %25, %43 : tensor<9xf32> +- %45 = stablehlo.multiply %42, %32 : tensor<9xf32> ++ %45 = stablehlo.multiply %32, %42 : tensor<9xf32> + %46 = stablehlo.add %25, %45 : tensor<9xf32> + %47 = stablehlo.constant dense<1.000000e+00> : tensor + %48 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> +@@ -83,10 +83,10 @@ + %52 = stablehlo.subtract %35, %32 : tensor<9xf32> + %53 = stablehlo.multiply %32, %52 : tensor<9xf32> + %54 = stablehlo.multiply %53, %39 : tensor<9xf32> +- %55 = stablehlo.multiply %42, %32 : tensor<9xf32> ++ %55 = stablehlo.multiply %32, %42 : tensor<9xf32> + %56 = stablehlo.add %25, %55 : tensor<9xf32> + %57 = stablehlo.subtract %56, %48 : tensor<9xf32> +- %58 = stablehlo.multiply %42, %32 : tensor<9xf32> ++ %58 = stablehlo.multiply %32, %42 : tensor<9xf32> + %59 = stablehlo.add %25, %58 : tensor<9xf32> + %60 = stablehlo.multiply %57, %59 : tensor<9xf32> + %61 = stablehlo.divide %54, %60 : tensor<9xf32> +@@ -225,9 +225,9 @@ + %502 = stablehlo.multiply %501, %iterArg_6 : tensor<9xf32> + %503 = stablehlo.constant dense<2.000000e+00> : tensor + %504 = stablehlo.constant dense<2.000000e+00> : tensor<9xf32> +- %505 = stablehlo.multiply %504, %496 : tensor<9xf32> ++ %505 = stablehlo.multiply %496, %504 : tensor<9xf32> + %506 = stablehlo.add %iterArg_4, %505 : tensor<9xf32> +- %507 = stablehlo.multiply %504, %496 : tensor<9xf32> ++ %507 = stablehlo.multiply %496, %504 : tensor<9xf32> + %508 = stablehlo.add %iterArg_4, %507 : tensor<9xf32> + %509 = stablehlo.constant dense<1.000000e+00> : tensor + %510 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> +@@ -237,10 +237,10 @@ + %514 = stablehlo.subtract %iterArg_5, %496 : tensor<9xf32> + %515 = stablehlo.multiply %496, %514 : tensor<9xf32> + %516 = stablehlo.multiply %515, %iterArg_6 : tensor<9xf32> +- %517 = stablehlo.multiply %504, %496 : tensor<9xf32> ++ %517 = stablehlo.multiply %496, %504 : tensor<9xf32> + %518 = stablehlo.add %iterArg_4, %517 : tensor<9xf32> + %519 = stablehlo.subtract %518, %510 : tensor<9xf32> +- %520 = stablehlo.multiply %504, %496 : tensor<9xf32> ++ %520 = stablehlo.multiply %496, %504 : tensor<9xf32> + %521 = stablehlo.add %iterArg_4, %520 : tensor<9xf32> + %522 = stablehlo.multiply %519, %521 : tensor<9xf32> + %523 = stablehlo.divide %516, %522 : tensor<9xf32> +@@ -322,7 +322,7 @@ + %79 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> + %80 = stablehlo.subtract %79, %74 : tensor<9xf32> + %81 = stablehlo.select %77, %80, %74 : tensor<9xi1>, tensor<9xf32> +- %82 = stablehlo.multiply %71, %81 : tensor<9xf32> ++ %82 = stablehlo.multiply %81, %71 : tensor<9xf32> + %83 = stablehlo.sine %82 : tensor<9xf32> + %84 = stablehlo.log %83 : tensor<9xf32> + %85 = stablehlo.is_finite %84 : (tensor<9xf32>) -> tensor<9xi1> +@@ -340,17 +340,17 @@ + %97 = stablehlo.add %95, %96 : tensor<9xf32> + %98 = stablehlo.constant dense<7.500000e+00> : tensor + %99 = stablehlo.constant dense<7.500000e+00> : tensor<9xf32> +- %100 = stablehlo.add %99, %95 : tensor<9xf32> ++ %100 = stablehlo.add %95, %99 : tensor<9xf32> + %101 = stablehlo.constant dense<2.01490307> : tensor + %102 = stablehlo.constant dense<2.01490307> : tensor<9xf32> + %103 = stablehlo.constant dense<7.500000e+00> : tensor<9xf32> + %104 = stablehlo.divide %95, %103 : tensor<9xf32> + %105 = stablehlo.log_plus_one %104 : tensor<9xf32> +- %106 = stablehlo.add %102, %105 : tensor<9xf32> ++ %106 = stablehlo.add %105, %102 : tensor<9xf32> + %107 = stablehlo.divide %100, %106 : tensor<9xf32> + %108 = stablehlo.subtract %97, %107 : tensor<9xf32> + %109 = stablehlo.multiply %108, %106 : tensor<9xf32> +- %110 = stablehlo.add %90, %109 : tensor<9xf32> ++ %110 = stablehlo.add %109, %90 : tensor<9xf32> + %111 = stablehlo.constant dense<1.000000e+00> : tensor + %112 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> + %113 = stablehlo.constant dense<676.520386> : tensor +@@ -361,7 +361,7 @@ + %118 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> + %119 = stablehlo.add %117, %118 : tensor<9xf32> + %120 = stablehlo.divide %114, %119 : tensor<9xf32> +- %121 = stablehlo.add %112, %120 : tensor<9xf32> ++ %121 = stablehlo.add %120, %112 : tensor<9xf32> + %122 = stablehlo.constant dense<-1259.13916> : tensor + %123 = stablehlo.constant dense<-1259.13916> : tensor<9xf32> + %124 = stablehlo.constant dense<1.000000e+00> : tensor +@@ -453,7 +453,7 @@ + %210 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> + %211 = stablehlo.subtract %210, %205 : tensor<9xf32> + %212 = stablehlo.select %208, %211, %205 : tensor<9xi1>, tensor<9xf32> +- %213 = stablehlo.multiply %202, %212 : tensor<9xf32> ++ %213 = stablehlo.multiply %212, %202 : tensor<9xf32> + %214 = stablehlo.sine %213 : tensor<9xf32> + %215 = stablehlo.log %214 : tensor<9xf32> + %216 = stablehlo.is_finite %215 : (tensor<9xf32>) -> tensor<9xi1> +@@ -471,17 +471,17 @@ + %228 = stablehlo.add %226, %227 : tensor<9xf32> + %229 = stablehlo.constant dense<7.500000e+00> : tensor + %230 = stablehlo.constant dense<7.500000e+00> : tensor<9xf32> +- %231 = stablehlo.add %230, %226 : tensor<9xf32> ++ %231 = stablehlo.add %226, %230 : tensor<9xf32> + %232 = stablehlo.constant dense<2.01490307> : tensor + %233 = stablehlo.constant dense<2.01490307> : tensor<9xf32> + %234 = stablehlo.constant dense<7.500000e+00> : tensor<9xf32> + %235 = stablehlo.divide %226, %234 : tensor<9xf32> + %236 = stablehlo.log_plus_one %235 : tensor<9xf32> +- %237 = stablehlo.add %233, %236 : tensor<9xf32> ++ %237 = stablehlo.add %236, %233 : tensor<9xf32> + %238 = stablehlo.divide %231, %237 : tensor<9xf32> + %239 = stablehlo.subtract %228, %238 : tensor<9xf32> + %240 = stablehlo.multiply %239, %237 : tensor<9xf32> +- %241 = stablehlo.add %221, %240 : tensor<9xf32> ++ %241 = stablehlo.add %240, %221 : tensor<9xf32> + %242 = stablehlo.constant dense<1.000000e+00> : tensor + %243 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> + %244 = stablehlo.constant dense<676.520386> : tensor +@@ -492,7 +492,7 @@ + %249 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> + %250 = stablehlo.add %248, %249 : tensor<9xf32> + %251 = stablehlo.divide %245, %250 : tensor<9xf32> +- %252 = stablehlo.add %243, %251 : tensor<9xf32> ++ %252 = stablehlo.add %251, %243 : tensor<9xf32> + %253 = stablehlo.constant dense<-1259.13916> : tensor + %254 = stablehlo.constant dense<-1259.13916> : tensor<9xf32> + %255 = stablehlo.constant dense<1.000000e+00> : tensor +@@ -586,7 +586,7 @@ + %343 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> + %344 = stablehlo.subtract %343, %338 : tensor<9xf32> + %345 = stablehlo.select %341, %344, %338 : tensor<9xi1>, tensor<9xf32> +- %346 = stablehlo.multiply %335, %345 : tensor<9xf32> ++ %346 = stablehlo.multiply %345, %335 : tensor<9xf32> + %347 = stablehlo.sine %346 : tensor<9xf32> + %348 = stablehlo.log %347 : tensor<9xf32> + %349 = stablehlo.is_finite %348 : (tensor<9xf32>) -> tensor<9xi1> +@@ -604,17 +604,17 @@ + %361 = stablehlo.add %359, %360 : tensor<9xf32> + %362 = stablehlo.constant dense<7.500000e+00> : tensor + %363 = stablehlo.constant dense<7.500000e+00> : tensor<9xf32> +- %364 = stablehlo.add %363, %359 : tensor<9xf32> ++ %364 = stablehlo.add %359, %363 : tensor<9xf32> + %365 = stablehlo.constant dense<2.01490307> : tensor + %366 = stablehlo.constant dense<2.01490307> : tensor<9xf32> + %367 = stablehlo.constant dense<7.500000e+00> : tensor<9xf32> + %368 = stablehlo.divide %359, %367 : tensor<9xf32> + %369 = stablehlo.log_plus_one %368 : tensor<9xf32> +- %370 = stablehlo.add %366, %369 : tensor<9xf32> ++ %370 = stablehlo.add %369, %366 : tensor<9xf32> + %371 = stablehlo.divide %364, %370 : tensor<9xf32> + %372 = stablehlo.subtract %361, %371 : tensor<9xf32> + %373 = stablehlo.multiply %372, %370 : tensor<9xf32> +- %374 = stablehlo.add %354, %373 : tensor<9xf32> ++ %374 = stablehlo.add %373, %354 : tensor<9xf32> + %375 = stablehlo.constant dense<1.000000e+00> : tensor + %376 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> + %377 = stablehlo.constant dense<676.520386> : tensor +@@ -625,7 +625,7 @@ + %382 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> + %383 = stablehlo.add %381, %382 : tensor<9xf32> + %384 = stablehlo.divide %378, %383 : tensor<9xf32> +- %385 = stablehlo.add %376, %384 : tensor<9xf32> ++ %385 = stablehlo.add %384, %376 : tensor<9xf32> + %386 = stablehlo.constant dense<-1259.13916> : tensor + %387 = stablehlo.constant dense<-1259.13916> : tensor<9xf32> + %388 = stablehlo.constant dense<1.000000e+00> : tensor +diff --ruN a/stablehlo/stablehlo/testdata/regularized_incomplete_beta__float32.mlir b/stablehlo/stablehlo/testdata/regularized_incomplete_beta__float32.mlir +--- stablehlo/stablehlo/testdata/regularized_incomplete_beta__float32.mlir ++++ stablehlo/stablehlo/testdata/regularized_incomplete_beta__float32.mlir +@@ -71,9 +71,9 @@ + %40 = stablehlo.multiply %38, %39 : tensor<9xf32> + %41 = stablehlo.constant dense<2.000000e+00> : tensor + %42 = stablehlo.constant dense<2.000000e+00> : tensor<9xf32> +- %43 = stablehlo.multiply %42, %32 : tensor<9xf32> ++ %43 = stablehlo.multiply %32, %42 : tensor<9xf32> + %44 = stablehlo.add %25, %43 : tensor<9xf32> +- %45 = stablehlo.multiply %42, %32 : tensor<9xf32> ++ %45 = stablehlo.multiply %32, %42 : tensor<9xf32> + %46 = stablehlo.add %25, %45 : tensor<9xf32> + %47 = stablehlo.constant dense<1.000000e+00> : tensor + %48 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> +@@ -83,10 +83,10 @@ + %52 = stablehlo.subtract %35, %32 : tensor<9xf32> + %53 = stablehlo.multiply %32, %52 : tensor<9xf32> + %54 = stablehlo.multiply %53, %39 : tensor<9xf32> +- %55 = stablehlo.multiply %42, %32 : tensor<9xf32> ++ %55 = stablehlo.multiply %32, %42 : tensor<9xf32> + %56 = stablehlo.add %25, %55 : tensor<9xf32> + %57 = stablehlo.subtract %56, %48 : tensor<9xf32> +- %58 = stablehlo.multiply %42, %32 : tensor<9xf32> ++ %58 = stablehlo.multiply %32, %42 : tensor<9xf32> + %59 = stablehlo.add %25, %58 : tensor<9xf32> + %60 = stablehlo.multiply %57, %59 : tensor<9xf32> + %61 = stablehlo.divide %54, %60 : tensor<9xf32> +@@ -222,9 +222,9 @@ + %498 = stablehlo.multiply %497, %iterArg_6 : tensor<9xf32> + %499 = stablehlo.constant dense<2.000000e+00> : tensor + %500 = stablehlo.constant dense<2.000000e+00> : tensor<9xf32> +- %501 = stablehlo.multiply %500, %492 : tensor<9xf32> ++ %501 = stablehlo.multiply %492, %500 : tensor<9xf32> + %502 = stablehlo.add %iterArg_4, %501 : tensor<9xf32> +- %503 = stablehlo.multiply %500, %492 : tensor<9xf32> ++ %503 = stablehlo.multiply %492, %500 : tensor<9xf32> + %504 = stablehlo.add %iterArg_4, %503 : tensor<9xf32> + %505 = stablehlo.constant dense<1.000000e+00> : tensor + %506 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> +@@ -234,10 +234,10 @@ + %510 = stablehlo.subtract %iterArg_5, %492 : tensor<9xf32> + %511 = stablehlo.multiply %492, %510 : tensor<9xf32> + %512 = stablehlo.multiply %511, %iterArg_6 : tensor<9xf32> +- %513 = stablehlo.multiply %500, %492 : tensor<9xf32> ++ %513 = stablehlo.multiply %492, %500 : tensor<9xf32> + %514 = stablehlo.add %iterArg_4, %513 : tensor<9xf32> + %515 = stablehlo.subtract %514, %506 : tensor<9xf32> +- %516 = stablehlo.multiply %500, %492 : tensor<9xf32> ++ %516 = stablehlo.multiply %492, %500 : tensor<9xf32> + %517 = stablehlo.add %iterArg_4, %516 : tensor<9xf32> + %518 = stablehlo.multiply %515, %517 : tensor<9xf32> + %519 = stablehlo.divide %512, %518 : tensor<9xf32> +@@ -319,7 +319,7 @@ + %76 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> + %77 = stablehlo.subtract %76, %71 : tensor<9xf32> + %78 = stablehlo.select %74, %77, %71 : tensor<9xi1>, tensor<9xf32> +- %79 = stablehlo.multiply %68, %78 : tensor<9xf32> ++ %79 = stablehlo.multiply %78, %68 : tensor<9xf32> + %80 = stablehlo.sine %79 : tensor<9xf32> + %81 = stablehlo.log %80 : tensor<9xf32> + %82 = stablehlo.is_finite %81 : (tensor<9xf32>) -> tensor<9xi1> +@@ -337,17 +337,17 @@ + %94 = stablehlo.add %92, %93 : tensor<9xf32> + %95 = stablehlo.constant dense<7.500000e+00> : tensor + %96 = stablehlo.constant dense<7.500000e+00> : tensor<9xf32> +- %97 = stablehlo.add %96, %92 : tensor<9xf32> ++ %97 = stablehlo.add %92, %96 : tensor<9xf32> + %98 = stablehlo.constant dense<2.01490307> : tensor + %99 = stablehlo.constant dense<2.01490307> : tensor<9xf32> + %100 = stablehlo.constant dense<7.500000e+00> : tensor<9xf32> + %101 = stablehlo.divide %92, %100 : tensor<9xf32> + %102 = stablehlo.log_plus_one %101 : tensor<9xf32> +- %103 = stablehlo.add %99, %102 : tensor<9xf32> ++ %103 = stablehlo.add %102, %99 : tensor<9xf32> + %104 = stablehlo.divide %97, %103 : tensor<9xf32> + %105 = stablehlo.subtract %94, %104 : tensor<9xf32> + %106 = stablehlo.multiply %105, %103 : tensor<9xf32> +- %107 = stablehlo.add %87, %106 : tensor<9xf32> ++ %107 = stablehlo.add %106, %87 : tensor<9xf32> + %108 = stablehlo.constant dense<1.000000e+00> : tensor + %109 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> + %110 = stablehlo.constant dense<676.520386> : tensor +@@ -358,7 +358,7 @@ + %115 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> + %116 = stablehlo.add %114, %115 : tensor<9xf32> + %117 = stablehlo.divide %111, %116 : tensor<9xf32> +- %118 = stablehlo.add %109, %117 : tensor<9xf32> ++ %118 = stablehlo.add %117, %109 : tensor<9xf32> + %119 = stablehlo.constant dense<-1259.13916> : tensor + %120 = stablehlo.constant dense<-1259.13916> : tensor<9xf32> + %121 = stablehlo.constant dense<1.000000e+00> : tensor +@@ -450,7 +450,7 @@ + %207 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> + %208 = stablehlo.subtract %207, %202 : tensor<9xf32> + %209 = stablehlo.select %205, %208, %202 : tensor<9xi1>, tensor<9xf32> +- %210 = stablehlo.multiply %199, %209 : tensor<9xf32> ++ %210 = stablehlo.multiply %209, %199 : tensor<9xf32> + %211 = stablehlo.sine %210 : tensor<9xf32> + %212 = stablehlo.log %211 : tensor<9xf32> + %213 = stablehlo.is_finite %212 : (tensor<9xf32>) -> tensor<9xi1> +@@ -468,17 +468,17 @@ + %225 = stablehlo.add %223, %224 : tensor<9xf32> + %226 = stablehlo.constant dense<7.500000e+00> : tensor + %227 = stablehlo.constant dense<7.500000e+00> : tensor<9xf32> +- %228 = stablehlo.add %227, %223 : tensor<9xf32> ++ %228 = stablehlo.add %223, %227 : tensor<9xf32> + %229 = stablehlo.constant dense<2.01490307> : tensor + %230 = stablehlo.constant dense<2.01490307> : tensor<9xf32> + %231 = stablehlo.constant dense<7.500000e+00> : tensor<9xf32> + %232 = stablehlo.divide %223, %231 : tensor<9xf32> + %233 = stablehlo.log_plus_one %232 : tensor<9xf32> +- %234 = stablehlo.add %230, %233 : tensor<9xf32> ++ %234 = stablehlo.add %233, %230 : tensor<9xf32> + %235 = stablehlo.divide %228, %234 : tensor<9xf32> + %236 = stablehlo.subtract %225, %235 : tensor<9xf32> + %237 = stablehlo.multiply %236, %234 : tensor<9xf32> +- %238 = stablehlo.add %218, %237 : tensor<9xf32> ++ %238 = stablehlo.add %237, %218 : tensor<9xf32> + %239 = stablehlo.constant dense<1.000000e+00> : tensor + %240 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> + %241 = stablehlo.constant dense<676.520386> : tensor +@@ -489,7 +489,7 @@ + %246 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> + %247 = stablehlo.add %245, %246 : tensor<9xf32> + %248 = stablehlo.divide %242, %247 : tensor<9xf32> +- %249 = stablehlo.add %240, %248 : tensor<9xf32> ++ %249 = stablehlo.add %248, %240 : tensor<9xf32> + %250 = stablehlo.constant dense<-1259.13916> : tensor + %251 = stablehlo.constant dense<-1259.13916> : tensor<9xf32> + %252 = stablehlo.constant dense<1.000000e+00> : tensor +@@ -583,7 +583,7 @@ + %340 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> + %341 = stablehlo.subtract %340, %335 : tensor<9xf32> + %342 = stablehlo.select %338, %341, %335 : tensor<9xi1>, tensor<9xf32> +- %343 = stablehlo.multiply %332, %342 : tensor<9xf32> ++ %343 = stablehlo.multiply %342, %332 : tensor<9xf32> + %344 = stablehlo.sine %343 : tensor<9xf32> + %345 = stablehlo.log %344 : tensor<9xf32> + %346 = stablehlo.is_finite %345 : (tensor<9xf32>) -> tensor<9xi1> +@@ -601,17 +601,17 @@ + %358 = stablehlo.add %356, %357 : tensor<9xf32> + %359 = stablehlo.constant dense<7.500000e+00> : tensor + %360 = stablehlo.constant dense<7.500000e+00> : tensor<9xf32> +- %361 = stablehlo.add %360, %356 : tensor<9xf32> ++ %361 = stablehlo.add %356, %360 : tensor<9xf32> + %362 = stablehlo.constant dense<2.01490307> : tensor + %363 = stablehlo.constant dense<2.01490307> : tensor<9xf32> + %364 = stablehlo.constant dense<7.500000e+00> : tensor<9xf32> + %365 = stablehlo.divide %356, %364 : tensor<9xf32> + %366 = stablehlo.log_plus_one %365 : tensor<9xf32> +- %367 = stablehlo.add %363, %366 : tensor<9xf32> ++ %367 = stablehlo.add %366, %363 : tensor<9xf32> + %368 = stablehlo.divide %361, %367 : tensor<9xf32> + %369 = stablehlo.subtract %358, %368 : tensor<9xf32> + %370 = stablehlo.multiply %369, %367 : tensor<9xf32> +- %371 = stablehlo.add %351, %370 : tensor<9xf32> ++ %371 = stablehlo.add %370, %351 : tensor<9xf32> + %372 = stablehlo.constant dense<1.000000e+00> : tensor + %373 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> + %374 = stablehlo.constant dense<676.520386> : tensor +@@ -622,7 +622,7 @@ + %379 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> + %380 = stablehlo.add %378, %379 : tensor<9xf32> + %381 = stablehlo.divide %375, %380 : tensor<9xf32> +- %382 = stablehlo.add %373, %381 : tensor<9xf32> ++ %382 = stablehlo.add %381, %373 : tensor<9xf32> + %383 = stablehlo.constant dense<-1259.13916> : tensor + %384 = stablehlo.constant dense<-1259.13916> : tensor<9xf32> + %385 = stablehlo.constant dense<1.000000e+00> : tensor +diff --ruN a/stablehlo/stablehlo/testdata/sinh_shape_bfloat16_20_20.mlir b/stablehlo/stablehlo/testdata/sinh_shape_bfloat16_20_20.mlir +--- stablehlo/stablehlo/testdata/sinh_shape_bfloat16_20_20.mlir ++++ stablehlo/stablehlo/testdata/sinh_shape_bfloat16_20_20.mlir +@@ -19,7 +19,7 @@ + %13 = stablehlo.add %10, %11 : tensor<20x20xf32> + %14 = stablehlo.divide %10, %13 : tensor<20x20xf32> + %15 = stablehlo.add %10, %14 : tensor<20x20xf32> +- %16 = stablehlo.multiply %12, %15 : tensor<20x20xf32> ++ %16 = stablehlo.multiply %15, %12 : tensor<20x20xf32> + %17 = stablehlo.abs %2 : tensor<20x20xf32> + %18 = stablehlo.compare LT, %17, %11 : (tensor<20x20xf32>, tensor<20x20xf32>) -> tensor<20x20xi1> + %19 = stablehlo.select %18, %16, %9 : tensor<20x20xi1>, tensor<20x20xf32> +diff --ruN a/stablehlo/stablehlo/testdata/sinh_shape_float16_20_20.mlir b/stablehlo/stablehlo/testdata/sinh_shape_float16_20_20.mlir +--- stablehlo/stablehlo/testdata/sinh_shape_float16_20_20.mlir ++++ stablehlo/stablehlo/testdata/sinh_shape_float16_20_20.mlir +@@ -19,7 +19,7 @@ + %13 = stablehlo.add %10, %11 : tensor<20x20xf32> + %14 = stablehlo.divide %10, %13 : tensor<20x20xf32> + %15 = stablehlo.add %10, %14 : tensor<20x20xf32> +- %16 = stablehlo.multiply %12, %15 : tensor<20x20xf32> ++ %16 = stablehlo.multiply %15, %12 : tensor<20x20xf32> + %17 = stablehlo.abs %2 : tensor<20x20xf32> + %18 = stablehlo.compare LT, %17, %11 : (tensor<20x20xf32>, tensor<20x20xf32>) -> tensor<20x20xi1> + %19 = stablehlo.select %18, %16, %9 : tensor<20x20xi1>, tensor<20x20xf32> +diff --ruN a/stablehlo/stablehlo/testdata/sinh_shape_float32_20_20.mlir b/stablehlo/stablehlo/testdata/sinh_shape_float32_20_20.mlir +--- stablehlo/stablehlo/testdata/sinh_shape_float32_20_20.mlir ++++ stablehlo/stablehlo/testdata/sinh_shape_float32_20_20.mlir +@@ -18,7 +18,7 @@ + %12 = stablehlo.add %9, %10 : tensor<20x20xf32> + %13 = stablehlo.divide %9, %12 : tensor<20x20xf32> + %14 = stablehlo.add %9, %13 : tensor<20x20xf32> +- %15 = stablehlo.multiply %11, %14 : tensor<20x20xf32> ++ %15 = stablehlo.multiply %14, %11 : tensor<20x20xf32> + %16 = stablehlo.abs %0 : tensor<20x20xf32> + %17 = stablehlo.compare LT, %16, %10 : (tensor<20x20xf32>, tensor<20x20xf32>) -> tensor<20x20xi1> + %18 = stablehlo.select %17, %15, %8 : tensor<20x20xi1>, tensor<20x20xf32> +diff --ruN a/stablehlo/stablehlo/testdata/slice_in_dim_limit_neg_dynamic.mlir b/stablehlo/stablehlo/testdata/slice_in_dim_limit_neg_dynamic.mlir +--- stablehlo/stablehlo/testdata/slice_in_dim_limit_neg_dynamic.mlir ++++ stablehlo/stablehlo/testdata/slice_in_dim_limit_neg_dynamic.mlir +@@ -3,7 +3,7 @@ + module @jit_fun_flat_jax { + func.func public @main(%arg0: tensor, %arg1: tensor {mhlo.sharding = ""}) -> tensor { + %0 = stablehlo.constant dense<-1> : tensor +- %1 = stablehlo.add %0, %arg0 : tensor ++ %1 = stablehlo.add %arg0, %0 : tensor + %2 = stablehlo.constant dense<0> : tensor<1xi32> + %3 = stablehlo.constant dense<0> : tensor<1xi32> + %4 = stablehlo.concatenate %2, %3, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> +diff --ruN a/stablehlo/stablehlo/testdata/slice_in_dim_start_neg_dynamic.mlir b/stablehlo/stablehlo/testdata/slice_in_dim_start_neg_dynamic.mlir +--- stablehlo/stablehlo/testdata/slice_in_dim_start_neg_dynamic.mlir ++++ stablehlo/stablehlo/testdata/slice_in_dim_start_neg_dynamic.mlir +@@ -3,7 +3,7 @@ + module @jit_fun_flat_jax { + func.func public @main(%arg0: tensor, %arg1: tensor {mhlo.sharding = ""}) -> tensor<1x4xf32> { + %0 = stablehlo.constant dense<-1> : tensor +- %1 = stablehlo.add %0, %arg0 : tensor ++ %1 = stablehlo.add %arg0, %0 : tensor + %2 = stablehlo.convert %1 : (tensor) -> tensor + %3 = stablehlo.reshape %2 : (tensor) -> tensor<1xi32> + %4 = stablehlo.constant dense<0> : tensor<1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/take__enable_xla_True_dynamic.mlir b/stablehlo/stablehlo/testdata/take__enable_xla_True_dynamic.mlir +--- stablehlo/stablehlo/testdata/take__enable_xla_True_dynamic.mlir ++++ stablehlo/stablehlo/testdata/take__enable_xla_True_dynamic.mlir +@@ -29,7 +29,7 @@ + %20 = stablehlo.compare LT, %8, %19, SIGNED : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> + %21 = stablehlo.constant dense<3> : tensor + %22 = stablehlo.broadcast_in_dim %21, dims = [] : (tensor) -> tensor<1xi64> +- %23 = stablehlo.add %8, %22 : tensor<1xi64> ++ %23 = stablehlo.add %22, %8 : tensor<1xi64> + %24 = stablehlo.select %20, %23, %8 : tensor<1xi1>, tensor<1xi64> + %25 = stablehlo.convert %24 : (tensor<1xi64>) -> tensor<1xi32> + %26 = stablehlo.broadcast_in_dim %25, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> +@@ -46,7 +46,7 @@ + %37 = stablehlo.compare LT, %9, %36, SIGNED : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> + %38 = stablehlo.constant dense<3> : tensor + %39 = stablehlo.broadcast_in_dim %38, dims = [] : (tensor) -> tensor<1xi64> +- %40 = stablehlo.add %9, %39 : tensor<1xi64> ++ %40 = stablehlo.add %39, %9 : tensor<1xi64> + %41 = stablehlo.select %37, %40, %9 : tensor<1xi1>, tensor<1xi64> + %42 = stablehlo.convert %41 : (tensor<1xi64>) -> tensor<1xi32> + %43 = stablehlo.broadcast_in_dim %42, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/take_along_axis_0_dynamic.mlir b/stablehlo/stablehlo/testdata/take_along_axis_0_dynamic.mlir +--- stablehlo/stablehlo/testdata/take_along_axis_0_dynamic.mlir ++++ stablehlo/stablehlo/testdata/take_along_axis_0_dynamic.mlir +@@ -34,7 +34,7 @@ + %25 = stablehlo.compare LT, %16, %24, SIGNED : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> + %26 = stablehlo.constant dense<2> : tensor + %27 = stablehlo.broadcast_in_dim %26, dims = [] : (tensor) -> tensor<1xi64> +- %28 = stablehlo.add %16, %27 : tensor<1xi64> ++ %28 = stablehlo.add %27, %16 : tensor<1xi64> + %29 = stablehlo.select %25, %28, %16 : tensor<1xi1>, tensor<1xi64> + %30 = stablehlo.convert %29 : (tensor<1xi64>) -> tensor<1xi32> + %31 = stablehlo.broadcast_in_dim %30, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> +@@ -49,7 +49,7 @@ + %40 = stablehlo.compare LT, %17, %39, SIGNED : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> + %41 = stablehlo.constant dense<2> : tensor + %42 = stablehlo.broadcast_in_dim %41, dims = [] : (tensor) -> tensor<1xi64> +- %43 = stablehlo.add %17, %42 : tensor<1xi64> ++ %43 = stablehlo.add %42, %17 : tensor<1xi64> + %44 = stablehlo.select %40, %43, %17 : tensor<1xi1>, tensor<1xi64> + %45 = stablehlo.convert %44 : (tensor<1xi64>) -> tensor<1xi32> + %46 = stablehlo.broadcast_in_dim %45, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/take_along_axis_1_dynamic.mlir b/stablehlo/stablehlo/testdata/take_along_axis_1_dynamic.mlir +--- stablehlo/stablehlo/testdata/take_along_axis_1_dynamic.mlir ++++ stablehlo/stablehlo/testdata/take_along_axis_1_dynamic.mlir +@@ -47,7 +47,7 @@ + %38 = stablehlo.compare LT, %29, %37, SIGNED : (tensor<2xi64>, tensor<2xi64>) -> tensor<2xi1> + %39 = stablehlo.constant dense<2> : tensor + %40 = stablehlo.broadcast_in_dim %39, dims = [] : (tensor) -> tensor<2xi64> +- %41 = stablehlo.add %29, %40 : tensor<2xi64> ++ %41 = stablehlo.add %40, %29 : tensor<2xi64> + %42 = stablehlo.select %38, %41, %29 : tensor<2xi1>, tensor<2xi64> + %43 = stablehlo.convert %42 : (tensor<2xi64>) -> tensor<2xi32> + %44 = stablehlo.broadcast_in_dim %43, dims = [0] : (tensor<2xi32>) -> tensor<2x1xi32> +@@ -62,7 +62,7 @@ + %53 = stablehlo.compare LT, %30, %52, SIGNED : (tensor<2xi64>, tensor<2xi64>) -> tensor<2xi1> + %54 = stablehlo.constant dense<2> : tensor + %55 = stablehlo.broadcast_in_dim %54, dims = [] : (tensor) -> tensor<2xi64> +- %56 = stablehlo.add %30, %55 : tensor<2xi64> ++ %56 = stablehlo.add %55, %30 : tensor<2xi64> + %57 = stablehlo.select %53, %56, %30 : tensor<2xi1>, tensor<2xi64> + %58 = stablehlo.convert %57 : (tensor<2xi64>) -> tensor<2xi32> + %59 = stablehlo.broadcast_in_dim %58, dims = [0] : (tensor<2xi32>) -> tensor<2x1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/vmap_gather_dtypes_shape_float32_10__axis_0_enable_xla_True_dynamic.mlir b/stablehlo/stablehlo/testdata/vmap_gather_dtypes_shape_float32_10__axis_0_enable_xla_True_dynamic.mlir +--- stablehlo/stablehlo/testdata/vmap_gather_dtypes_shape_float32_10__axis_0_enable_xla_True_dynamic.mlir ++++ stablehlo/stablehlo/testdata/vmap_gather_dtypes_shape_float32_10__axis_0_enable_xla_True_dynamic.mlir +@@ -41,7 +41,7 @@ + %32 = stablehlo.compare LT, %22, %31, SIGNED : (tensor<2xi64>, tensor<2xi64>) -> tensor<2xi1> + %33 = stablehlo.constant dense<2> : tensor + %34 = stablehlo.broadcast_in_dim %33, dims = [] : (tensor) -> tensor<2xi64> +- %35 = stablehlo.add %22, %34 : tensor<2xi64> ++ %35 = stablehlo.add %34, %22 : tensor<2xi64> + %36 = stablehlo.select %32, %35, %22 : tensor<2xi1>, tensor<2xi64> + %37 = stablehlo.convert %36 : (tensor<2xi64>) -> tensor<2xi32> + %38 = stablehlo.broadcast_in_dim %37, dims = [0] : (tensor<2xi32>) -> tensor<2x1xi32> +@@ -56,7 +56,7 @@ + %47 = stablehlo.compare LT, %23, %46, SIGNED : (tensor<2xi64>, tensor<2xi64>) -> tensor<2xi1> + %48 = stablehlo.constant dense<2> : tensor + %49 = stablehlo.broadcast_in_dim %48, dims = [] : (tensor) -> tensor<2xi64> +- %50 = stablehlo.add %23, %49 : tensor<2xi64> ++ %50 = stablehlo.add %49, %23 : tensor<2xi64> + %51 = stablehlo.select %47, %50, %23 : tensor<2xi1>, tensor<2xi64> + %52 = stablehlo.convert %51 : (tensor<2xi64>) -> tensor<2xi32> + %53 = stablehlo.broadcast_in_dim %52, dims = [0] : (tensor<2xi32>) -> tensor<2x1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/vmap_gather_from_take_indices_name__1__axis_0_enable_xla_True_mode_fill_dynamic.mlir b/stablehlo/stablehlo/testdata/vmap_gather_from_take_indices_name__1__axis_0_enable_xla_True_mode_fill_dynamic.mlir +--- stablehlo/stablehlo/testdata/vmap_gather_from_take_indices_name__1__axis_0_enable_xla_True_mode_fill_dynamic.mlir ++++ stablehlo/stablehlo/testdata/vmap_gather_from_take_indices_name__1__axis_0_enable_xla_True_mode_fill_dynamic.mlir +@@ -45,7 +45,7 @@ + %36 = stablehlo.compare LT, %22, %35, SIGNED : (tensor<2xi64>, tensor<2xi64>) -> tensor<2xi1> + %37 = stablehlo.constant dense<4> : tensor + %38 = stablehlo.broadcast_in_dim %37, dims = [] : (tensor) -> tensor<2xi64> +- %39 = stablehlo.add %22, %38 : tensor<2xi64> ++ %39 = stablehlo.add %38, %22 : tensor<2xi64> + %40 = stablehlo.select %36, %39, %22 : tensor<2xi1>, tensor<2xi64> + %41 = stablehlo.convert %40 : (tensor<2xi64>) -> tensor<2xi32> + %42 = stablehlo.broadcast_in_dim %41, dims = [0] : (tensor<2xi32>) -> tensor<2x1xi32> +@@ -64,7 +64,7 @@ + %55 = stablehlo.compare LT, %23, %54, SIGNED : (tensor<2xi64>, tensor<2xi64>) -> tensor<2xi1> + %56 = stablehlo.constant dense<4> : tensor + %57 = stablehlo.broadcast_in_dim %56, dims = [] : (tensor) -> tensor<2xi64> +- %58 = stablehlo.add %23, %57 : tensor<2xi64> ++ %58 = stablehlo.add %57, %23 : tensor<2xi64> + %59 = stablehlo.select %55, %58, %23 : tensor<2xi1>, tensor<2xi64> + %60 = stablehlo.convert %59 : (tensor<2xi64>) -> tensor<2xi32> + %61 = stablehlo.broadcast_in_dim %60, dims = [0] : (tensor<2xi32>) -> tensor<2x1xi32> +diff --ruN a/stablehlo/stablehlo/testdata/vmap_gather_from_take_indices_name__1__axis_2_enable_xla_True_mode_fill_dynamic.mlir b/stablehlo/stablehlo/testdata/vmap_gather_from_take_indices_name__1__axis_2_enable_xla_True_mode_fill_dynamic.mlir +--- stablehlo/stablehlo/testdata/vmap_gather_from_take_indices_name__1__axis_2_enable_xla_True_mode_fill_dynamic.mlir ++++ stablehlo/stablehlo/testdata/vmap_gather_from_take_indices_name__1__axis_2_enable_xla_True_mode_fill_dynamic.mlir +@@ -45,7 +45,7 @@ + %36 = stablehlo.compare LT, %22, %35, SIGNED : (tensor<2xi64>, tensor<2xi64>) -> tensor<2xi1> + %37 = stablehlo.constant dense<4> : tensor + %38 = stablehlo.broadcast_in_dim %37, dims = [] : (tensor) -> tensor<2xi64> +- %39 = stablehlo.add %22, %38 : tensor<2xi64> ++ %39 = stablehlo.add %38, %22 : tensor<2xi64> + %40 = stablehlo.select %36, %39, %22 : tensor<2xi1>, tensor<2xi64> + %41 = stablehlo.convert %40 : (tensor<2xi64>) -> tensor<2xi32> + %42 = stablehlo.broadcast_in_dim %41, dims = [0] : (tensor<2xi32>) -> tensor<2x1xi32> +@@ -64,7 +64,7 @@ + %55 = stablehlo.compare LT, %23, %54, SIGNED : (tensor<2xi64>, tensor<2xi64>) -> tensor<2xi1> + %56 = stablehlo.constant dense<4> : tensor + %57 = stablehlo.broadcast_in_dim %56, dims = [] : (tensor) -> tensor<2xi64> +- %58 = stablehlo.add %23, %57 : tensor<2xi64> ++ %58 = stablehlo.add %57, %23 : tensor<2xi64> + %59 = stablehlo.select %55, %58, %23 : tensor<2xi1>, tensor<2xi64> + %60 = stablehlo.convert %59 : (tensor<2xi64>) -> tensor<2xi32> + %61 = stablehlo.broadcast_in_dim %60, dims = [0] : (tensor<2xi32>) -> tensor<2x1xi32> +diff --ruN a/stablehlo/stablehlo/tests/stablehlo_canonicalize_dynamism.mlir b/stablehlo/stablehlo/tests/stablehlo_canonicalize_dynamism.mlir +--- stablehlo/stablehlo/tests/stablehlo_canonicalize_dynamism.mlir ++++ stablehlo/stablehlo/tests/stablehlo_canonicalize_dynamism.mlir +@@ -426,6 +426,172 @@ + + // ----- + ++// CHECK-LABEL: func @dynamic_reduce_window_success_static_result_type ++func.func @dynamic_reduce_window_success_static_result_type(%arg0: tensor<3x2xf32>, %arg1: tensor) -> tensor<2x2xf32> { ++ // CHECK-NOT: stablehlo.dynamic_reduce_window ++ // CHECK: "stablehlo.reduce_window"(%arg0, %arg1) ({ ++ // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG2:arg.*]]: tensor, %[[ARG3:arg.*]]: tensor): ++ // CHECK-NEXT: %[[VAL1:.*]] = stablehlo.add %arg2, %arg3 : tensor ++ // CHECK-NEXT: stablehlo.return %[[VAL1]] : tensor ++ // CHECK-NEXT: }) { ++ // CHECK-SAME: base_dilations = dense<[2, 1]> : tensor<2xi64>, ++ // CHECK-SAME{LITERAL}: padding = dense<[[2, 1], [0, 0]]> : tensor<2x2xi64>, ++ // CHECK-SAME: window_dilations = dense<[3, 1]> : tensor<2xi64>, ++ // CHECK-SAME: window_dimensions = dense<[2, 1]> : tensor<2xi64>, ++ // CHECK-SAME: window_strides = dense<[4, 1]> : tensor<2xi64> ++ // CHECK-SAME: } : (tensor<3x2xf32>, tensor) -> tensor<2x2xf32> ++ %0 = stablehlo.constant dense<[2, 1]> : tensor<2xi64> ++ %1 = stablehlo.constant dense<[4, 1]> : tensor<2xi64> ++ %2 = stablehlo.constant dense<[2, 1]> : tensor<2xi64> ++ %3 = stablehlo.constant dense<[3, 1]> : tensor<2xi64> ++ %4 = stablehlo.constant dense<[[2, 1], [0, 0]]> : tensor<2x2xi64> ++ %5 = stablehlo.custom_call @stablehlo.dynamic_reduce_window(%arg0, %arg1, %0, %1, %2, %3, %4) { ++ called_computations = [@dynamic_reduce_window0] ++ } : (tensor<3x2xf32>, tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2x2xi64>) -> tensor<2x2xf32> ++ func.return %5 : tensor<2x2xf32> ++} ++ ++func.func private @dynamic_reduce_window0(%arg0: tensor, %arg1: tensor) -> tensor { ++ %0 = stablehlo.add %arg0, %arg1 : tensor ++ func.return %0 : tensor ++} ++ ++// ----- ++ ++// CHECK-LABEL: func @dynamic_reduce_window_success_dynamic_result_type ++func.func @dynamic_reduce_window_success_dynamic_result_type(%arg0: tensor, %arg1: tensor) -> tensor { ++ // CHECK-NOT: stablehlo.dynamic_reduce_window ++ // CHECK: "stablehlo.reduce_window"(%arg0, %arg1) ({ ++ // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG2:arg.*]]: tensor, %[[ARG3:arg.*]]: tensor): ++ // CHECK-NEXT: %[[VAL1:.*]] = stablehlo.add %arg2, %arg3 : tensor ++ // CHECK-NEXT: stablehlo.return %[[VAL1]] : tensor ++ // CHECK-NEXT: }) { ++ // CHECK-SAME: base_dilations = dense<[2, 1]> : tensor<2xi64>, ++ // CHECK-SAME{LITERAL}: padding = dense<[[2, 1], [0, 0]]> : tensor<2x2xi64>, ++ // CHECK-SAME: window_dilations = dense<[3, 1]> : tensor<2xi64>, ++ // CHECK-SAME: window_dimensions = dense<[2, 1]> : tensor<2xi64>, ++ // CHECK-SAME: window_strides = dense<[4, 1]> : tensor<2xi64> ++ // CHECK-SAME: } : (tensor, tensor) -> tensor ++ %0 = stablehlo.constant dense<[2, 1]> : tensor<2xi64> ++ %1 = stablehlo.constant dense<[4, 1]> : tensor<2xi64> ++ %2 = stablehlo.constant dense<[2, 1]> : tensor<2xi64> ++ %3 = stablehlo.constant dense<[3, 1]> : tensor<2xi64> ++ %4 = stablehlo.constant dense<[[2, 1], [0, 0]]> : tensor<2x2xi64> ++ %5 = stablehlo.custom_call @stablehlo.dynamic_reduce_window(%arg0, %arg1, %0, %1, %2, %3, %4) { ++ called_computations = [@dynamic_reduce_window0] ++ } : (tensor, tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2x2xi64>) -> tensor ++ func.return %5 : tensor ++} ++ ++func.func private @dynamic_reduce_window0(%arg0: tensor, %arg1: tensor) -> tensor { ++ %0 = stablehlo.add %arg0, %arg1 : tensor ++ func.return %0 : tensor ++} ++ ++// TODO(burmako): Implement tests for verification failures for dynamic_reduce_window. ++ ++// ----- ++ ++// CHECK-LABEL: func @dynamic_reduce_window_inapplicable_dynamic_window_dimensions ++func.func @dynamic_reduce_window_inapplicable_dynamic_window_dimensions(%arg0: tensor<3x2xf32>, %arg1: tensor, %arg2: tensor<2xi64>) -> tensor<2x2xf32> { ++ // CHECK: stablehlo.dynamic_reduce_window ++ %0 = stablehlo.constant dense<[4, 1]> : tensor<2xi64> ++ %1 = stablehlo.constant dense<[2, 1]> : tensor<2xi64> ++ %2 = stablehlo.constant dense<[3, 1]> : tensor<2xi64> ++ %3 = stablehlo.constant dense<[[2, 1], [0, 0]]> : tensor<2x2xi64> ++ %4 = stablehlo.custom_call @stablehlo.dynamic_reduce_window(%arg0, %arg1, %arg2, %0, %1, %2, %3) { ++ called_computations = [@dynamic_reduce_window0] ++ } : (tensor<3x2xf32>, tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2x2xi64>) -> tensor<2x2xf32> ++ func.return %4 : tensor<2x2xf32> ++} ++ ++func.func private @dynamic_reduce_window0(%arg0: tensor, %arg1: tensor) -> tensor { ++ %0 = stablehlo.add %arg0, %arg1 : tensor ++ func.return %0 : tensor ++} ++ ++// ----- ++ ++// CHECK-LABEL: func @dynamic_reduce_window_inapplicable_dynamic_window_strides ++func.func @dynamic_reduce_window_inapplicable_dynamic_window_strides(%arg0: tensor<3x2xf32>, %arg1: tensor, %arg2: tensor<2xi64>) -> tensor<2x2xf32> { ++ // CHECK: stablehlo.dynamic_reduce_window ++ %0 = stablehlo.constant dense<[2, 1]> : tensor<2xi64> ++ %1 = stablehlo.constant dense<[2, 1]> : tensor<2xi64> ++ %2 = stablehlo.constant dense<[3, 1]> : tensor<2xi64> ++ %3 = stablehlo.constant dense<[[2, 1], [0, 0]]> : tensor<2x2xi64> ++ %4 = stablehlo.custom_call @stablehlo.dynamic_reduce_window(%arg0, %arg1, %0, %arg2, %1, %2, %3) { ++ called_computations = [@dynamic_reduce_window0] ++ } : (tensor<3x2xf32>, tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2x2xi64>) -> tensor<2x2xf32> ++ func.return %4 : tensor<2x2xf32> ++} ++ ++func.func private @dynamic_reduce_window0(%arg0: tensor, %arg1: tensor) -> tensor { ++ %0 = stablehlo.add %arg0, %arg1 : tensor ++ func.return %0 : tensor ++} ++ ++// ----- ++ ++// CHECK-LABEL: func @dynamic_reduce_window_inapplicable_dynamic_base_dilations ++func.func @dynamic_reduce_window_inapplicable_dynamic_base_dilations(%arg0: tensor<3x2xf32>, %arg1: tensor, %arg2: tensor<2xi64>) -> tensor<2x2xf32> { ++ // CHECK: stablehlo.dynamic_reduce_window ++ %0 = stablehlo.constant dense<[2, 1]> : tensor<2xi64> ++ %1 = stablehlo.constant dense<[4, 1]> : tensor<2xi64> ++ %2 = stablehlo.constant dense<[3, 1]> : tensor<2xi64> ++ %3 = stablehlo.constant dense<[[2, 1], [0, 0]]> : tensor<2x2xi64> ++ %4 = stablehlo.custom_call @stablehlo.dynamic_reduce_window(%arg0, %arg1, %0, %1, %arg2, %2, %3) { ++ called_computations = [@dynamic_reduce_window0] ++ } : (tensor<3x2xf32>, tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2x2xi64>) -> tensor<2x2xf32> ++ func.return %4 : tensor<2x2xf32> ++} ++ ++func.func private @dynamic_reduce_window0(%arg0: tensor, %arg1: tensor) -> tensor { ++ %0 = stablehlo.add %arg0, %arg1 : tensor ++ func.return %0 : tensor ++} ++ ++// ----- ++ ++// CHECK-LABEL: func @dynamic_reduce_window_inapplicable_dynamic_window_dilations ++func.func @dynamic_reduce_window_inapplicable_dynamic_window_dilations(%arg0: tensor<3x2xf32>, %arg1: tensor, %arg2: tensor<2xi64>) -> tensor<2x2xf32> { ++ // CHECK: stablehlo.dynamic_reduce_window ++ %0 = stablehlo.constant dense<[2, 1]> : tensor<2xi64> ++ %1 = stablehlo.constant dense<[4, 1]> : tensor<2xi64> ++ %2 = stablehlo.constant dense<[2, 1]> : tensor<2xi64> ++ %3 = stablehlo.constant dense<[[2, 1], [0, 0]]> : tensor<2x2xi64> ++ %4 = stablehlo.custom_call @stablehlo.dynamic_reduce_window(%arg0, %arg1, %0, %1, %2, %arg2, %3) { ++ called_computations = [@dynamic_reduce_window0] ++ } : (tensor<3x2xf32>, tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2x2xi64>) -> tensor<2x2xf32> ++ func.return %4 : tensor<2x2xf32> ++} ++ ++func.func private @dynamic_reduce_window0(%arg0: tensor, %arg1: tensor) -> tensor { ++ %0 = stablehlo.add %arg0, %arg1 : tensor ++ func.return %0 : tensor ++} ++ ++// ----- ++ ++// CHECK-LABEL: func @dynamic_reduce_window_inapplicable_dynamic_padding ++func.func @dynamic_reduce_window_inapplicable_dynamic_padding(%arg0: tensor<3x2xf32>, %arg1: tensor, %arg2: tensor<2x2xi64>) -> tensor<2x2xf32> { ++ // CHECK: stablehlo.dynamic_reduce_window ++ %0 = stablehlo.constant dense<[2, 1]> : tensor<2xi64> ++ %1 = stablehlo.constant dense<[4, 1]> : tensor<2xi64> ++ %2 = stablehlo.constant dense<[2, 1]> : tensor<2xi64> ++ %3 = stablehlo.constant dense<[3, 1]> : tensor<2xi64> ++ %4 = stablehlo.custom_call @stablehlo.dynamic_reduce_window(%arg0, %arg1, %0, %1, %2, %3, %arg2) { ++ called_computations = [@dynamic_reduce_window0] ++ } : (tensor<3x2xf32>, tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2x2xi64>) -> tensor<2x2xf32> ++ func.return %4 : tensor<2x2xf32> ++} ++ ++func.func private @dynamic_reduce_window0(%arg0: tensor, %arg1: tensor) -> tensor { ++ %0 = stablehlo.add %arg0, %arg1 : tensor ++ func.return %0 : tensor ++} ++ ++// ----- ++ + // CHECK-LABEL: func @dynamic_reshape_success + func.func @dynamic_reshape_success(%arg0: tensor<4xf32>) -> tensor<1x4xf32> { + // CHECK-NOT: stablehlo.dynamic_reshape +@@ -452,6 +618,44 @@ + %0 = stablehlo.constant dense<[1, 4]> : tensor<2xi64> + %1 = stablehlo.dynamic_reshape %arg0, %0 : (tensor<4xf32>, tensor<2xi64>) -> tensor<1x?xf32> + return %1 : tensor<1x?xf32> ++} ++ ++// ----- ++ ++// CHECK-LABEL: func @dynamic_rng_bit_generator_success ++func.func @dynamic_rng_bit_generator_success(%arg0: tensor<2xui64>) -> tensor<1x4xf32> { ++ // CHECK-NOT: stablehlo.dynamic_rng_bit_generator ++ // CHECK: stablehlo.rng_bit_generator %arg0, algorithm = DEFAULT : (tensor<2xui64>) -> (tensor<2xui64>, tensor<1x4xf32>) ++ %0 = stablehlo.constant dense<[1, 4]> : tensor<2xi64> ++ %1:2 = stablehlo.custom_call @stablehlo.dynamic_rng_bit_generator(%arg0, %0) { ++ rng_algorithm = #stablehlo ++ } : (tensor<2xui64>, tensor<2xi64>) -> (tensor<2xui64>, tensor<1x4xf32>) ++ return %1#1 : tensor<1x4xf32> ++} ++ ++// TODO(burmako): Implement tests for verification failures for dynamic_rng_bit_generator. ++ ++// ----- ++ ++// CHECK-LABEL: func @dynamic_rng_bit_generator_inapplicable_dynamic_output_shape ++func.func @dynamic_rng_bit_generator_inapplicable_dynamic_output_shape(%arg0: tensor<2xui64>, %arg1: tensor<2xi64>) -> tensor<1x4xf32> { ++ // CHECK: stablehlo.dynamic_rng_bit_generator ++ %1:2 = stablehlo.custom_call @stablehlo.dynamic_rng_bit_generator(%arg0, %arg1) { ++ rng_algorithm = #stablehlo ++ } : (tensor<2xui64>, tensor<2xi64>) -> (tensor<2xui64>, tensor<1x4xf32>) ++ return %1#1 : tensor<1x4xf32> ++} ++ ++// ----- ++ ++// CHECK-LABEL: func @dynamic_rng_bit_generator_inapplicable_dynamic_output_type ++func.func @dynamic_rng_bit_generator_inapplicable_dynamic_output_type(%arg0: tensor<2xui64>) -> tensor { ++ // CHECK: stablehlo.dynamic_rng_bit_generator ++ %0 = stablehlo.constant dense<[1, 4]> : tensor<2xi64> ++ %1:2 = stablehlo.custom_call @stablehlo.dynamic_rng_bit_generator(%arg0, %0) { ++ rng_algorithm = #stablehlo ++ } : (tensor<2xui64>, tensor<2xi64>) -> (tensor<2xui64>, tensor) ++ return %1#1 : tensor + } + + // ----- +diff --ruN a/stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir b/stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir +--- stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir ++++ stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir +@@ -607,12 +607,45 @@ + + // ----- + ++// CHECK-LABEL: @main ++func.func @main(%arg0: tensor<3x2xf32>, %arg1: tensor) -> tensor<*xf32> { ++ // CHECK: stablehlo.dynamic_reduce_window{{.*}} -> tensor<2x2xf32> ++ %0 = stablehlo.constant dense<[2, 1]> : tensor<2xi64> ++ %1 = stablehlo.constant dense<[4, 1]> : tensor<2xi64> ++ %2 = stablehlo.constant dense<[2, 1]> : tensor<2xi64> ++ %3 = stablehlo.constant dense<[3, 1]> : tensor<2xi64> ++ %4 = stablehlo.constant dense<[[2, 1], [0, 0]]> : tensor<2x2xi64> ++ %5 = stablehlo.custom_call @stablehlo.dynamic_reduce_window(%arg0, %arg1, %0, %1, %2, %3, %4) { ++ called_computations = [@dynamic_reduce_window0] ++ } : (tensor<3x2xf32>, tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2x2xi64>) -> tensor<*xf32> ++ func.return %5 : tensor<*xf32> ++} ++ ++func.func private @dynamic_reduce_window0(%arg0: tensor, %arg1: tensor) -> tensor { ++ %0 = stablehlo.add %arg0, %arg1 : tensor ++ func.return %0 : tensor ++} ++ ++// ----- ++ + // CHECK-LABEL: @refine_dynamic_reshape + func.func @refine_dynamic_reshape(%arg0: tensor<4xf32>) -> tensor<*xf32> { + // CHECK: stablehlo.dynamic_reshape{{.*}} -> tensor<1x4xf32> + %0 = stablehlo.constant dense<[1, 4]> : tensor<2xi64> + %1 = stablehlo.dynamic_reshape %arg0, %0 : (tensor<4xf32>, tensor<2xi64>) -> tensor<*xf32> + func.return %1 : tensor<*xf32> ++} ++ ++// ----- ++ ++// CHECK-LABEL: @refine_dynamic_rng_bit_generator ++func.func @refine_dynamic_rng_bit_generator(%arg0: tensor<2xui64>) -> (tensor, tensor<*xf32>) { ++ // CHECK: stablehlo.dynamic_rng_bit_generator{{.*}} -> (tensor<2xui64>, tensor<1x4xf32>) ++ %0 = stablehlo.constant dense<[1, 4]> : tensor<2xi64> ++ %1:2 = stablehlo.custom_call @stablehlo.dynamic_rng_bit_generator(%arg0, %0) { ++ rng_algorithm = #stablehlo ++ } : (tensor<2xui64>, tensor<2xi64>) -> (tensor, tensor<*xf32>) ++ func.return %1#0, %1#1 : tensor, tensor<*xf32> + } + + // ----- +diff --ruN a/stablehlo/stablehlo/transforms/StablehloCanonicalizeDynamism.cpp b/stablehlo/stablehlo/transforms/StablehloCanonicalizeDynamism.cpp +--- stablehlo/stablehlo/transforms/StablehloCanonicalizeDynamism.cpp ++++ stablehlo/stablehlo/transforms/StablehloCanonicalizeDynamism.cpp +@@ -24,6 +24,7 @@ + #include "mlir/Interfaces/InferTypeOpInterface.h" + #include "mlir/Support/LogicalResult.h" + #include "mlir/Transforms/GreedyPatternRewriteDriver.h" ++#include "stablehlo/dialect/ExperimentalOps.h" + #include "stablehlo/dialect/StablehloOps.h" + #include "stablehlo/transforms/Passes.h" + +@@ -198,6 +199,54 @@ + } + }; + ++struct CanonicalizeDynamicReduceWindowOpPattern ++ : public OpRewritePattern { ++ using OpRewritePattern::OpRewritePattern; ++ LogicalResult matchAndRewrite(CustomCallOp impl, ++ PatternRewriter& rewriter) const override { ++ auto maybeOp = getDynamicReduceWindowOp(impl); ++ if (!maybeOp || failed(maybeOp->verify())) return failure(); ++ DynamicReduceWindowOpAdaptor op = *maybeOp; ++ ++ // ReduceWindowOp supports dynamic shapes for operands and results, so we ++ // don't check for that here unlike in some other patterns in this pass. ++ SmallVector windowDimensions, windowStrides, baseDilations, ++ windowDilations, padding; ++ if (failed(hlo::matchInts(op.getWindowDimensions(), windowDimensions))) ++ return rewriter.notifyMatchFailure(op, ++ "expected static window_dimensions"); ++ if (failed(hlo::matchInts(op.getWindowStrides(), windowStrides))) ++ return rewriter.notifyMatchFailure(op, "expected static window_strides"); ++ if (failed(hlo::matchInts(op.getBaseDilations(), baseDilations))) ++ return rewriter.notifyMatchFailure(op, "expected static base_dilations"); ++ if (failed(hlo::matchInts(op.getWindowDilations(), windowDilations))) ++ return rewriter.notifyMatchFailure(op, ++ "expected static window_dilations"); ++ if (failed(hlo::matchInts(op.getPadding(), padding))) ++ return rewriter.notifyMatchFailure(op, "expected static padding"); ++ auto newOp = rewriter.create( ++ op->getLoc(), op->getResultTypes(), op.getInputs(), op.getInitValues(), ++ rewriter.getI64TensorAttr(windowDimensions), ++ rewriter.getI64TensorAttr(windowStrides), ++ rewriter.getI64TensorAttr(baseDilations), ++ rewriter.getI64TensorAttr(windowDilations), ++ hlo::getPaddingAttr(&rewriter, padding)); ++ ++ // Inline the called computation into newOp. ++ // This is somewhat annoying because we also have to rewrite the original ++ // func::ReturnOp into stablehlo::ReturnOp. ++ rewriter.cloneRegionBefore(op.getBody(), newOp.getBody(), ++ newOp.getBody().end()); ++ auto funcReturnOp = ++ cast(newOp.getBody().front().getTerminator()); ++ rewriter.setInsertionPointToEnd(&newOp.getBody().front()); ++ rewriter.replaceOpWithNewOp( ++ funcReturnOp, funcReturnOp.getOperands()); ++ rewriter.replaceOp(op, newOp->getResults()); ++ return success(); ++ } ++}; ++ + struct CanonicalizeDynamicReshapeOpPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; +@@ -210,6 +259,27 @@ + if (!op.getType().hasStaticShape()) + return rewriter.notifyMatchFailure(op, "expected static result type"); + rewriter.replaceOpWithNewOp(op, op.getType(), op.getOperand()); ++ return success(); ++ } ++}; ++ ++struct CanonicalizeDynamicRngBitGeneratorOpPattern ++ : public OpRewritePattern { ++ using OpRewritePattern::OpRewritePattern; ++ LogicalResult matchAndRewrite(CustomCallOp impl, ++ PatternRewriter& rewriter) const override { ++ auto maybeOp = getDynamicRngBitGeneratorOp(impl); ++ if (!maybeOp || failed(maybeOp->verify())) return failure(); ++ DynamicRngBitGeneratorOpAdaptor op = *maybeOp; ++ ++ // This pattern ignores and discards the output_shape operand. We rely on ++ // the verifier to make sure that its value is consistent with result type. ++ if (!succeeded(hlo::matchInts(op.getOutputShape()))) ++ return rewriter.notifyMatchFailure(op, "expected static output_shape"); ++ if (!op.getOutput().getType().cast().hasStaticShape()) ++ return rewriter.notifyMatchFailure(op, "expected static output type"); ++ rewriter.replaceOpWithNewOp( ++ op, op->getResultTypes(), op.getRngAlgorithm(), op.getInitialState()); + return success(); + } + }; +@@ -320,7 +390,9 @@ + patterns.add(&getContext()); + patterns.add(&getContext()); + patterns.add(&getContext()); ++ patterns.add(&getContext()); + patterns.add(&getContext()); ++ patterns.add(&getContext()); + patterns.add( + &getContext()); + patterns.add(&getContext()); +diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp +--- stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp ++++ stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp +@@ -43,6 +43,7 @@ + #include "mlir/Transforms/GreedyPatternRewriteDriver.h" + #include "stablehlo/dialect/Base.h" + #include "stablehlo/dialect/ChloOps.h" ++#include "stablehlo/dialect/ExperimentalOps.h" + #include "stablehlo/dialect/StablehloOps.h" + #include "stablehlo/dialect/TypeInference.h" + #include "stablehlo/transforms/Passes.h" +@@ -844,12 +845,78 @@ + } + }; + ++struct RefineDynamicReduceWindowOpPattern ++ : public OpRewritePattern { ++ using OpRewritePattern::OpRewritePattern; ++ LogicalResult matchAndRewrite(CustomCallOp impl, ++ PatternRewriter& rewriter) const override { ++ auto maybeOp = getDynamicReduceWindowOp(impl); ++ if (!maybeOp || failed(maybeOp->verify())) return failure(); ++ DynamicReduceWindowOpAdaptor op = *maybeOp; ++ ++ // At the moment, we only support refining return types using fully static ++ // shape values which serves the current use cases well. ++ // Support for partially static shape values is left for future work. ++ SmallVector windowDimensions, windowStrides, baseDilations, ++ windowDilations, padding; ++ if (failed(hlo::matchInts(op.getWindowDimensions(), windowDimensions))) ++ return rewriter.notifyMatchFailure(op, ++ "expected constant window_dimensions"); ++ if (failed(hlo::matchInts(op.getWindowStrides(), windowStrides))) ++ return rewriter.notifyMatchFailure(op, ++ "expected constant window_strides"); ++ if (failed(hlo::matchInts(op.getBaseDilations(), baseDilations))) ++ return rewriter.notifyMatchFailure(op, ++ "expected constant base_dilations"); ++ if (failed(hlo::matchInts(op.getWindowDilations(), windowDilations))) ++ return rewriter.notifyMatchFailure(op, ++ "expected constant window_dilations"); ++ if (failed(hlo::matchInts(op.getPadding(), padding))) ++ return rewriter.notifyMatchFailure(op, "expected constant padding"); ++ ++ SmallVector inferredReturnTypes; ++ if (failed(hlo::inferReduceWindowOp( ++ /*location=*/{}, op.getInputs(), op.getInitValues(), ++ rewriter.getI64TensorAttr(windowDimensions), ++ rewriter.getI64TensorAttr(windowStrides), ++ rewriter.getI64TensorAttr(baseDilations), ++ rewriter.getI64TensorAttr(windowDilations), ++ hlo::getPaddingAttr(&rewriter, padding), inferredReturnTypes))) ++ return rewriter.notifyMatchFailure(op, "inferReduceWindowOp failed"); ++ return refineReturnTypes(rewriter, op, inferredReturnTypes); ++ } ++}; ++ + struct RefineDynamicReshapeOpPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(DynamicReshapeOp op, + PatternRewriter& rewriter) const override { + return refineReturnShape(rewriter, op, op.getOutputShape()); ++ } ++}; ++ ++struct RefineDynamicRngBitGeneratorOpPattern ++ : public OpRewritePattern { ++ using OpRewritePattern::OpRewritePattern; ++ LogicalResult matchAndRewrite(CustomCallOp impl, ++ PatternRewriter& rewriter) const override { ++ auto maybeOp = getDynamicRngBitGeneratorOp(impl); ++ if (!maybeOp || failed(maybeOp->verify())) return failure(); ++ DynamicRngBitGeneratorOpAdaptor op = *maybeOp; ++ ++ // At the moment, we only support refining return types using fully static ++ // shape values which serves the current use cases well. ++ // Support for partially static shape values is left for future work. ++ auto initialStateType = op.getInitialState().getType().cast(); ++ SmallVector outputShape; ++ if (failed(hlo::matchInts(op.getOutputShape(), outputShape))) ++ return rewriter.notifyMatchFailure(op, "expected constant output_shape"); ++ ++ // We only need to refine the shape of `output` (the second result). ++ // The shape of `output_state` (the first result) is determined by the shape ++ // of `initial_state`, so we ignore it and provide an empty refinement. ++ return refineReturnTypes(rewriter, op, {{initialStateType}, {outputShape}}); + } + }; + +@@ -1181,7 +1248,9 @@ + patterns.add(&getContext()); + patterns.add(&getContext()); + patterns.add(&getContext()); ++ patterns.add(&getContext()); + patterns.add(&getContext()); ++ patterns.add(&getContext()); + patterns.add(&getContext()); + patterns.add(&getContext()); + patterns.add(&getContext()); + From 4741d078d3669a1d8fe8d7b36d4a00305015f420 Mon Sep 17 00:00:00 2001 From: Amit Sabne Date: Thu, 27 Jul 2023 17:42:32 -0700 Subject: [PATCH 285/410] [XLA] Turn copy(elementwise(copy(x),broadcast(y))) into elementwise(x,copy(broadcast(y))) PiperOrigin-RevId: 551689944 --- .../xla/service/algebraic_simplifier.cc | 36 +++++++++++++++++++ .../xla/service/algebraic_simplifier_test.cc | 25 +++++++++++++ 2 files changed, 61 insertions(+) diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 386ad197f9ad7d..a39c17fb4a90f8 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -1550,6 +1550,42 @@ Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy) { } } } + + // Convert copy(elementwise(copy(x), broadcast(y))) -> + // elementwise(x,copy(broadcast(y)) + if (copy->operand(0)->IsElementwiseBinary() && + // Compare needs direction too, hence skipping. + copy->operand(0)->opcode() != HloOpcode::kCompare) { + auto ew_binary = copy->mutable_operand(0); + auto computation = copy->parent(); + HloInstruction* inner_op = nullptr; + HloInstruction* broadcast_input = nullptr; + HloInstruction* broadcast = nullptr; + HloInstruction* inner_copy = nullptr; + if ((Match(ew_binary->mutable_operand(0), + m::Copy(&inner_copy, m::Op(&inner_op))) && + Match(ew_binary->mutable_operand(1), + m::Broadcast(&broadcast, m::Op(&broadcast_input)))) || + (Match(ew_binary->mutable_operand(1), + m::Copy(&inner_copy, m::Op(&inner_op))) && + Match(ew_binary->mutable_operand(0), + m::Broadcast(&broadcast, m::Op(&broadcast_input))))) { + if (Shape::Equal().IgnoreMemorySpaceInLayout()( + copy->shape(), inner_copy->operand(0)->shape())) { + HloInstruction* new_broadcast = + computation->AddInstruction(HloInstruction::CreateBroadcast( + broadcast->shape(), broadcast_input, + broadcast->dimensions())); + HloInstruction* broadcast_copy = + computation->AddInstruction(HloInstruction::CreateUnary( + inner_op->shape(), HloOpcode::kCopy, new_broadcast)); + HloInstruction* new_ew_binary = computation->AddInstruction( + HloInstruction::CreateBinary(copy->shape(), ew_binary->opcode(), + inner_op, broadcast_copy)); + return ReplaceInstruction(copy, new_ew_binary); + } + } + } } return OkStatus(); diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 18eea15b0b1066..d7c86cf26fa06d 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -1195,6 +1195,31 @@ TEST_F(AlgebraicSimplifierTest, ArrayOvershootTest) { ASSERT_FALSE(simplifier.Run(m.get()).value()); } +TEST_F(AlgebraicSimplifierTest, CopyEwCopyBroadcast) { + const char* kModuleStr = R"( + HloModule m + test { + param0 = f32[32,32,8,32]{3,1,2,0} parameter(0) + param1 = f32[8,32]{1,0} parameter(1) + brd = f32[32,32,8,32]{3,2,1,0} broadcast(param1), dimensions={2,3} + cpy = f32[32,32,8,32]{3,2,1,0} copy(param0) + add = f32[32,32,8,32]{3,2,1,0} add(cpy, brd) + ROOT cpy2 = f32[32,32,8,32]{3,1,2,0} copy(add) + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + AlgebraicSimplifierOptions options; + options.set_is_layout_sensitive(true); + SCOPED_TRACE(m->ToString()); + AlgebraicSimplifier simplifier(options); + ASSERT_TRUE(simplifier.Run(m.get()).value()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Add(m::Parameter(0), + m::Copy(m::Broadcast(m::Parameter(1)))))); + SCOPED_TRACE(m->ToString()); +} + // Test that (A/B)/C is simplified to A/(B*C). TEST_F(AlgebraicSimplifierTest, LhsDivOfDiv) { auto m = CreateNewVerifiedModule(); From b644c21e8423845a7925ebe1b94488e1a89ed79d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 27 Jul 2023 18:12:58 -0700 Subject: [PATCH 286/410] Update TFRT dependency to use revision http://github.com/tensorflow/runtime/commit/1f5ac07bdba7f5803ac94c466a698914aa5ad7d0. PiperOrigin-RevId: 551695850 --- third_party/tf_runtime/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/tf_runtime/workspace.bzl b/third_party/tf_runtime/workspace.bzl index f5a5e9f61cab3c..73be79fe34728f 100644 --- a/third_party/tf_runtime/workspace.bzl +++ b/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "81aca58669e2db2c53168db65ecdf40add3dcd2d" - TFRT_SHA256 = "f9738cfe7e65b03dabdc5ba61da39a93bdf3719de48193f303e3bf475da813a4" + TFRT_COMMIT = "1f5ac07bdba7f5803ac94c466a698914aa5ad7d0" + TFRT_SHA256 = "589ef73620007a7c129ddfca0f34089073bef5fd9bf148d63be85f841e29e9b4" tf_http_archive( name = "tf_runtime", From 7677b1b561ac372525ddceb92ab246858f2cc52e Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Fri, 28 Jul 2023 01:37:18 +0000 Subject: [PATCH 287/410] Target Monterey (12.0) as the minimum compatible os for arm64 wheels --- .../tools/ci_build/osx/arm64/.macos.bazelrc | 40 +------------------ 1 file changed, 2 insertions(+), 38 deletions(-) diff --git a/tensorflow/tools/ci_build/osx/arm64/.macos.bazelrc b/tensorflow/tools/ci_build/osx/arm64/.macos.bazelrc index 375ff18da06083..3ddc77319db16b 100644 --- a/tensorflow/tools/ci_build/osx/arm64/.macos.bazelrc +++ b/tensorflow/tools/ci_build/osx/arm64/.macos.bazelrc @@ -17,8 +17,8 @@ build --define=tensorflow_mkldnn_contraction_kernel=0 # Settings for MacOS on ARM CPUs. build --cpu=darwin_arm64 -build --macos_minimum_os=11.0 -build --action_env MACOSX_DEPLOYMENT_TARGET=11.0 +build --macos_minimum_os=12.0 +build --action_env MACOSX_DEPLOYMENT_TARGET=12.0 # Test-related settings below this point. test --verbose_failures=true --local_test_jobs=HOST_CPUS --test_output=errors @@ -42,39 +42,3 @@ test:nonpip_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss test:nonpip_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:nonpip_filters --test_lang_filters=cc,py test:nonpip --config=nonpip_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xla/service/gpu/... -//tensorflow/compiler/xla/tools/multihost_hlo_runner/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/python/integration_testing/... -//tensorflow/tools/toolchains/... -//tensorflow/lite/... -//tensorflow/compiler/aot/... -//tensorflow/compiler/xla/tests:local_client_aot_test_computation -//tensorflow/compiler/xla/tests:local_client_aot_test_helper -//tensorflow/compiler/xla/tests:local_client_aot_test - -# "pip tests" run a similar suite of tests the "nonpip" tests, but do something -# odd to attempt to validate the quality of the pip package. The wheel is -# installed into a virtual environment, and then that venv is used to run all -# bazel tests with a special flag "--define=no_tensorflow_py_deps=true", which -# drops all the bazel dependencies for each py_test; this makes all the tests -# use the wheel's TensorFlow installation instead of the one made available -# through bazel. This must be done in a different root directory, //bazel_pip/..., -# because "import tensorflow" run from the root directory would instead import -# the folder instead of the venv package. -# -# Pass --config=pip to run the same suite of tests. If you want to run just one -# test for investigation, you'll need --config=pip_base instead, and then you -# can specify whichever target you want. -test:pip_base --define=no_tensorflow_py_deps=true -test:pip_filters --build_tag_filters=-nopip,-no_pip,-nomac,-no_mac,-no_oss,-oss_serial,-no_oss_py39,-no_oss_py310,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 -test:pip_filters --test_tag_filters=-nopip,-no_pip,-nomac,-no_mac,-no_oss,-oss_serial,-no_oss_py39,-no_oss_py310,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 -test:pip_filters --test_lang_filters=py -test:pip --config=pip_base --config=pip_filters -- //bazel_pip/tensorflow/python/... - -# For building libtensorflow archives -build:libtensorflow_filters --action_env TF_NEED_HDFS=0 -build:libtensorflow_filters --action_env TF_NEED_ROCM=0 --action_env TF_NEED_MKL=0 -build:libtensorflow_filters --action_env COMPUTECPP_PATH="/usr/local" -test:libtensorflow_test --config=libtensorflow_filters -- //tensorflow/tools/lib_package:libtensorflow_test //tensorflow/tools/lib_package:libtensorflow_java_test -build:libtensorflow_build --config=libtensorflow_filters -- //tensorflow/tools/lib_package:libtensorflow.tar.gz //tensorflow/tools/lib_package:libtensorflow_jni.tar.gz //tensorflow/java:libtensorflow.jar //tensorflow/java:libtensorflow-src.jar //tensorflow/tools/lib_package:libtensorflow_proto.zip - -# For continuous builds -# nodistinct_host_configuration saves building twice a lot of targets -test:continuous_filters --nodistinct_host_configuration --keep_going -test:continuous_filters --build_tests_only --test_output=errors --flaky_test_attempts=3 -test:continuous_filters --test_size_filters=small,medium --test_timeout=300,450,1200,3600 -test:continuous_filters --test_tag_filters=-no_oss,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 -test:continuous_filters --build_tag_filters=-no_oss,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 - -test:continuous --config=continuous_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/compiler/xrt/... -//tensorflow/core/tpu/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/python/integration_testing/... -//tensorflow/tools/toolchains/... -//tensorflow/lite/... -//tensorflow/compiler/aot/... -//tensorflow/compiler/xla/tests:local_client_aot_test_computation -//tensorflow/compiler/xla/tests:local_client_aot_test_helper -//tensorflow/compiler/xla/tests:local_client_aot_test \ No newline at end of file From b0d2cb4b6bd14351605c0d4a0633b23053ad561c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 27 Jul 2023 18:57:28 -0700 Subject: [PATCH 288/410] Add MLIR op definition for VirtualInfeedEnqueue op and VirtualInfeedDequeue op PiperOrigin-RevId: 551703953 --- .../mlir/tensorflow/tests/side-effect-analysis-test.mlir | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/side-effect-analysis-test.mlir b/tensorflow/compiler/mlir/tensorflow/tests/side-effect-analysis-test.mlir index 6cd1cf67c0a93d..77f585a1aab70d 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/side-effect-analysis-test.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/side-effect-analysis-test.mlir @@ -2876,4 +2876,4 @@ func.func @tpu_execute_effect( func.return // expected-remark@above {{ID: 6}} // expected-remark@above {{Sinks: {5}}} -} +} \ No newline at end of file From b8cc608717f6a4849d3f726f95b083150732b3c2 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 27 Jul 2023 19:13:39 -0700 Subject: [PATCH 289/410] Add a command line flag to provide a path to a file containing compilation environment proto for TPU. PiperOrigin-RevId: 551706484 --- tensorflow/compiler/xla/tools/run_hlo_module.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tensorflow/compiler/xla/tools/run_hlo_module.h b/tensorflow/compiler/xla/tools/run_hlo_module.h index 0a30eb8affd2ec..d1e4e5d7ed9296 100644 --- a/tensorflow/compiler/xla/tools/run_hlo_module.h +++ b/tensorflow/compiler/xla/tools/run_hlo_module.h @@ -48,6 +48,8 @@ struct RunHloModuleOptions { std::string input_format; std::string input_module; bool use_buffer_assignment_from_proto{false}; + // The format and the usage of the option is platform-dependent. + std::string input_compilation_environments; int iterations{1}; std::string output_literals_file; std::string input_literals_file; From bd083e3d86a8d5a8d9dd2eebedf25933469d3866 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 27 Jul 2023 19:18:24 -0700 Subject: [PATCH 290/410] Prevent redundant and excessive HLO dumps from the auto-sharding pass. PiperOrigin-RevId: 551707268 --- .../xla/hlo/experimental/auto_sharding/auto_sharding.cc | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding.cc index 7024ae30066361..6d51ba0e1764b6 100644 --- a/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -3989,10 +3989,6 @@ StatusOr AutoShardingImplementation::RunAutoSharding( // ----- Canonicalize layouts based on LayoutCanonicalizationCallback. ----- TF_RETURN_IF_ERROR(CanonicalizeLayouts(module)); - XLA_VLOG_LINES(7, absl::StrCat("After auto sharding for mesh ", - spmd::ToString(option_.device_mesh_shape), - ":\n", module->ToString())); - DumpHloModuleIfEnabled(*module, "after_auto_spmd_sharding"); return module_is_changed ? AutoShardingResult::kModuleChangedShardingPerformed : AutoShardingResult::kModuleUnchanged; From 0f9124670e30613a4ef2e1ad7264b57741cd6381 Mon Sep 17 00:00:00 2001 From: Songyi Han Date: Thu, 27 Jul 2023 19:41:48 -0700 Subject: [PATCH 291/410] Support preserving function alias in QuantizePtqDynamicRange This change adds function alias preserving support for weight-only and dynamic range quantization. PiperOrigin-RevId: 551710786 --- .../integration_test/quantize_model_test.py | 84 +++++++++++++++++-- .../python/pywrap_quantize_model.cc | 5 +- .../tensorflow/python/quantize_model.cc | 29 +++++-- .../tensorflow/python/quantize_model.h | 3 +- .../tensorflow/python/quantize_model.py | 7 ++ 5 files changed, 114 insertions(+), 14 deletions(-) diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py index a023c8e91483e5..5a455ac81baaa0 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py @@ -2376,8 +2376,13 @@ def data_gen() -> repr_dataset.RepresentativeDataset: if alias == 'conv_func': for func in meta_graph_def.graph_def.library.function: if func.signature.name == func_name: - self._contains_op_with_name_and_attribute( - func.node_def, op_name='XlaConvV2', attr_name='', attr_val=None + self.assertTrue( + self._contains_op_with_name_and_attribute( + func.node_def, + op_name='XlaConvV2', + attr_name='', + attr_val=None, + ) ) @test_util.run_in_graph_and_eager_modes @@ -2438,8 +2443,13 @@ def test_function_alias_preserved_in_qat(self): if alias == 'einsum_with_kernel': for func in meta_graph_def.graph_def.library.function: if func.signature.name == func_name: - self._contains_op_with_name_and_attribute( - func.node_def, op_name='XlaDotV2', attr_name='', attr_val=None + self.assertTrue( + self._contains_op_with_name_and_attribute( + func.node_def, + op_name='XlaDotV2', + attr_name='', + attr_val=None, + ) ) def test_matmul_ptq_model_with_unfreeze_constants(self): @@ -4530,7 +4540,8 @@ def test_conv_model_with_wrong_tags_raises_error(self): # StatusNotOk error. `Exception` is used here because importing # `StatusNotOk` may break the open-sourced version of TensorFlow. with self.assertRaisesRegex( - Exception, 'Failed to import SavedModel' + Exception, + 'could not be found in SavedModel, with available tags', ) as raises: quantize_model.quantize( self._input_saved_model_path, @@ -4541,7 +4552,7 @@ def test_conv_model_with_wrong_tags_raises_error(self): representative_dataset=data_gen, ) - self.assertEqual(raises.exception.__class__.__name__, 'StatusNotOk') + self.assertEqual(raises.exception.__class__.__name__, 'RuntimeError') @parameterized.named_parameters( ('quantize', True, 0), @@ -5285,6 +5296,67 @@ def test_gather_and_conv_model( self._output_saved_model_path, self._input_saved_model_path, 1 / 3 ) + @test_util.run_in_graph_and_eager_modes + def test_function_alias_preserved(self): + # Prepare test model + function_alias = 'conv_func' + tags = {tag_constants.SERVING} + input_type, filter_shape = dtypes.int64, (2, 3, 3, 2) + model = self._create_simple_gather_and_conv_model(input_type, filter_shape) + save_opts = save_options.SaveOptions( + function_aliases={function_alias: model.model} + ) + signatures = { + 'serving_default': model.model.get_concrete_function(), + } + saved_model_save.save( + model, self._input_saved_model_path, signatures, save_opts + ) + + # Quantize the model + quantization_options = quant_opts_pb2.QuantizationOptions( + quantization_method=quant_opts_pb2.QuantizationMethod( + experimental_method=quant_opts_pb2.QuantizationMethod.ExperimentalMethod.WEIGHT_ONLY + ), + op_set=quant_opts_pb2.XLA, + min_num_elements_for_weights=1, + ) + + converted_model = quantize_model.quantize( + self._input_saved_model_path, + ['serving_default'], + tags, + self._output_saved_model_path, + quantization_options, + ) + self.assertIsNotNone(converted_model) + self.assertCountEqual( + converted_model.signatures._signatures.keys(), {'serving_default'} + ) + + # Check if function alias is preserved + output_loader = saved_model_loader.SavedModelLoader( + self._output_saved_model_path + ) + meta_graph_def = output_loader.get_meta_graph_def_from_tags(tags) + function_aliases = meta_graph_def.meta_info_def.function_aliases + self.assertNotEmpty(function_aliases) + self.assertCountEqual(function_aliases.values(), {function_alias}) + + # Test that the aliased function contains a quantized op. + for func_name, alias in function_aliases.items(): + if alias == function_alias: + for func in meta_graph_def.graph_def.library.function: + if func.signature.name == func_name: + self.assertTrue( + self._contains_op_with_name_and_attribute( + func.node_def, + op_name='Const', + attr_name='dtype', + attr_val=attr_value_pb2.AttrValue(type=types_pb2.DT_INT8), + ) + ) + if __name__ == '__main__': test.main() diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.cc b/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.cc index 4b083d6f96c3be..a6d9436f011aac 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.cc @@ -180,10 +180,11 @@ PYBIND11_MODULE(pywrap_quantize_model, m) { [](const absl::string_view saved_model_path, const std::vector& signature_keys, const std::unordered_set& tags, - const QuantizationOptions& quant_opts) + const QuantizationOptions& quant_opts, + const absl::flat_hash_map& function_aliases) -> absl::StatusOr { return QuantizePtqDynamicRange(saved_model_path, signature_keys, tags, - quant_opts); + quant_opts, function_aliases); }, R"pbdoc( Returns serialized ExportedModel that contains the quantized model's diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc index f39e20c51dd269..ae0172d2294da7 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc @@ -649,7 +649,8 @@ absl::StatusOr QuantizePtqDynamicRange( const absl::string_view saved_model_path, const std::vector &signature_keys, const std::unordered_set &tags, - const QuantizationOptions &quantization_options) { + const QuantizationOptions &quantization_options, + const absl::flat_hash_map &function_aliases) { // Convert the SavedModelBundle to an MLIR module. mlir::MLIRContext context = CreateMlirContextForTfQuantization(); @@ -673,8 +674,27 @@ absl::StatusOr QuantizePtqDynamicRange( mlir::OwningOpRef module_ref = std::move(module).value(); - TF_QUANT_RETURN_IF_ERROR(PreprocessAndFreezeGraph( - module_ref.get(), &context, bundle ? bundle->GetSession() : nullptr)); + const absl::flat_hash_map updated_function_aliases = + UpdateFunctionAliases(function_aliases, *module_ref); + + // Collect the names of the functions that have aliases so that they may not + // be inlined. The mapping is mlir function name - user defined function + // alias for each value in the set. + absl::flat_hash_set aliased_function_names; + absl::c_for_each(updated_function_aliases, [&](const auto &aliases) { + return aliased_function_names.insert(aliases.first); + }); + + if (aliased_function_names.empty()) { + TF_QUANT_RETURN_IF_ERROR(PreprocessAndFreezeGraph( + module_ref.get(), &context, bundle ? bundle->GetSession() : nullptr)); + } else { + TF_QUANT_RETURN_IF_ERROR(PreprocessAndFreezeGraph( + /*mlir_dump_file_prefix=*/kDefaultTfQuantMlirDumpFilePrefix, + /*is_inliner_run=*/false, + /*noinline_functions=*/aliased_function_names, module_ref.get(), + &context, bundle ? bundle->GetSession() : nullptr)); + } TF_QUANT_RETURN_IF_ERROR(RunPasses( /*name=*/kTfQuantPtqDynamicRangeStepName, @@ -698,8 +718,7 @@ absl::StatusOr QuantizePtqDynamicRange( RunExportPasses(export_opts, context, *module_ref)); return ConvertMlirModuleToExportedModel( - *module_ref, checkpoint_dir, - /*function_aliases=*/{}, + *module_ref, checkpoint_dir, updated_function_aliases, {asset_file_defs.begin(), asset_file_defs.end()}); } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.h b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.h index 2344c108016a83..de64855ad8303f 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.h +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.h @@ -55,7 +55,8 @@ absl::StatusOr QuantizePtqDynamicRange( absl::string_view saved_model_path, const std::vector& signature_keys, const std::unordered_set& tags, - const QuantizationOptions& quantization_options); + const QuantizationOptions& quantization_options, + const absl::flat_hash_map& function_aliases); absl::StatusOr QuantizePtqModelPreCalibration( absl::string_view saved_model_path, diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py index 53dd34cfc30d68..c2f3a4d2ea69c0 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py @@ -942,12 +942,19 @@ def _dynamic_range_quantize( _DYNAMIC_RANGE_DEFAULT_MIN_NUM_ELEMENTS_FOR_WEIGHTS, ) + loader = saved_model_loader.SavedModelLoader(saved_model_path) + + function_aliases = loader.get_meta_graph_def_from_tags( + tags + ).meta_info_def.function_aliases + # Apply post-training dynamic range quantization to the model. exported_model_serialized = pywrap_quantize_model.quantize_ptq_dynamic_range( saved_model_path, list(signature_keys), set(tags), quantization_options.SerializeToString(), + dict(function_aliases), ) exported_model = exported_model_pb2.ExportedModel.FromString( From 81109af3e66209e2a274e50d20a5b418fb49357d Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Thu, 27 Jul 2023 21:31:10 -0700 Subject: [PATCH 292/410] Let the raw-ops page changes apply for dev (tf-nightly) and rc builds. PiperOrigin-RevId: 551729672 --- tensorflow/tools/docs/generate2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/tools/docs/generate2.py b/tensorflow/tools/docs/generate2.py index e0cbc588d0d343..d56e508f8bf2dc 100644 --- a/tensorflow/tools/docs/generate2.py +++ b/tensorflow/tools/docs/generate2.py @@ -44,7 +44,7 @@ from tensorflow.python.util import tf_export from tensorflow.python.util import tf_inspect -if version.parse(tf.__version__) >= version.parse("2.14"): +if version.parse(tf.__version__) >= version.parse("2.14-dev"): from tensorflow.python.util.pywrap_xla_ops import get_gpu_kernel_names # pylint: disable=g-import-not-at-top # Caution: the google and oss versions of this import are different. @@ -118,7 +118,7 @@ def build(self): # Skip the ModulePage implementation, which doesn't use a template. content = base_page.PageInfo.build(self) - if version.parse(tf.__version__) >= version.parse("2.14"): + if version.parse(tf.__version__) >= version.parse("2.14-dev"): raw_ops_doc = self.generate_raw_ops_doc_ge_214() else: raw_ops_doc = self.generate_raw_ops_doc_lt_214() From ab8c59cbfaeafc1809b632001916493f07da729a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 27 Jul 2023 22:46:48 -0700 Subject: [PATCH 293/410] Internal cleanup of an incorrect patch PiperOrigin-RevId: 551742663 --- third_party/stablehlo/temporary.patch: | 4513 ------------------------ 1 file changed, 4513 deletions(-) delete mode 100644 third_party/stablehlo/temporary.patch: diff --git a/third_party/stablehlo/temporary.patch: b/third_party/stablehlo/temporary.patch: deleted file mode 100644 index ddd7f6c1f59a86..00000000000000 --- a/third_party/stablehlo/temporary.patch: +++ /dev/null @@ -1,4513 +0,0 @@ -diff --ruN a/stablehlo/BUILD.bazel b/stablehlo/BUILD.bazel ---- stablehlo/BUILD.bazel -+++ stablehlo/BUILD.bazel -@@ -279,6 +279,24 @@ - ) - - cc_library( -+ name = "experimental_ops", -+ srcs = [ -+ "stablehlo/dialect/ExperimentalOps.cpp", -+ ], -+ hdrs = [ -+ "stablehlo/dialect/ExperimentalOps.h", -+ ], -+ strip_include_prefix = ".", -+ deps = [ -+ ":stablehlo_ops", -+ "@llvm-project//llvm:Support", -+ "@llvm-project//mlir:FuncDialect", -+ "@llvm-project//mlir:IR", -+ "@llvm-project//mlir:Support", -+ ], -+) -+ -+cc_library( - name = "reference_axes", - srcs = [ - "stablehlo/reference/Axes.cpp", -@@ -677,6 +695,7 @@ - deps = [ - ":base", - ":chlo_ops", -+ ":experimental_ops", - ":stablehlo_ops", - ":stablehlo_ops_inc_gen", - ":stablehlo_pass_inc_gen", -diff --ruN a/stablehlo/CMakeLists.txt b/stablehlo/CMakeLists.txt ---- stablehlo/CMakeLists.txt -+++ stablehlo/CMakeLists.txt -@@ -13,135 +13,20 @@ - # See the License for the specific language governing permissions and - # limitations under the License. - # --cmake_minimum_required(VERSION 3.15.0) - --if(POLICY CMP0068) -- cmake_policy(SET CMP0068 NEW) -- set(CMAKE_BUILD_WITH_INSTALL_NAME_DIR ON) --endif() -- --if(POLICY CMP0075) -- cmake_policy(SET CMP0075 NEW) --endif() -- --if(POLICY CMP0077) -- cmake_policy(SET CMP0077 NEW) --endif() -- --# CMP0116: Ninja generators transform `DEPFILE`s from `add_custom_command()` --# New in CMake 3.20. https://cmake.org/cmake/help/latest/policy/CMP0116.html --if(POLICY CMP0116) -- cmake_policy(SET CMP0116 OLD) --endif() -+# This build of StableHLO is meant to be embedded in MLIR-HLO. -+# As a result, its root CMakeLists.txt is different from the original -+# CMakeLists.txt from https://github.com/openxla/stablehlo. -+# All other files of this build of StableHLO except for this one are the same -+# as the original files. -+# To get access to a standalone build of StableHLO, check out the -+# openxla/stablehlo repository. - - #------------------------------------------------------------------------------- - # Options and settings - #------------------------------------------------------------------------------- --option(STABLEHLO_BUILD_EMBEDDED "Build StableHLO as part of another project" OFF) --option(STABLEHLO_ENABLE_BINDINGS_PYTHON "Enables StableHLO Python bindings" OFF) --option(STABLEHLO_ENABLE_STRICT_BUILD "Build StableHLO with strict warnings and warnings as errors" OFF) - --#------------------------------------------------------------------------------- --# Project setup and globals --#------------------------------------------------------------------------------- --set(STABLEHLO_EXTERNAL_PROJECT_BUILD OFF) -- --if(NOT (CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR) AND NOT MLIR_BINARY_DIR) -- # Building as part of LLVM via the external project mechanism. -- set(STABLEHLO_EXTERNAL_PROJECT_BUILD ON) --else() -- # Building standalone. -- project(stablehlo LANGUAGES CXX C) -- set(CMAKE_C_STANDARD 11) -- set(CMAKE_CXX_STANDARD 17) --endif() -- --# Build with ccache if the package is present --set(LLVM_CCACHE_BUILD OFF CACHE BOOL "Set to ON for a ccache enabled build") --if(LLVM_CCACHE_BUILD) -- find_program(CCACHE_PROGRAM ccache) -- if(CCACHE_PROGRAM) -- set(LLVM_CCACHE_MAXSIZE "" CACHE STRING "Size of ccache") -- set(LLVM_CCACHE_DIR "" CACHE STRING "Directory to keep ccached data") -- set(LLVM_CCACHE_PARAMS "CCACHE_CPP2=yes CCACHE_HASHDIR=yes" -- CACHE STRING "Parameters to pass through to ccache") -- -- set(CCACHE_PROGRAM "${LLVM_CCACHE_PARAMS} ${CCACHE_PROGRAM}") -- if (LLVM_CCACHE_MAXSIZE) -- set(CCACHE_PROGRAM "CCACHE_MAXSIZE=${LLVM_CCACHE_MAXSIZE} ${CCACHE_PROGRAM}") -- endif() -- if (LLVM_CCACHE_DIR) -- set(CCACHE_PROGRAM "CCACHE_DIR=${LLVM_CCACHE_DIR} ${CCACHE_PROGRAM}") -- endif() -- set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE ${CCACHE_PROGRAM}) -- else() -- message(FATAL_ERROR "Unable to find the program ccache. Set LLVM_CCACHE_BUILD to OFF") -- endif() --endif() -- --#------------------------------------------------------------------------------- --# MLIR/LLVM Configuration --#------------------------------------------------------------------------------- --if (STABLEHLO_ENABLE_STRICT_BUILD) -- set(LLVM_ENABLE_WARNINGS ON) -- set(LLVM_ENABLE_WERROR ON) -- set(LLVM_ENABLE_PEDANTIC ON) --endif() -- --# Find MLIR to install if we are building standalone. If building as part of --# another project, let it handle the MLIR dependency. The dependent project --# might use a bundled version of MLIR instead of installing, for instance. --if(STABLEHLO_EXTERNAL_PROJECT_BUILD) -- message(STATUS "Building StableHLO as an external LLVM project") -- set(MLIR_MAIN_SRC_DIR ${LLVM_MAIN_SRC_DIR}/../mlir ) # --src-root -- set(MLIR_INCLUDE_DIR ${MLIR_MAIN_SRC_DIR}/include ) # --includedir -- set(MLIR_GENERATED_INCLUDE_DIR ${LLVM_BINARY_DIR}/tools/mlir/include) -- include_directories(SYSTEM ${MLIR_INCLUDE_DIR}) -- include_directories(SYSTEM ${MLIR_GENERATED_INCLUDE_DIR}) -- include_directories(SYSTEM ${MLIR_TABLEGEN_OUTPUT_DIR}) -- -- set(BACKEND_PACKAGE_STRING "${PACKAGE_STRING}") -- list(APPEND CMAKE_MODULE_PATH "${MLIR_MAIN_SRC_DIR}/cmake/modules") --elseif(NOT STABLEHLO_BUILD_EMBEDDED) -- message(STATUS "Building StableHLO with an installed MLIR") -- find_package(MLIR REQUIRED CONFIG) -- message(STATUS "Using MLIRConfig.cmake in: ${MLIR_DIR}") -- message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}") -- set(LLVM_RUNTIME_OUTPUT_INTDIR ${CMAKE_BINARY_DIR}/bin) -- set(LLVM_LIBRARY_OUTPUT_INTDIR ${CMAKE_BINARY_DIR}/lib) -- list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}") -- list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}") --else() -- message(STATUS "Building StableHLO embedded in another project") --endif() -- --if(LLVM_ENABLE_ZLIB) -- find_package(ZLIB) --endif() -- --include(TableGen) --include(AddLLVM) --include(AddMLIR) --include(HandleLLVMOptions) --include_directories(${LLVM_INCLUDE_DIRS}) --include_directories(${MLIR_INCLUDE_DIRS}) --include_directories(${CMAKE_CURRENT_SOURCE_DIR}) --include_directories(${CMAKE_CURRENT_BINARY_DIR}) --link_directories(${LLVM_BUILD_LIBRARY_DIR}) --add_definitions(${LLVM_DEFINITIONS}) -- --#------------------------------------------------------------------------------- --# Python configuration --#------------------------------------------------------------------------------- -- --if(STABLEHLO_ENABLE_BINDINGS_PYTHON) -- if(NOT STABLEHLO_EXTERNAL_PROJECT_BUILD) -- message(WARNING "StableHLO Python bindings are not supported in standalone mode") -- endif() -- -- include(MLIRDetectPythonEnv) -- mlir_configure_python_dev_packages() --endif() -+set(STABLEHLO_ENABLE_BINDINGS_PYTHON ${MHLO_ENABLE_BINDINGS_PYTHON}) - - #------------------------------------------------------------------------------- - # Directory setup -diff --ruN a/stablehlo/stablehlo/dialect/Base.cpp b/stablehlo/stablehlo/dialect/Base.cpp ---- stablehlo/stablehlo/dialect/Base.cpp -+++ stablehlo/stablehlo/dialect/Base.cpp -@@ -156,6 +156,7 @@ - DenseIntElementsAttr attr; - if (!matchPattern(value, m_Constant(&attr))) return failure(); - -+ // Signless types are treated as signed, per StableHLO convention. - // Unless the type is i1 (which models boolean type from the StableHLO spec), - // in which case it's considered to be unsigned. - auto elementType = attr.getType().getElementType(); -@@ -599,5 +600,18 @@ - return UnrankedTensorType::get(components.getElementType()); - } - -+DenseIntElementsAttr getPaddingAttr(MLIRContext* context, -+ ArrayRef values) { -+ return DenseIntElementsAttr::get( -+ RankedTensorType::get({static_cast(values.size()) / 2, 2}, -+ IntegerType::get(context, 64)), -+ values); -+} -+ -+DenseIntElementsAttr getPaddingAttr(Builder* builder, -+ ArrayRef values) { -+ return getPaddingAttr(builder->getContext(), values); -+} -+ - } // namespace hlo - } // namespace mlir -diff --ruN a/stablehlo/stablehlo/dialect/Base.h b/stablehlo/stablehlo/dialect/Base.h ---- stablehlo/stablehlo/dialect/Base.h -+++ stablehlo/stablehlo/dialect/Base.h -@@ -194,6 +194,10 @@ - - ShapedType createShapedType(ShapedTypeComponents components); - -+DenseIntElementsAttr getPaddingAttr(MLIRContext *context, -+ ArrayRef value); -+DenseIntElementsAttr getPaddingAttr(Builder *builder, ArrayRef value); -+ - // This interface is implemented by both StableHLO and MHLO dialects - // and is used as the foundation for sharing verification, type inference and - // prettyprinting logic between them. -@@ -249,6 +253,10 @@ - template - class BroadcastingElementwise - : public mlir::OpTrait::TraitBase {}; -+ -+template -+class IsCommutative -+ : public mlir::OpTrait::TraitBase {}; - - template - class PairwiseSameOperandAndResultType -diff --ruN a/stablehlo/stablehlo/dialect/Base.td b/stablehlo/stablehlo/dialect/Base.td ---- stablehlo/stablehlo/dialect/Base.td -+++ stablehlo/stablehlo/dialect/Base.td -@@ -188,6 +188,11 @@ - // An operation that is essentially element-wise but may implement broadcasting - // semantics. - def HLO_BroadcastingElementwise : HLO_NativeOpTrait<"BroadcastingElementwise">; -+ -+// This class adds property that the operation is commutative. -+// Upstream IsCommutative has default folders, and StableHLO aims to have no -+// default folders or canonicalizations. -+def HLO_Commutative : HLO_NativeOpTrait<"IsCommutative">; - - // Op has pairwise operand and result type matching: the number of operands - // must be equal to the number of results and the type of ith operand must -diff --ruN a/stablehlo/stablehlo/dialect/CMakeLists.txt b/stablehlo/stablehlo/dialect/CMakeLists.txt ---- stablehlo/stablehlo/dialect/CMakeLists.txt -+++ stablehlo/stablehlo/dialect/CMakeLists.txt -@@ -77,6 +77,20 @@ - target_include_directories(ChloOps INTERFACE - $ - $ -+) -+ -+add_mlir_dialect_library(ExperimentalOps -+ PARTIAL_SOURCES_INTENDED -+ ExperimentalOps.cpp -+ -+ DEPENDS -+ StablehloOpsIncGen -+ -+ LINK_LIBS PUBLIC -+ MLIRFuncDialect -+ MLIRIR -+ MLIRSupport -+ StablehloOps - ) - - add_mlir_dialect_library(StablehloRegister -diff --ruN a/stablehlo/stablehlo/dialect/ExperimentalOps.cpp b/stablehlo/stablehlo/dialect/ExperimentalOps.cpp ---- stablehlo/stablehlo/dialect/ExperimentalOps.cpp -+++ stablehlo/stablehlo/dialect/ExperimentalOps.cpp -@@ -0,0 +1,392 @@ -+/* Copyright 2023 The StableHLO Authors. -+ -+Licensed under the Apache License, Version 2.0 (the "License"); -+you may not use this file except in compliance with the License. -+You may obtain a copy of the License at -+ -+ http://www.apache.org/licenses/LICENSE-2.0 -+ -+Unless required by applicable law or agreed to in writing, software -+distributed under the License is distributed on an "AS IS" BASIS, -+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+See the License for the specific language governing permissions and -+limitations under the License. -+==============================================================================*/ -+ -+#include "stablehlo/dialect/ExperimentalOps.h" -+ -+#include -+ -+#include "llvm/ADT/ArrayRef.h" -+#include "llvm/ADT/STLExtras.h" -+#include "mlir/Dialect/Func/IR/FuncOps.h" -+#include "mlir/IR/BuiltinAttributes.h" -+#include "mlir/IR/BuiltinOps.h" -+#include "mlir/IR/BuiltinTypeInterfaces.h" -+#include "mlir/IR/Types.h" -+ -+namespace mlir { -+namespace stablehlo { -+ -+LogicalResult DynamicReduceWindowOpAdaptor::verify() { -+ // Before checking the constraints inherited from ReduceWindowOp, -+ // make sure that the operands and the attributes of the underlying custom -+ // call make sense. -+ if (op_->getNumOperands() != 2 * op_->getNumResults() + 5) -+ return op_.emitError("expects size(operands) = 2 * size(results) + 5"); -+ if (op_->getNumResults() == 0) -+ return op_.emitError("expects size(results) > 0"); -+ for (const auto& attr : op_->getAttrs()) { -+ // api_version and backend_config have default values. -+ // call_target_name should be "stablehlo.dynamic_reduce_window". -+ // called_computations carries the body. -+ if (attr.getName() != "api_version" && -+ attr.getName() != "backend_config" && -+ attr.getName() != "call_target_name" && -+ attr.getName() != "called_computations") -+ return op_.emitError() -+ << attr.getName() << " is not a supported attribute"; -+ } -+ if (!op_.getBackendConfig().empty()) -+ return op_.emitError() << "expects an empty backend_config"; -+ if (op_.getCallTargetName() != "stablehlo.dynamic_reduce_window") -+ return op_.emitError() << "expects @stablehlo.dynamic_reduce_window"; -+ -+ // Unpack operands and attributes of the underlying custom call into -+ // operation-specific inputs. -+ auto numInputs = getInputs().size(); -+ auto inputs = op_.getInputs().slice(0, numInputs); -+ auto initValues = op_.getInputs().slice(numInputs, numInputs); -+ auto windowDimensions = op_.getInputs()[op_.getInputs().size() - 5]; -+ auto windowStrides = op_.getInputs()[op_.getInputs().size() - 4]; -+ auto baseDilations = op_.getInputs()[op_.getInputs().size() - 3]; -+ auto windowDilations = op_.getInputs()[op_.getInputs().size() - 2]; -+ auto padding = op_.getInputs()[op_.getInputs().size() - 1]; -+ auto results = op_.getResults(); -+ -+ // reduce_window_c1 -+ // This constraint hold automatically thanks to the checks that we have -+ // performed above. -+ -+ // reduce_window_i1 -+ SmallVector inputTypes; -+ for (auto [index, input] : llvm::enumerate(inputs)) { -+ auto inputType = input.getType().dyn_cast(); -+ inputTypes.push_back(inputType); -+ if (!inputType) -+ return op_.emitError() -+ << "expects inputs (e.g. operand #" << index << ") to be tensors"; -+ } -+ -+ // reduce_window_i2 -+ SmallVector initValueTypes; -+ for (auto [index, initValue] : llvm::enumerate(initValues)) { -+ auto initValueType = initValue.getType().dyn_cast(); -+ initValueTypes.push_back(initValueType); -+ if (!initValueType || !initValueType.hasRank() || -+ initValueType.getRank() != 0) -+ return op_.emitError() << "expects init_values (e.g. operand #" -+ << numInputs + index << ") " -+ << "to be 0-dimensional tensors"; -+ } -+ -+ // reduce_window_i3...reduce_window_i7 -+ auto checkRank = [&](StringRef name, int64_t index, Value dynamicAttr, -+ int64_t expectedRank) -> LogicalResult { -+ auto type = dynamicAttr.getType().dyn_cast(); -+ if (!type || !type.hasRank() || type.getRank() != expectedRank || -+ !type.getElementType().isIntOrIndex()) { -+ if (index < 0) index += op_->getNumOperands(); -+ return op_.emitError() -+ << "expects " << name << " (operand #" << index << ") " -+ << "to be a " << expectedRank << "-dimensional tensor " -+ << "of integer or index type"; -+ } -+ return success(); -+ }; -+ if (failed(checkRank("window_dimensions", -5, windowDimensions, 1)) || -+ failed(checkRank("window_strides", -4, windowStrides, 1)) || -+ failed(checkRank("base_dilations", -3, baseDilations, 1)) || -+ failed(checkRank("window_dilations", -2, windowDilations, 1)) || -+ failed(checkRank("padding", -1, padding, 2))) -+ return failure(); -+ -+ // reduce_window_i7 -+ auto paddingType = getPadding().getType().dyn_cast(); -+ if (!paddingType || !paddingType.hasRank() || paddingType.getRank() != 2 || -+ paddingType.getDimSize(1) != 2 || -+ !paddingType.getElementType().isIntOrIndex()) -+ return op_.emitError() -+ << "expects padding_type (operand #" << op_.getNumOperands() - 1 -+ << ") to be a 2-dimensional tensor of integer or index type"; -+ -+ // reduce_window_c2 -+ std::optional> inputShape; -+ for (auto inputType : inputTypes) { -+ if (!inputType.hasRank()) continue; -+ if (!inputShape) inputShape = inputType.getShape(); -+ if (failed(verifyCompatibleShape(inputType.getShape(), *inputShape))) -+ return op_.emitError() << "expects all inputs (operands 0.." << numInputs -+ << ") to have compatible shapes"; -+ } -+ -+ // reduce_window_c3 -+ for (auto [inputType, initValueType] : -+ llvm::zip(inputTypes, initValueTypes)) { -+ if (inputType.getElementType() != initValueType.getElementType()) -+ return op_.emitError() << "expects inputs (operands 0.." << numInputs -+ << ") and init_values (operands " << numInputs -+ << ".." << numInputs * 2 << ") to have pairwise " -+ << "the same element types"; -+ } -+ -+ // reduce_window_c4...reduce_window_c12 -+ // In this range, we only verify the constraints with even numbers. -+ // Verifying the constraints with odd numbers would require knowing the -+ // actual values of window_dimensions, window_strides, etc. -+ // While we certainly can try to check whether they are constants and -+ // verify them in that case, that seems like too much at this point. -+ auto checkShape = [&](StringRef name, int64_t index, Value dynamicAttr, -+ ArrayRef expectedShape) -> LogicalResult { -+ auto type = dynamicAttr.getType().cast(); -+ if (type.getShape() != expectedShape) { -+ if (index < 0) index += op_->getNumOperands(); -+ return op_.emitError() -+ << "expects " << name << " (operand #" << index << ") " -+ << "to have shape [" << expectedShape << "]"; -+ } -+ return success(); -+ }; -+ if (inputShape) { -+ auto inputRank = static_cast(inputShape->size()); -+ if (failed(checkShape("window_dimensions", -5, windowDimensions, -+ {inputRank})) || -+ failed(checkShape("window_strides", -4, windowStrides, {inputRank})) || -+ failed(checkShape("base_dilations", -3, baseDilations, {inputRank})) || -+ failed( -+ checkShape("window_dilations", -2, windowDilations, {inputRank})) || -+ failed(checkShape("padding", -1, padding, {inputRank, 2}))) -+ return failure(); -+ } -+ -+ // reduce_window_c13 -+ if (op_.getCalledComputations().size() != 1) -+ return op_.emitError() << "expects called_computations to have 1 element"; -+ auto bodyAttr = op_.getCalledComputations()[0].cast(); -+ auto bodyFunc = -+ op_->getParentOfType().lookupSymbol(bodyAttr); -+ if (!bodyFunc) -+ return op_.emitError() << "expects called_computations to refer to " -+ << "a function that exists within a parent module"; -+ -+ // reduce_window_c13 -+ SmallVector expectedBodyInputs; -+ llvm::append_range(expectedBodyInputs, initValueTypes); -+ llvm::append_range(expectedBodyInputs, initValueTypes); -+ SmallVector expectedBodyOutputs; -+ llvm::append_range(expectedBodyOutputs, initValueTypes); -+ auto expectedBodyType = FunctionType::get( -+ op_.getContext(), expectedBodyInputs, expectedBodyOutputs); -+ if (bodyFunc.getFunctionType() != expectedBodyType) -+ return op_.emitError() << "expects body to have type " << expectedBodyType; -+ -+ // reduce_window_c14 -+ SmallVector resultTypes; -+ std::optional> resultShape; -+ for (auto result : results) { -+ auto resultType = result.getType().dyn_cast(); -+ resultTypes.push_back(resultType); -+ if (!resultType) return op_.emitError() << "expects results to be tensors"; -+ -+ if (!resultType.hasRank()) continue; -+ if (!resultShape) resultShape = resultType.getShape(); -+ if (failed(verifyCompatibleShape(resultType.getShape(), *resultShape))) -+ return op_.emitError() << "expects all results to have compatible shapes"; -+ } -+ -+ // reduce_window_c15 -+ // Verifying this constraint would require knowing the actual values of -+ // window_dimensions, window_strides, etc. -+ // While we certainly can try to check whether they are constants and -+ // verify them in that case, that seems like too much at this point. -+ -+ // reduce_window_c16 -+ for (auto [resultType, initValueType] : -+ llvm::zip(resultTypes, initValueTypes)) { -+ if (resultType.getElementType() != initValueType.getElementType()) -+ return op_.emitError() << "expects results and init_values (operands " -+ << numInputs << ".." << numInputs * 2 << ") " -+ << "to have pairwise the same element types"; -+ } -+ -+ return success(); -+} -+ -+ValueRange DynamicReduceWindowOpAdaptor::getInputs() { -+ auto numInputs = (op_.getInputs().size() - 5) / 2; -+ return op_.getInputs().slice(0, numInputs); -+} -+ -+ValueRange DynamicReduceWindowOpAdaptor::getInitValues() { -+ auto numInputs = (op_.getInputs().size() - 5) / 2; -+ return op_.getInputs().slice(numInputs, numInputs); -+} -+ -+TypedValue DynamicReduceWindowOpAdaptor::getWindowDimensions() { -+ return op_.getInputs()[op_.getInputs().size() - 5] -+ .cast>(); -+} -+ -+TypedValue DynamicReduceWindowOpAdaptor::getWindowStrides() { -+ return op_.getInputs()[op_.getInputs().size() - 4] -+ .cast>(); -+} -+ -+TypedValue DynamicReduceWindowOpAdaptor::getBaseDilations() { -+ return op_.getInputs()[op_.getInputs().size() - 3] -+ .cast>(); -+} -+ -+TypedValue DynamicReduceWindowOpAdaptor::getWindowDilations() { -+ return op_.getInputs()[op_.getInputs().size() - 2] -+ .cast>(); -+} -+ -+TypedValue DynamicReduceWindowOpAdaptor::getPadding() { -+ return op_.getInputs()[op_.getInputs().size() - 1] -+ .cast>(); -+} -+ -+Region& DynamicReduceWindowOpAdaptor::getBody() { -+ auto bodyAttr = op_.getCalledComputations()[0].cast(); -+ auto bodyFunc = -+ op_->getParentOfType().lookupSymbol(bodyAttr); -+ return bodyFunc.getBody(); -+} -+ -+ValueRange DynamicReduceWindowOpAdaptor::getResults() { -+ return op_.getResults(); -+} -+ -+std::optional getDynamicReduceWindowOp( -+ CustomCallOp op) { -+ if (op.getCallTargetName() != "stablehlo.dynamic_reduce_window") return {}; -+ return DynamicReduceWindowOpAdaptor(op); -+} -+ -+LogicalResult DynamicRngBitGeneratorOpAdaptor::verify() { -+ // Before checking the constraints inherited from RngBitGeneratorOp, -+ // make sure that the operands and the attributes of the underlying custom -+ // call make sense. -+ if (op_->getNumOperands() != 2) -+ return op_.emitError("expects size(operands) = 2"); -+ if (op_->getNumResults() != 2) -+ return op_.emitError("expects size(results) = 2"); -+ for (const auto& attr : op_->getAttrs()) { -+ // api_version and backend_config have default values. -+ // call_target_name should be "stablehlo.dynamic_rng_bit_generator". -+ // rng_algorithm comes from the operation. -+ if (attr.getName() != "api_version" && attr.getName() != "backend_config" && -+ attr.getName() != "call_target_name" && -+ attr.getName() != "rng_algorithm") -+ return op_.emitError() -+ << attr.getName() << " is not a supported attribute"; -+ } -+ if (!op_.getBackendConfig().empty()) -+ return op_.emitError() << "expects an empty backend_config"; -+ if (op_.getCallTargetName() != "stablehlo.dynamic_rng_bit_generator") -+ return op_.emitError() << "expects @stablehlo.dynamic_rng_bit_generator"; -+ if (!op_->hasAttr("rng_algorithm")) -+ return op_.emitError() << "expects an rng_algorithm"; -+ -+ // Unpack operands and attributes of the underlying custom call into -+ // operation-specific inputs. -+ auto rngAlgorithmAttr = op_->getAttr("rng_algorithm"); -+ auto initialState = op_.getInputs()[0]; -+ auto outputShape = op_.getInputs()[1]; -+ auto outputState = op_.getResults()[0]; -+ auto output = op_.getResults()[1]; -+ -+ // dynamic_rng_bit_generator_i1 -+ if (!rngAlgorithmAttr.isa()) -+ return op_.emitError() -+ << "expects a #stablehlo rng_algorithm"; -+ -+ // dynamic_rng_bit_generator_i2 -+ // TODO(#643): Clarify supported types for RngBitGeneratorOp. -+ auto initialStateType = initialState.getType().dyn_cast(); -+ if (!initialStateType || !initialStateType.getElementType().isIntOrFloat()) -+ return op_.emitError() -+ << "expects initial_state (operand #0) " -+ << "to be a tensor of integer or floating-point type"; -+ -+ // dynamic_rng_bit_generator_i3 -+ auto outputShapeType = outputShape.getType().dyn_cast(); -+ if (!outputShapeType || !outputShapeType.hasRank() || -+ outputShapeType.getRank() != 1 || -+ !outputShapeType.getElementType().isIntOrIndex()) -+ return op_.emitError() -+ << "expects output_shape (operand #1) " -+ << "to be a 1-dimensional tensor of integer or index type"; -+ -+ // dynamic_rng_bit_generator_o1 -+ // TODO(#643): Clarify supported types for RngBitGeneratorOp. -+ auto outputStateType = outputState.getType().dyn_cast(); -+ if (!outputStateType || !outputStateType.getElementType().isIntOrFloat()) -+ return op_.emitError() -+ << "expects output_state (result #0) " -+ << "to be a tensor of integer or floating-point type"; -+ -+ // dynamic_rng_bit_generator_o2 -+ auto outputType = output.getType().dyn_cast(); -+ if (!outputType || !outputType.getElementType().isIntOrFloat()) -+ return op_.emitError() -+ << "expects output (result #1) " -+ << "to be a tensor of integer or floating-point type"; -+ -+ // dynamic_rng_bit_generator_c1 -+ if (!hlo::isCompatibleForHloTypeInference(initialStateType, outputStateType)) -+ return op_.emitError() -+ << "expects initial_state (operand #0) and output_state (result #0) " -+ << "to have compatible shapes"; -+ -+ // dynamic_rng_bit_generator_c2 -+ // TODO(#486): Verify rng_algorithm in RngBitGeneratorOp. -+ -+ // dynamic_rng_bit_generator_c3 -+ if (!hlo::isCompatibleForHloTypeInference(outputShape, outputType)) -+ return op_.emitError() << "expects output (result #1) to have shape " -+ << "compatible with output_shape (operand #2)"; -+ -+ return success(); -+} -+ -+RngAlgorithm DynamicRngBitGeneratorOpAdaptor::getRngAlgorithm() { -+ return op_->getAttr("rng_algorithm").cast().getValue(); -+} -+ -+TypedValue DynamicRngBitGeneratorOpAdaptor::getInitialState() { -+ return op_.getInputs()[0].cast>(); -+} -+ -+TypedValue DynamicRngBitGeneratorOpAdaptor::getOutputShape() { -+ return op_.getInputs()[1].cast>(); -+} -+ -+TypedValue DynamicRngBitGeneratorOpAdaptor::getOutputState() { -+ return op_.getResults()[0].cast>(); -+} -+ -+TypedValue DynamicRngBitGeneratorOpAdaptor::getOutput() { -+ return op_.getResults()[1].cast>(); -+} -+ -+std::optional getDynamicRngBitGeneratorOp( -+ CustomCallOp op) { -+ if (op.getCallTargetName() != "stablehlo.dynamic_rng_bit_generator") -+ return {}; -+ return DynamicRngBitGeneratorOpAdaptor(op); -+} -+ -+} // namespace stablehlo -+} // namespace mlir -diff --ruN a/stablehlo/stablehlo/dialect/ExperimentalOps.h b/stablehlo/stablehlo/dialect/ExperimentalOps.h ---- stablehlo/stablehlo/dialect/ExperimentalOps.h -+++ stablehlo/stablehlo/dialect/ExperimentalOps.h -@@ -0,0 +1,170 @@ -+/* Copyright 2023 The StableHLO Authors. -+ -+Licensed under the Apache License, Version 2.0 (the "License"); -+you may not use this file except in compliance with the License. -+You may obtain a copy of the License at -+ -+ http://www.apache.org/licenses/LICENSE-2.0 -+ -+Unless required by applicable law or agreed to in writing, software -+distributed under the License is distributed on an "AS IS" BASIS, -+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+See the License for the specific language governing permissions and -+limitations under the License. -+==============================================================================*/ -+ -+#ifndef STABLEHLO_DIALECT_EXPERIMENTAL_OPS_H -+#define STABLEHLO_DIALECT_EXPERIMENTAL_OPS_H -+ -+// This file supports XLA-specific experiments with the StableHLO opset. -+// These experiments are not yet ready to be upstreamed to openxla/stablehlo -+// and are incubating towards the respective StableHLO RFCs. -+// -+// Custom calls (which are the implementation vehicle of these experiments) -+// don't have compatibility guarantees within the StableHLO process, but -+// the StableHLO team at Google provides out-of-band guarantees for these -+// custom calls, with the same compatibility window as StableHLO upstream. -+ -+#include "mlir/IR/Operation.h" -+#include "mlir/IR/Region.h" -+#include "mlir/IR/Value.h" -+#include "mlir/IR/ValueRange.h" -+#include "mlir/Support/LogicalResult.h" -+#include "stablehlo/dialect/StablehloOps.h" -+ -+namespace mlir { -+namespace stablehlo { -+ -+// The DynamicReduceWindowOp experiment provides a dynamic version of -+// ReduceWindowOp. Once the dynamism RFC is figured out, we expect to have an -+// upstream representation for this notion. -+// -+// Within this experiment, DynamicReduceWindowOp is represented via the -+// `stablehlo.custom_call @stablehlo.dynamic_reduce_window` custom call. -+// This custom call has the following operands which represent a dynamic version -+// of operands and attributes of ReduceWindowOp: -+// * [0:N] => inputs -+// * [N:2*N] => init_values -+// * [-5] => window_dimensions -+// * [-4] => window_strides -+// * [-3] => base_dilations -+// * [-2] => window_dilations -+// * [-1] => padding -+// Additionally, to represent the body of DynamicReduceWindowOp, the custom call -+// has a satellite function attached to the custom call via called_computations. -+// -+// Semantics of DynamicReduceWindowOp are inherited from semantics of -+// https://github.com/openxla/stablehlo/blob/main/docs/spec.md#reduce_window -+// with the following exceptions: -+// 1) All tensor constants, i.e. window_dimensions, window_strides, -+// base_dilations, window_dilations and padding, become tensors of -+// integer type. -+// 2) As a result, some of the constraints can no longer be validated -+// statically. However, this operation still expects these constraints -+// to hold dynamically, and if they don't hold, the behavior is undefined. -+class DynamicReduceWindowOpAdaptor { -+ public: -+ DynamicReduceWindowOpAdaptor(CustomCallOp op) : op_(op) {} -+ operator Operation*() { return op_; } -+ Operation* operator->() { return op_; } -+ -+ // Same accessors as for stablehlo::ReduceWindowOp, except that all the -+ // std::optional attributes have turned into values. -+ // These accessors assume that the operation is well-formed (i.e. that it -+ // can pass verification). -+ ValueRange getInputs(); -+ ValueRange getInitValues(); -+ TypedValue getWindowDimensions(); -+ TypedValue getWindowStrides(); -+ TypedValue getBaseDilations(); -+ TypedValue getWindowDilations(); -+ TypedValue getPadding(); -+ Region& getBody(); -+ ValueRange getResults(); -+ -+ // Verifies the constraints documented above. -+ // Emits errors if errors are detected. -+ LogicalResult verify(); -+ -+ private: -+ CustomCallOp op_; -+}; -+ -+// Wraps a custom call in a DynamicReduceWindowAdaptor. -+// Fails if the call_target_name of the custom call doesn't match -+// "stablehlo.dynamic_reduce_window". -+std::optional getDynamicReduceWindowOp( -+ CustomCallOp op); -+ -+// The DynamicRngBitGeneratorOp experiment provides a dynamic version of -+// RngBitGeneratorOp. Once the dynamism RFC is figured out, we expect to have an -+// upstream representation for this notion. -+// -+// Within this experiment, DynamicRngBitGeneratorOp is represented via the -+// `stablehlo.custom_call @stablehlo.dynamic_rng_bit_generator` custom call. -+// This custom call has the regular operand of RngBitGeneratorOp plus an -+// additional `output_shape` operand that determines the shape of the output: -+// * [0] => initial_state -+// * [1] => output_shape -+// -+// Semantics of DynamicRngBitGeneratorOp are inherited from semantics of -+// https://github.com/openxla/stablehlo/blob/main/docs/spec.md#rng_bit_generator -+// extended with an additional input (I3) and an additional constraint (C3): -+// -+// #### Inputs -+// -+// | Label | Name | Type | -+// |-------|-----------------|----------------------------------------------| -+// | (I1) | `rng_algorithm` | enum of `DEFAULT`, `THREE_FRY`, and `PHILOX` | -+// | (I2) | `initial_state` | 1-dimensional tensor of type `ui64` | -+// | (I3) | `output_shape` | 1-dimensional tensor of integer type | -+// -+// #### Outputs -+// -+// | Name | Type | -+// |----------------|------------------------------------------| -+// | `output_state` | 1-dimensional tensor of type `ui64` | -+// | `output` | tensor of integer or floating-point type | -+// -+// #### Constraints -+// -+// * (C1) `type(initial_state) = type(output_state)`. -+// * (C2) `size(initial_state)` is defined as: -+// * implementation-defined if `rng_algorithm = DEFAULT`. -+// * `2` if `rng_algorithm = THREE_FRY`. -+// * `2` or `3` if `rng_algorithm = PHILOX`. -+// * (C3) `shape(output) = output_shape`. -+class DynamicRngBitGeneratorOpAdaptor { -+ public: -+ DynamicRngBitGeneratorOpAdaptor(CustomCallOp op) : op_(op) {} -+ operator Operation*() { return op_; } -+ Operation* operator->() { return op_; } -+ -+ // Same accessors as for stablehlo::RngBitGeneratorOp, extended with the -+ // additional `output_shape` operand. -+ // These accessors assume that the operation is well-formed (i.e. that it -+ // can pass verification). -+ RngAlgorithm getRngAlgorithm(); -+ TypedValue getInitialState(); -+ TypedValue getOutputShape(); -+ TypedValue getOutputState(); -+ TypedValue getOutput(); -+ -+ // Verifies the constraints documented above. -+ // Emits errors if errors are detected. -+ LogicalResult verify(); -+ -+ private: -+ CustomCallOp op_; -+}; -+ -+// Wraps a custom call in a DynamicReduceWindowAdaptor. -+// Fails if the call_target_name of the custom call doesn't match -+// "stablehlo.dynamic_rng_bit_generator". -+std::optional getDynamicRngBitGeneratorOp( -+ CustomCallOp op); -+ -+} // namespace stablehlo -+} // namespace mlir -+ -+#endif // STABLEHLO_DIALECT_EXPERIMENTAL_OPS_H -diff --ruN a/stablehlo/stablehlo/dialect/StablehloOps.cpp b/stablehlo/stablehlo/dialect/StablehloOps.cpp ---- stablehlo/stablehlo/dialect/StablehloOps.cpp -+++ stablehlo/stablehlo/dialect/StablehloOps.cpp -@@ -1467,7 +1467,7 @@ - if (innerOp.getNumOperands() != 2 || - !innerOp.hasTrait() || - !hasSameOperandAndResultTypes(innerOp) || -- !innerOp.hasTrait() || -+ !innerOp.hasTrait() || - !innerOp.hasTrait()) - return false; - -@@ -1664,7 +1664,7 @@ - if (!innerOpDialect || !innerOpDialect->getNamespace().equals("stablehlo") || - !innerOpNameInfo->hasTrait::Impl>() || - !innerOpNameInfo->hasTrait() || -- !innerOpNameInfo->hasTrait() || -+ !innerOpNameInfo->hasTrait() || - !innerOpNameInfo->hasTrait()) { - parser.emitError(loc, - "expected the inner-op to be a commutative binary-op from " -diff --ruN a/stablehlo/stablehlo/dialect/StablehloOps.td b/stablehlo/stablehlo/dialect/StablehloOps.td ---- stablehlo/stablehlo/dialect/StablehloOps.td -+++ stablehlo/stablehlo/dialect/StablehloOps.td -@@ -687,7 +687,7 @@ - } - - def StableHLO_AddOp : StableHLO_BinaryElementwiseOp<"add", -- [Commutative, Pure, HLO_CompatibleOperandsAndResultType]> { -+ [HLO_Commutative, Pure, HLO_CompatibleOperandsAndResultType]> { - let summary = "Add operation"; - let description = [{ - Performs element-wise addition of two tensors `lhs` and `rhs` and produces a -@@ -769,7 +769,7 @@ - } - - def StableHLO_MaxOp : StableHLO_BinaryElementwiseOp<"maximum", -- [Commutative, Pure, HLO_CompatibleOperandsAndResultType]> { -+ [HLO_Commutative, Pure, HLO_CompatibleOperandsAndResultType]> { - let summary = "Max operation"; - let description = [{ - Performs element-wise max operation on tensors `lhs` and `rhs` and produces -@@ -786,7 +786,7 @@ - } - - def StableHLO_MinOp : StableHLO_BinaryElementwiseOp<"minimum", -- [Commutative, Pure, HLO_CompatibleOperandsAndResultType]> { -+ [HLO_Commutative, Pure, HLO_CompatibleOperandsAndResultType]> { - let summary = "Min operation"; - let description = [{ - Performs element-wise min operation on tensors `lhs` and `rhs` and produces a -@@ -803,7 +803,7 @@ - } - - def StableHLO_MulOp : StableHLO_BinaryElementwiseOp<"multiply", -- [Commutative, Pure, HLO_CompatibleOperandsAndResultType]> { -+ [HLO_Commutative, Pure, HLO_CompatibleOperandsAndResultType]> { - let summary = "Mul operation"; - let description = [{ - Performs element-wise product of two tensors `lhs` and `rhs` and produces a -@@ -933,7 +933,7 @@ - // See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations - class StableHLO_BinaryBiwiseOrLogicalElementwiseOp : - StableHLO_BinaryElementwiseOp { -+ [HLO_Commutative, Pure, HLO_CompatibleOperandsAndResultType]> { - let arguments = (ins - HLO_PredOrIntTensor:$lhs, - HLO_PredOrIntTensor:$rhs -diff --ruN a/stablehlo/stablehlo/testdata/acosh_shape_bfloat16_20_20.mlir b/stablehlo/stablehlo/testdata/acosh_shape_bfloat16_20_20.mlir ---- stablehlo/stablehlo/testdata/acosh_shape_bfloat16_20_20.mlir -+++ stablehlo/stablehlo/testdata/acosh_shape_bfloat16_20_20.mlir -@@ -16,9 +16,9 @@ - %10 = stablehlo.constant dense<6.914060e-01> : tensor<20x20xbf16> - %11 = stablehlo.add %8, %10 : tensor<20x20xbf16> - %12 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xbf16> -- %13 = stablehlo.add %12, %0 : tensor<20x20xbf16> -+ %13 = stablehlo.add %0, %12 : tensor<20x20xbf16> - %14 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xbf16> -- %15 = stablehlo.add %14, %0 : tensor<20x20xbf16> -+ %15 = stablehlo.add %0, %14 : tensor<20x20xbf16> - %16 = stablehlo.multiply %13, %15 : tensor<20x20xbf16> - %17 = stablehlo.sqrt %16 : tensor<20x20xbf16> - %18 = stablehlo.add %0, %17 : tensor<20x20xbf16> -diff --ruN a/stablehlo/stablehlo/testdata/acosh_shape_float16_20_20.mlir b/stablehlo/stablehlo/testdata/acosh_shape_float16_20_20.mlir ---- stablehlo/stablehlo/testdata/acosh_shape_float16_20_20.mlir -+++ stablehlo/stablehlo/testdata/acosh_shape_float16_20_20.mlir -@@ -16,9 +16,9 @@ - %10 = stablehlo.constant dense<6.933590e-01> : tensor<20x20xf16> - %11 = stablehlo.add %8, %10 : tensor<20x20xf16> - %12 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf16> -- %13 = stablehlo.add %12, %0 : tensor<20x20xf16> -+ %13 = stablehlo.add %0, %12 : tensor<20x20xf16> - %14 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xf16> -- %15 = stablehlo.add %14, %0 : tensor<20x20xf16> -+ %15 = stablehlo.add %0, %14 : tensor<20x20xf16> - %16 = stablehlo.multiply %13, %15 : tensor<20x20xf16> - %17 = stablehlo.sqrt %16 : tensor<20x20xf16> - %18 = stablehlo.add %0, %17 : tensor<20x20xf16> -diff --ruN a/stablehlo/stablehlo/testdata/acosh_shape_float32_20_20.mlir b/stablehlo/stablehlo/testdata/acosh_shape_float32_20_20.mlir ---- stablehlo/stablehlo/testdata/acosh_shape_float32_20_20.mlir -+++ stablehlo/stablehlo/testdata/acosh_shape_float32_20_20.mlir -@@ -16,9 +16,9 @@ - %10 = stablehlo.constant dense<0.693147182> : tensor<20x20xf32> - %11 = stablehlo.add %8, %10 : tensor<20x20xf32> - %12 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> -- %13 = stablehlo.add %12, %0 : tensor<20x20xf32> -+ %13 = stablehlo.add %0, %12 : tensor<20x20xf32> - %14 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xf32> -- %15 = stablehlo.add %14, %0 : tensor<20x20xf32> -+ %15 = stablehlo.add %0, %14 : tensor<20x20xf32> - %16 = stablehlo.multiply %13, %15 : tensor<20x20xf32> - %17 = stablehlo.sqrt %16 : tensor<20x20xf32> - %18 = stablehlo.add %0, %17 : tensor<20x20xf32> -diff --ruN a/stablehlo/stablehlo/testdata/asin_shape_bfloat16_20_20.mlir b/stablehlo/stablehlo/testdata/asin_shape_bfloat16_20_20.mlir ---- stablehlo/stablehlo/testdata/asin_shape_bfloat16_20_20.mlir -+++ stablehlo/stablehlo/testdata/asin_shape_bfloat16_20_20.mlir -@@ -11,9 +11,9 @@ - %5 = stablehlo.multiply %0, %0 : tensor<20x20xbf16> - %6 = stablehlo.subtract %4, %5 : tensor<20x20xbf16> - %7 = stablehlo.sqrt %6 : tensor<20x20xbf16> -- %8 = stablehlo.add %3, %7 : tensor<20x20xbf16> -+ %8 = stablehlo.add %7, %3 : tensor<20x20xbf16> - %9 = stablehlo.atan2 %0, %8 : tensor<20x20xbf16> -- %10 = stablehlo.multiply %2, %9 : tensor<20x20xbf16> -+ %10 = stablehlo.multiply %9, %2 : tensor<20x20xbf16> - %11 = stablehlo.custom_call @check.eq(%10, %1) : (tensor<20x20xbf16>, tensor<20x20xbf16>) -> tensor - return %11 : tensor - } -diff --ruN a/stablehlo/stablehlo/testdata/asin_shape_complex64_20_20.mlir b/stablehlo/stablehlo/testdata/asin_shape_complex64_20_20.mlir ---- stablehlo/stablehlo/testdata/asin_shape_complex64_20_20.mlir -+++ stablehlo/stablehlo/testdata/asin_shape_complex64_20_20.mlir -@@ -11,9 +11,9 @@ - %5 = stablehlo.multiply %0, %0 : tensor<20x20xcomplex> - %6 = stablehlo.subtract %4, %5 : tensor<20x20xcomplex> - %7 = stablehlo.sqrt %6 : tensor<20x20xcomplex> -- %8 = stablehlo.add %3, %7 : tensor<20x20xcomplex> -+ %8 = stablehlo.add %7, %3 : tensor<20x20xcomplex> - %9 = stablehlo.atan2 %0, %8 : tensor<20x20xcomplex> -- %10 = stablehlo.multiply %2, %9 : tensor<20x20xcomplex> -+ %10 = stablehlo.multiply %9, %2 : tensor<20x20xcomplex> - %11 = stablehlo.custom_call @check.eq(%10, %1) : (tensor<20x20xcomplex>, tensor<20x20xcomplex>) -> tensor - return %11 : tensor - } -diff --ruN a/stablehlo/stablehlo/testdata/asin_shape_float16_20_20.mlir b/stablehlo/stablehlo/testdata/asin_shape_float16_20_20.mlir ---- stablehlo/stablehlo/testdata/asin_shape_float16_20_20.mlir -+++ stablehlo/stablehlo/testdata/asin_shape_float16_20_20.mlir -@@ -11,9 +11,9 @@ - %5 = stablehlo.multiply %0, %0 : tensor<20x20xf16> - %6 = stablehlo.subtract %4, %5 : tensor<20x20xf16> - %7 = stablehlo.sqrt %6 : tensor<20x20xf16> -- %8 = stablehlo.add %3, %7 : tensor<20x20xf16> -+ %8 = stablehlo.add %7, %3 : tensor<20x20xf16> - %9 = stablehlo.atan2 %0, %8 : tensor<20x20xf16> -- %10 = stablehlo.multiply %2, %9 : tensor<20x20xf16> -+ %10 = stablehlo.multiply %9, %2 : tensor<20x20xf16> - %11 = stablehlo.custom_call @check.eq(%10, %1) : (tensor<20x20xf16>, tensor<20x20xf16>) -> tensor - return %11 : tensor - } -diff --ruN a/stablehlo/stablehlo/testdata/asin_shape_float32_20_20.mlir b/stablehlo/stablehlo/testdata/asin_shape_float32_20_20.mlir ---- stablehlo/stablehlo/testdata/asin_shape_float32_20_20.mlir -+++ stablehlo/stablehlo/testdata/asin_shape_float32_20_20.mlir -@@ -11,9 +11,9 @@ - %5 = stablehlo.multiply %0, %0 : tensor<20x20xf32> - %6 = stablehlo.subtract %4, %5 : tensor<20x20xf32> - %7 = stablehlo.sqrt %6 : tensor<20x20xf32> -- %8 = stablehlo.add %3, %7 : tensor<20x20xf32> -+ %8 = stablehlo.add %7, %3 : tensor<20x20xf32> - %9 = stablehlo.atan2 %0, %8 : tensor<20x20xf32> -- %10 = stablehlo.multiply %2, %9 : tensor<20x20xf32> -+ %10 = stablehlo.multiply %9, %2 : tensor<20x20xf32> - %11 = stablehlo.custom_call @check.eq(%10, %1) : (tensor<20x20xf32>, tensor<20x20xf32>) -> tensor - return %11 : tensor - } -diff --ruN a/stablehlo/stablehlo/testdata/asinh_shape_bfloat16_20_20.mlir b/stablehlo/stablehlo/testdata/asinh_shape_bfloat16_20_20.mlir ---- stablehlo/stablehlo/testdata/asinh_shape_bfloat16_20_20.mlir -+++ stablehlo/stablehlo/testdata/asinh_shape_bfloat16_20_20.mlir -@@ -28,7 +28,7 @@ - %22 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xbf16> - %23 = stablehlo.add %21, %22 : tensor<20x20xbf16> - %24 = stablehlo.sqrt %23 : tensor<20x20xbf16> -- %25 = stablehlo.add %18, %24 : tensor<20x20xbf16> -+ %25 = stablehlo.add %24, %18 : tensor<20x20xbf16> - %26 = stablehlo.divide %17, %25 : tensor<20x20xbf16> - %27 = stablehlo.multiply %16, %26 : tensor<20x20xbf16> - %28 = stablehlo.add %15, %27 : tensor<20x20xbf16> -diff --ruN a/stablehlo/stablehlo/testdata/asinh_shape_float16_20_20.mlir b/stablehlo/stablehlo/testdata/asinh_shape_float16_20_20.mlir ---- stablehlo/stablehlo/testdata/asinh_shape_float16_20_20.mlir -+++ stablehlo/stablehlo/testdata/asinh_shape_float16_20_20.mlir -@@ -28,7 +28,7 @@ - %22 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf16> - %23 = stablehlo.add %21, %22 : tensor<20x20xf16> - %24 = stablehlo.sqrt %23 : tensor<20x20xf16> -- %25 = stablehlo.add %18, %24 : tensor<20x20xf16> -+ %25 = stablehlo.add %24, %18 : tensor<20x20xf16> - %26 = stablehlo.divide %17, %25 : tensor<20x20xf16> - %27 = stablehlo.multiply %16, %26 : tensor<20x20xf16> - %28 = stablehlo.add %15, %27 : tensor<20x20xf16> -diff --ruN a/stablehlo/stablehlo/testdata/asinh_shape_float32_20_20.mlir b/stablehlo/stablehlo/testdata/asinh_shape_float32_20_20.mlir ---- stablehlo/stablehlo/testdata/asinh_shape_float32_20_20.mlir -+++ stablehlo/stablehlo/testdata/asinh_shape_float32_20_20.mlir -@@ -28,7 +28,7 @@ - %22 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> - %23 = stablehlo.add %21, %22 : tensor<20x20xf32> - %24 = stablehlo.sqrt %23 : tensor<20x20xf32> -- %25 = stablehlo.add %18, %24 : tensor<20x20xf32> -+ %25 = stablehlo.add %24, %18 : tensor<20x20xf32> - %26 = stablehlo.divide %17, %25 : tensor<20x20xf32> - %27 = stablehlo.multiply %16, %26 : tensor<20x20xf32> - %28 = stablehlo.add %15, %27 : tensor<20x20xf32> -diff --ruN a/stablehlo/stablehlo/testdata/bessel_i0e_shape_bfloat16_20_20.mlir b/stablehlo/stablehlo/testdata/bessel_i0e_shape_bfloat16_20_20.mlir ---- stablehlo/stablehlo/testdata/bessel_i0e_shape_bfloat16_20_20.mlir -+++ stablehlo/stablehlo/testdata/bessel_i0e_shape_bfloat16_20_20.mlir -@@ -33,7 +33,7 @@ - %8 = stablehlo.constant dense<5.000000e-01> : tensor<20x20xf32> - %9 = stablehlo.constant dense<5.000000e-01> : tensor - %10 = stablehlo.constant dense<5.000000e-01> : tensor<20x20xf32> -- %11 = stablehlo.multiply %10, %3 : tensor<20x20xf32> -+ %11 = stablehlo.multiply %3, %10 : tensor<20x20xf32> - %12 = stablehlo.constant dense<2.000000e+00> : tensor - %13 = stablehlo.constant dense<2.000000e+00> : tensor<20x20xf32> - %14 = stablehlo.subtract %11, %13 : tensor<20x20xf32> -@@ -133,7 +133,7 @@ - %108 = stablehlo.constant dense<0.676795303> : tensor<20x20xf32> - %109 = stablehlo.add %106, %108 : tensor<20x20xf32> - %110 = stablehlo.subtract %109, %99 : tensor<20x20xf32> -- %111 = stablehlo.multiply %8, %110 : tensor<20x20xf32> -+ %111 = stablehlo.multiply %110, %8 : tensor<20x20xf32> - %112 = stablehlo.constant dense<5.000000e-01> : tensor - %113 = stablehlo.constant dense<5.000000e-01> : tensor<20x20xf32> - %114 = stablehlo.constant dense<3.200000e+01> : tensor -@@ -182,7 +182,7 @@ - %157 = stablehlo.constant dense<0.804490387> : tensor<20x20xf32> - %158 = stablehlo.add %155, %157 : tensor<20x20xf32> - %159 = stablehlo.subtract %158, %148 : tensor<20x20xf32> -- %160 = stablehlo.multiply %113, %159 : tensor<20x20xf32> -+ %160 = stablehlo.multiply %159, %113 : tensor<20x20xf32> - %161 = stablehlo.sqrt %3 : tensor<20x20xf32> - %162 = stablehlo.divide %160, %161 : tensor<20x20xf32> - %163 = stablehlo.select %6, %111, %162 : tensor<20x20xi1>, tensor<20x20xf32> -diff --ruN a/stablehlo/stablehlo/testdata/bessel_i0e_shape_float16_20_20.mlir b/stablehlo/stablehlo/testdata/bessel_i0e_shape_float16_20_20.mlir ---- stablehlo/stablehlo/testdata/bessel_i0e_shape_float16_20_20.mlir -+++ stablehlo/stablehlo/testdata/bessel_i0e_shape_float16_20_20.mlir -@@ -33,7 +33,7 @@ - %8 = stablehlo.constant dense<5.000000e-01> : tensor<20x20xf32> - %9 = stablehlo.constant dense<5.000000e-01> : tensor - %10 = stablehlo.constant dense<5.000000e-01> : tensor<20x20xf32> -- %11 = stablehlo.multiply %10, %3 : tensor<20x20xf32> -+ %11 = stablehlo.multiply %3, %10 : tensor<20x20xf32> - %12 = stablehlo.constant dense<2.000000e+00> : tensor - %13 = stablehlo.constant dense<2.000000e+00> : tensor<20x20xf32> - %14 = stablehlo.subtract %11, %13 : tensor<20x20xf32> -@@ -133,7 +133,7 @@ - %108 = stablehlo.constant dense<0.676795303> : tensor<20x20xf32> - %109 = stablehlo.add %106, %108 : tensor<20x20xf32> - %110 = stablehlo.subtract %109, %99 : tensor<20x20xf32> -- %111 = stablehlo.multiply %8, %110 : tensor<20x20xf32> -+ %111 = stablehlo.multiply %110, %8 : tensor<20x20xf32> - %112 = stablehlo.constant dense<5.000000e-01> : tensor - %113 = stablehlo.constant dense<5.000000e-01> : tensor<20x20xf32> - %114 = stablehlo.constant dense<3.200000e+01> : tensor -@@ -182,7 +182,7 @@ - %157 = stablehlo.constant dense<0.804490387> : tensor<20x20xf32> - %158 = stablehlo.add %155, %157 : tensor<20x20xf32> - %159 = stablehlo.subtract %158, %148 : tensor<20x20xf32> -- %160 = stablehlo.multiply %113, %159 : tensor<20x20xf32> -+ %160 = stablehlo.multiply %159, %113 : tensor<20x20xf32> - %161 = stablehlo.sqrt %3 : tensor<20x20xf32> - %162 = stablehlo.divide %160, %161 : tensor<20x20xf32> - %163 = stablehlo.select %6, %111, %162 : tensor<20x20xi1>, tensor<20x20xf32> -diff --ruN a/stablehlo/stablehlo/testdata/bessel_i0e_shape_float32_20_20.mlir b/stablehlo/stablehlo/testdata/bessel_i0e_shape_float32_20_20.mlir ---- stablehlo/stablehlo/testdata/bessel_i0e_shape_float32_20_20.mlir -+++ stablehlo/stablehlo/testdata/bessel_i0e_shape_float32_20_20.mlir -@@ -32,7 +32,7 @@ - %7 = stablehlo.constant dense<5.000000e-01> : tensor<20x20xf32> - %8 = stablehlo.constant dense<5.000000e-01> : tensor - %9 = stablehlo.constant dense<5.000000e-01> : tensor<20x20xf32> -- %10 = stablehlo.multiply %9, %2 : tensor<20x20xf32> -+ %10 = stablehlo.multiply %2, %9 : tensor<20x20xf32> - %11 = stablehlo.constant dense<2.000000e+00> : tensor - %12 = stablehlo.constant dense<2.000000e+00> : tensor<20x20xf32> - %13 = stablehlo.subtract %10, %12 : tensor<20x20xf32> -@@ -132,7 +132,7 @@ - %107 = stablehlo.constant dense<0.676795303> : tensor<20x20xf32> - %108 = stablehlo.add %105, %107 : tensor<20x20xf32> - %109 = stablehlo.subtract %108, %98 : tensor<20x20xf32> -- %110 = stablehlo.multiply %7, %109 : tensor<20x20xf32> -+ %110 = stablehlo.multiply %109, %7 : tensor<20x20xf32> - %111 = stablehlo.constant dense<5.000000e-01> : tensor - %112 = stablehlo.constant dense<5.000000e-01> : tensor<20x20xf32> - %113 = stablehlo.constant dense<3.200000e+01> : tensor -@@ -181,7 +181,7 @@ - %156 = stablehlo.constant dense<0.804490387> : tensor<20x20xf32> - %157 = stablehlo.add %154, %156 : tensor<20x20xf32> - %158 = stablehlo.subtract %157, %147 : tensor<20x20xf32> -- %159 = stablehlo.multiply %112, %158 : tensor<20x20xf32> -+ %159 = stablehlo.multiply %158, %112 : tensor<20x20xf32> - %160 = stablehlo.sqrt %2 : tensor<20x20xf32> - %161 = stablehlo.divide %159, %160 : tensor<20x20xf32> - %162 = stablehlo.select %5, %110, %161 : tensor<20x20xi1>, tensor<20x20xf32> -diff --ruN a/stablehlo/stablehlo/testdata/bessel_i1e_shape_bfloat16_20_20.mlir b/stablehlo/stablehlo/testdata/bessel_i1e_shape_bfloat16_20_20.mlir ---- stablehlo/stablehlo/testdata/bessel_i1e_shape_bfloat16_20_20.mlir -+++ stablehlo/stablehlo/testdata/bessel_i1e_shape_bfloat16_20_20.mlir -@@ -11,7 +11,7 @@ - %5 = stablehlo.constant dense<2.000000e+00> : tensor<20x20xf32> - %6 = stablehlo.constant dense<3.200000e+01> : tensor<20x20xf32> - %7 = stablehlo.constant dense<8.000000e+00> : tensor<20x20xf32> -- %8 = stablehlo.multiply %4, %3 : tensor<20x20xf32> -+ %8 = stablehlo.multiply %3, %4 : tensor<20x20xf32> - %9 = stablehlo.subtract %8, %5 : tensor<20x20xf32> - %10 = stablehlo.constant dense<0.000000e+00> : tensor<20x20xf32> - %11 = stablehlo.constant dense<0.000000e+00> : tensor<20x20xf32> -diff --ruN a/stablehlo/stablehlo/testdata/bessel_i1e_shape_float16_20_20.mlir b/stablehlo/stablehlo/testdata/bessel_i1e_shape_float16_20_20.mlir ---- stablehlo/stablehlo/testdata/bessel_i1e_shape_float16_20_20.mlir -+++ stablehlo/stablehlo/testdata/bessel_i1e_shape_float16_20_20.mlir -@@ -11,7 +11,7 @@ - %5 = stablehlo.constant dense<2.000000e+00> : tensor<20x20xf32> - %6 = stablehlo.constant dense<3.200000e+01> : tensor<20x20xf32> - %7 = stablehlo.constant dense<8.000000e+00> : tensor<20x20xf32> -- %8 = stablehlo.multiply %4, %3 : tensor<20x20xf32> -+ %8 = stablehlo.multiply %3, %4 : tensor<20x20xf32> - %9 = stablehlo.subtract %8, %5 : tensor<20x20xf32> - %10 = stablehlo.constant dense<0.000000e+00> : tensor<20x20xf32> - %11 = stablehlo.constant dense<0.000000e+00> : tensor<20x20xf32> -diff --ruN a/stablehlo/stablehlo/testdata/bessel_i1e_shape_float32_20_20.mlir b/stablehlo/stablehlo/testdata/bessel_i1e_shape_float32_20_20.mlir ---- stablehlo/stablehlo/testdata/bessel_i1e_shape_float32_20_20.mlir -+++ stablehlo/stablehlo/testdata/bessel_i1e_shape_float32_20_20.mlir -@@ -10,7 +10,7 @@ - %4 = stablehlo.constant dense<2.000000e+00> : tensor<20x20xf32> - %5 = stablehlo.constant dense<3.200000e+01> : tensor<20x20xf32> - %6 = stablehlo.constant dense<8.000000e+00> : tensor<20x20xf32> -- %7 = stablehlo.multiply %3, %2 : tensor<20x20xf32> -+ %7 = stablehlo.multiply %2, %3 : tensor<20x20xf32> - %8 = stablehlo.subtract %7, %4 : tensor<20x20xf32> - %9 = stablehlo.constant dense<0.000000e+00> : tensor<20x20xf32> - %10 = stablehlo.constant dense<0.000000e+00> : tensor<20x20xf32> -diff --ruN a/stablehlo/stablehlo/testdata/conv_general_dilated_1d_stride_2_even_enable_xla_True_dynamic.mlir b/stablehlo/stablehlo/testdata/conv_general_dilated_1d_stride_2_even_enable_xla_True_dynamic.mlir ---- stablehlo/stablehlo/testdata/conv_general_dilated_1d_stride_2_even_enable_xla_True_dynamic.mlir -+++ stablehlo/stablehlo/testdata/conv_general_dilated_1d_stride_2_even_enable_xla_True_dynamic.mlir -@@ -16,7 +16,7 @@ - %11 = stablehlo.constant dense<1> : tensor - %12 = stablehlo.subtract %3, %11 : tensor - %13 = stablehlo.select %10, %12, %3 : tensor, tensor -- %14 = stablehlo.multiply %2, %13 : tensor -+ %14 = stablehlo.multiply %13, %2 : tensor - %15 = stablehlo.subtract %1, %14 : tensor - %16 = stablehlo.constant dense<2> : tensor - %17 = stablehlo.add %15, %16 : tensor -@@ -32,7 +32,7 @@ - %27 = stablehlo.constant dense<1> : tensor - %28 = stablehlo.subtract %19, %27 : tensor - %29 = stablehlo.select %26, %28, %19 : tensor, tensor -- %30 = stablehlo.multiply %18, %29 : tensor -+ %30 = stablehlo.multiply %29, %18 : tensor - %31 = stablehlo.subtract %17, %30 : tensor - %32 = stablehlo.constant dense<-1> : tensor - %33 = stablehlo.multiply %arg0, %32 : tensor -@@ -48,7 +48,7 @@ - %43 = stablehlo.constant dense<1> : tensor - %44 = stablehlo.subtract %35, %43 : tensor - %45 = stablehlo.select %42, %44, %35 : tensor, tensor -- %46 = stablehlo.multiply %34, %45 : tensor -+ %46 = stablehlo.multiply %45, %34 : tensor - %47 = stablehlo.subtract %33, %46 : tensor - %48 = stablehlo.constant dense<-1> : tensor - %49 = stablehlo.multiply %arg0, %48 : tensor -@@ -64,7 +64,7 @@ - %59 = stablehlo.constant dense<1> : tensor - %60 = stablehlo.subtract %51, %59 : tensor - %61 = stablehlo.select %58, %60, %51 : tensor, tensor -- %62 = stablehlo.multiply %50, %61 : tensor -+ %62 = stablehlo.multiply %61, %50 : tensor - %63 = stablehlo.subtract %49, %62 : tensor - %64 = stablehlo.constant dense<2> : tensor - %65 = stablehlo.add %63, %64 : tensor -@@ -80,7 +80,7 @@ - %75 = stablehlo.constant dense<1> : tensor - %76 = stablehlo.subtract %67, %75 : tensor - %77 = stablehlo.select %74, %76, %67 : tensor, tensor -- %78 = stablehlo.multiply %66, %77 : tensor -+ %78 = stablehlo.multiply %77, %66 : tensor - %79 = stablehlo.subtract %65, %78 : tensor - %80 = stablehlo.constant dense<-1> : tensor - %81 = stablehlo.multiply %77, %80 : tensor -diff --ruN a/stablehlo/stablehlo/testdata/conv_general_dilated_1d_stride_2_odd_enable_xla_True_dynamic.mlir b/stablehlo/stablehlo/testdata/conv_general_dilated_1d_stride_2_odd_enable_xla_True_dynamic.mlir ---- stablehlo/stablehlo/testdata/conv_general_dilated_1d_stride_2_odd_enable_xla_True_dynamic.mlir -+++ stablehlo/stablehlo/testdata/conv_general_dilated_1d_stride_2_odd_enable_xla_True_dynamic.mlir -@@ -16,7 +16,7 @@ - %11 = stablehlo.constant dense<1> : tensor - %12 = stablehlo.subtract %3, %11 : tensor - %13 = stablehlo.select %10, %12, %3 : tensor, tensor -- %14 = stablehlo.multiply %2, %13 : tensor -+ %14 = stablehlo.multiply %13, %2 : tensor - %15 = stablehlo.subtract %1, %14 : tensor - %16 = stablehlo.constant dense<2> : tensor - %17 = stablehlo.add %15, %16 : tensor -@@ -32,7 +32,7 @@ - %27 = stablehlo.constant dense<1> : tensor - %28 = stablehlo.subtract %19, %27 : tensor - %29 = stablehlo.select %26, %28, %19 : tensor, tensor -- %30 = stablehlo.multiply %18, %29 : tensor -+ %30 = stablehlo.multiply %29, %18 : tensor - %31 = stablehlo.subtract %17, %30 : tensor - %32 = stablehlo.constant dense<-1> : tensor - %33 = stablehlo.multiply %arg0, %32 : tensor -@@ -48,7 +48,7 @@ - %43 = stablehlo.constant dense<1> : tensor - %44 = stablehlo.subtract %35, %43 : tensor - %45 = stablehlo.select %42, %44, %35 : tensor, tensor -- %46 = stablehlo.multiply %34, %45 : tensor -+ %46 = stablehlo.multiply %45, %34 : tensor - %47 = stablehlo.subtract %33, %46 : tensor - %48 = stablehlo.constant dense<-1> : tensor - %49 = stablehlo.multiply %arg0, %48 : tensor -@@ -64,7 +64,7 @@ - %59 = stablehlo.constant dense<1> : tensor - %60 = stablehlo.subtract %51, %59 : tensor - %61 = stablehlo.select %58, %60, %51 : tensor, tensor -- %62 = stablehlo.multiply %50, %61 : tensor -+ %62 = stablehlo.multiply %61, %50 : tensor - %63 = stablehlo.subtract %49, %62 : tensor - %64 = stablehlo.constant dense<2> : tensor - %65 = stablehlo.add %63, %64 : tensor -@@ -80,7 +80,7 @@ - %75 = stablehlo.constant dense<1> : tensor - %76 = stablehlo.subtract %67, %75 : tensor - %77 = stablehlo.select %74, %76, %67 : tensor, tensor -- %78 = stablehlo.multiply %66, %77 : tensor -+ %78 = stablehlo.multiply %77, %66 : tensor - %79 = stablehlo.subtract %65, %78 : tensor - %80 = stablehlo.constant dense<-1> : tensor - %81 = stablehlo.multiply %77, %80 : tensor -diff --ruN a/stablehlo/stablehlo/testdata/digamma_shape_bfloat16_20_20.mlir b/stablehlo/stablehlo/testdata/digamma_shape_bfloat16_20_20.mlir ---- stablehlo/stablehlo/testdata/digamma_shape_bfloat16_20_20.mlir -+++ stablehlo/stablehlo/testdata/digamma_shape_bfloat16_20_20.mlir -@@ -21,7 +21,7 @@ - %15 = stablehlo.divide %11, %14 : tensor<20x20xf32> - %16 = stablehlo.subtract %9, %15 : tensor<20x20xf32> - %17 = stablehlo.divide %11, %13 : tensor<20x20xf32> -- %18 = stablehlo.add %10, %17 : tensor<20x20xf32> -+ %18 = stablehlo.add %17, %10 : tensor<20x20xf32> - %19 = stablehlo.constant dense<-1259.13916> : tensor<20x20xf32> - %20 = stablehlo.constant dense<2.000000e+00> : tensor<20x20xf32> - %21 = stablehlo.add %8, %20 : tensor<20x20xf32> -@@ -79,11 +79,11 @@ - %73 = stablehlo.divide %67, %69 : tensor<20x20xf32> - %74 = stablehlo.add %66, %73 : tensor<20x20xf32> - %75 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> -- %76 = stablehlo.add %75, %8 : tensor<20x20xf32> -+ %76 = stablehlo.add %8, %75 : tensor<20x20xf32> - %77 = stablehlo.constant dense<2.01490307> : tensor<20x20xf32> - %78 = stablehlo.divide %8, %75 : tensor<20x20xf32> - %79 = stablehlo.log_plus_one %78 : tensor<20x20xf32> -- %80 = stablehlo.add %77, %79 : tensor<20x20xf32> -+ %80 = stablehlo.add %79, %77 : tensor<20x20xf32> - %81 = stablehlo.divide %72, %74 : tensor<20x20xf32> - %82 = stablehlo.constant dense<7.000000e+00> : tensor<20x20xf32> - %83 = stablehlo.divide %82, %76 : tensor<20x20xf32> -@@ -95,10 +95,10 @@ - %89 = stablehlo.abs %88 : tensor<20x20xf32> - %90 = stablehlo.add %2, %89 : tensor<20x20xf32> - %91 = stablehlo.constant dense<3.14159274> : tensor<20x20xf32> -- %92 = stablehlo.multiply %91, %90 : tensor<20x20xf32> -+ %92 = stablehlo.multiply %90, %91 : tensor<20x20xf32> - %93 = stablehlo.cosine %92 : tensor<20x20xf32> - %94 = stablehlo.sine %92 : tensor<20x20xf32> -- %95 = stablehlo.multiply %91, %93 : tensor<20x20xf32> -+ %95 = stablehlo.multiply %93, %91 : tensor<20x20xf32> - %96 = stablehlo.divide %95, %94 : tensor<20x20xf32> - %97 = stablehlo.subtract %85, %96 : tensor<20x20xf32> - %98 = stablehlo.select %4, %97, %85 : tensor<20x20xi1>, tensor<20x20xf32> -diff --ruN a/stablehlo/stablehlo/testdata/digamma_shape_float16_20_20.mlir b/stablehlo/stablehlo/testdata/digamma_shape_float16_20_20.mlir ---- stablehlo/stablehlo/testdata/digamma_shape_float16_20_20.mlir -+++ stablehlo/stablehlo/testdata/digamma_shape_float16_20_20.mlir -@@ -21,7 +21,7 @@ - %15 = stablehlo.divide %11, %14 : tensor<20x20xf32> - %16 = stablehlo.subtract %9, %15 : tensor<20x20xf32> - %17 = stablehlo.divide %11, %13 : tensor<20x20xf32> -- %18 = stablehlo.add %10, %17 : tensor<20x20xf32> -+ %18 = stablehlo.add %17, %10 : tensor<20x20xf32> - %19 = stablehlo.constant dense<-1259.13916> : tensor<20x20xf32> - %20 = stablehlo.constant dense<2.000000e+00> : tensor<20x20xf32> - %21 = stablehlo.add %8, %20 : tensor<20x20xf32> -@@ -79,11 +79,11 @@ - %73 = stablehlo.divide %67, %69 : tensor<20x20xf32> - %74 = stablehlo.add %66, %73 : tensor<20x20xf32> - %75 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> -- %76 = stablehlo.add %75, %8 : tensor<20x20xf32> -+ %76 = stablehlo.add %8, %75 : tensor<20x20xf32> - %77 = stablehlo.constant dense<2.01490307> : tensor<20x20xf32> - %78 = stablehlo.divide %8, %75 : tensor<20x20xf32> - %79 = stablehlo.log_plus_one %78 : tensor<20x20xf32> -- %80 = stablehlo.add %77, %79 : tensor<20x20xf32> -+ %80 = stablehlo.add %79, %77 : tensor<20x20xf32> - %81 = stablehlo.divide %72, %74 : tensor<20x20xf32> - %82 = stablehlo.constant dense<7.000000e+00> : tensor<20x20xf32> - %83 = stablehlo.divide %82, %76 : tensor<20x20xf32> -@@ -95,10 +95,10 @@ - %89 = stablehlo.abs %88 : tensor<20x20xf32> - %90 = stablehlo.add %2, %89 : tensor<20x20xf32> - %91 = stablehlo.constant dense<3.14159274> : tensor<20x20xf32> -- %92 = stablehlo.multiply %91, %90 : tensor<20x20xf32> -+ %92 = stablehlo.multiply %90, %91 : tensor<20x20xf32> - %93 = stablehlo.cosine %92 : tensor<20x20xf32> - %94 = stablehlo.sine %92 : tensor<20x20xf32> -- %95 = stablehlo.multiply %91, %93 : tensor<20x20xf32> -+ %95 = stablehlo.multiply %93, %91 : tensor<20x20xf32> - %96 = stablehlo.divide %95, %94 : tensor<20x20xf32> - %97 = stablehlo.subtract %85, %96 : tensor<20x20xf32> - %98 = stablehlo.select %4, %97, %85 : tensor<20x20xi1>, tensor<20x20xf32> -diff --ruN a/stablehlo/stablehlo/testdata/digamma_shape_float32_20_20.mlir b/stablehlo/stablehlo/testdata/digamma_shape_float32_20_20.mlir ---- stablehlo/stablehlo/testdata/digamma_shape_float32_20_20.mlir -+++ stablehlo/stablehlo/testdata/digamma_shape_float32_20_20.mlir -@@ -20,7 +20,7 @@ - %14 = stablehlo.divide %10, %13 : tensor<20x20xf32> - %15 = stablehlo.subtract %8, %14 : tensor<20x20xf32> - %16 = stablehlo.divide %10, %12 : tensor<20x20xf32> -- %17 = stablehlo.add %9, %16 : tensor<20x20xf32> -+ %17 = stablehlo.add %16, %9 : tensor<20x20xf32> - %18 = stablehlo.constant dense<-1259.13916> : tensor<20x20xf32> - %19 = stablehlo.constant dense<2.000000e+00> : tensor<20x20xf32> - %20 = stablehlo.add %7, %19 : tensor<20x20xf32> -@@ -78,11 +78,11 @@ - %72 = stablehlo.divide %66, %68 : tensor<20x20xf32> - %73 = stablehlo.add %65, %72 : tensor<20x20xf32> - %74 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> -- %75 = stablehlo.add %74, %7 : tensor<20x20xf32> -+ %75 = stablehlo.add %7, %74 : tensor<20x20xf32> - %76 = stablehlo.constant dense<2.01490307> : tensor<20x20xf32> - %77 = stablehlo.divide %7, %74 : tensor<20x20xf32> - %78 = stablehlo.log_plus_one %77 : tensor<20x20xf32> -- %79 = stablehlo.add %76, %78 : tensor<20x20xf32> -+ %79 = stablehlo.add %78, %76 : tensor<20x20xf32> - %80 = stablehlo.divide %71, %73 : tensor<20x20xf32> - %81 = stablehlo.constant dense<7.000000e+00> : tensor<20x20xf32> - %82 = stablehlo.divide %81, %75 : tensor<20x20xf32> -@@ -94,10 +94,10 @@ - %88 = stablehlo.abs %87 : tensor<20x20xf32> - %89 = stablehlo.add %0, %88 : tensor<20x20xf32> - %90 = stablehlo.constant dense<3.14159274> : tensor<20x20xf32> -- %91 = stablehlo.multiply %90, %89 : tensor<20x20xf32> -+ %91 = stablehlo.multiply %89, %90 : tensor<20x20xf32> - %92 = stablehlo.cosine %91 : tensor<20x20xf32> - %93 = stablehlo.sine %91 : tensor<20x20xf32> -- %94 = stablehlo.multiply %90, %92 : tensor<20x20xf32> -+ %94 = stablehlo.multiply %92, %90 : tensor<20x20xf32> - %95 = stablehlo.divide %94, %93 : tensor<20x20xf32> - %96 = stablehlo.subtract %84, %95 : tensor<20x20xf32> - %97 = stablehlo.select %3, %96, %84 : tensor<20x20xi1>, tensor<20x20xf32> -diff --ruN a/stablehlo/stablehlo/testdata/erf_shape_bfloat16_20_20.mlir b/stablehlo/stablehlo/testdata/erf_shape_bfloat16_20_20.mlir ---- stablehlo/stablehlo/testdata/erf_shape_bfloat16_20_20.mlir -+++ stablehlo/stablehlo/testdata/erf_shape_bfloat16_20_20.mlir -@@ -11,7 +11,7 @@ - %5 = stablehlo.clamp %3, %2, %4 : tensor<20x20xf32> - %6 = stablehlo.multiply %5, %5 : tensor<20x20xf32> - %7 = stablehlo.constant dense<0.000000e+00> : tensor<20x20xf32> -- %8 = stablehlo.multiply %7, %6 : tensor<20x20xf32> -+ %8 = stablehlo.multiply %6, %7 : tensor<20x20xf32> - %9 = stablehlo.constant dense<-2.72614237E-10> : tensor<20x20xf32> - %10 = stablehlo.add %8, %9 : tensor<20x20xf32> - %11 = stablehlo.multiply %10, %6 : tensor<20x20xf32> -@@ -33,7 +33,7 @@ - %27 = stablehlo.constant dense<-0.0160960332> : tensor<20x20xf32> - %28 = stablehlo.add %26, %27 : tensor<20x20xf32> - %29 = stablehlo.constant dense<0.000000e+00> : tensor<20x20xf32> -- %30 = stablehlo.multiply %29, %6 : tensor<20x20xf32> -+ %30 = stablehlo.multiply %6, %29 : tensor<20x20xf32> - %31 = stablehlo.constant dense<-1.45660715E-5> : tensor<20x20xf32> - %32 = stablehlo.add %30, %31 : tensor<20x20xf32> - %33 = stablehlo.multiply %32, %6 : tensor<20x20xf32> -diff --ruN a/stablehlo/stablehlo/testdata/erf_shape_float16_20_20.mlir b/stablehlo/stablehlo/testdata/erf_shape_float16_20_20.mlir ---- stablehlo/stablehlo/testdata/erf_shape_float16_20_20.mlir -+++ stablehlo/stablehlo/testdata/erf_shape_float16_20_20.mlir -@@ -11,7 +11,7 @@ - %5 = stablehlo.clamp %3, %2, %4 : tensor<20x20xf32> - %6 = stablehlo.multiply %5, %5 : tensor<20x20xf32> - %7 = stablehlo.constant dense<0.000000e+00> : tensor<20x20xf32> -- %8 = stablehlo.multiply %7, %6 : tensor<20x20xf32> -+ %8 = stablehlo.multiply %6, %7 : tensor<20x20xf32> - %9 = stablehlo.constant dense<-2.72614237E-10> : tensor<20x20xf32> - %10 = stablehlo.add %8, %9 : tensor<20x20xf32> - %11 = stablehlo.multiply %10, %6 : tensor<20x20xf32> -@@ -33,7 +33,7 @@ - %27 = stablehlo.constant dense<-0.0160960332> : tensor<20x20xf32> - %28 = stablehlo.add %26, %27 : tensor<20x20xf32> - %29 = stablehlo.constant dense<0.000000e+00> : tensor<20x20xf32> -- %30 = stablehlo.multiply %29, %6 : tensor<20x20xf32> -+ %30 = stablehlo.multiply %6, %29 : tensor<20x20xf32> - %31 = stablehlo.constant dense<-1.45660715E-5> : tensor<20x20xf32> - %32 = stablehlo.add %30, %31 : tensor<20x20xf32> - %33 = stablehlo.multiply %32, %6 : tensor<20x20xf32> -diff --ruN a/stablehlo/stablehlo/testdata/erf_shape_float32_20_20.mlir b/stablehlo/stablehlo/testdata/erf_shape_float32_20_20.mlir ---- stablehlo/stablehlo/testdata/erf_shape_float32_20_20.mlir -+++ stablehlo/stablehlo/testdata/erf_shape_float32_20_20.mlir -@@ -10,7 +10,7 @@ - %4 = stablehlo.clamp %2, %0, %3 : tensor<20x20xf32> - %5 = stablehlo.multiply %4, %4 : tensor<20x20xf32> - %6 = stablehlo.constant dense<0.000000e+00> : tensor<20x20xf32> -- %7 = stablehlo.multiply %6, %5 : tensor<20x20xf32> -+ %7 = stablehlo.multiply %5, %6 : tensor<20x20xf32> - %8 = stablehlo.constant dense<-2.72614237E-10> : tensor<20x20xf32> - %9 = stablehlo.add %7, %8 : tensor<20x20xf32> - %10 = stablehlo.multiply %9, %5 : tensor<20x20xf32> -@@ -32,7 +32,7 @@ - %26 = stablehlo.constant dense<-0.0160960332> : tensor<20x20xf32> - %27 = stablehlo.add %25, %26 : tensor<20x20xf32> - %28 = stablehlo.constant dense<0.000000e+00> : tensor<20x20xf32> -- %29 = stablehlo.multiply %28, %5 : tensor<20x20xf32> -+ %29 = stablehlo.multiply %5, %28 : tensor<20x20xf32> - %30 = stablehlo.constant dense<-1.45660715E-5> : tensor<20x20xf32> - %31 = stablehlo.add %29, %30 : tensor<20x20xf32> - %32 = stablehlo.multiply %31, %5 : tensor<20x20xf32> -diff --ruN a/stablehlo/stablehlo/testdata/erfc_shape_bfloat16_20_20.mlir b/stablehlo/stablehlo/testdata/erfc_shape_bfloat16_20_20.mlir ---- stablehlo/stablehlo/testdata/erfc_shape_bfloat16_20_20.mlir -+++ stablehlo/stablehlo/testdata/erfc_shape_bfloat16_20_20.mlir -@@ -17,7 +17,7 @@ - %11 = stablehlo.constant dense<2.000000e+00> : tensor<20x20xf32> - %12 = stablehlo.compare LT, %5, %11 : (tensor<20x20xf32>, tensor<20x20xf32>) -> tensor<20x20xi1> - %13 = stablehlo.constant dense<0.000000e+00> : tensor<20x20xf32> -- %14 = stablehlo.multiply %13, %7 : tensor<20x20xf32> -+ %14 = stablehlo.multiply %7, %13 : tensor<20x20xf32> - %15 = stablehlo.constant dense<2.326820e-02> : tensor<20x20xf32> - %16 = stablehlo.add %14, %15 : tensor<20x20xf32> - %17 = stablehlo.multiply %16, %7 : tensor<20x20xf32> -@@ -45,7 +45,7 @@ - %39 = stablehlo.constant dense<0.563825965> : tensor<20x20xf32> - %40 = stablehlo.add %38, %39 : tensor<20x20xf32> - %41 = stablehlo.constant dense<0.000000e+00> : tensor<20x20xf32> -- %42 = stablehlo.multiply %41, %7 : tensor<20x20xf32> -+ %42 = stablehlo.multiply %7, %41 : tensor<20x20xf32> - %43 = stablehlo.constant dense<-10.477664> : tensor<20x20xf32> - %44 = stablehlo.add %42, %43 : tensor<20x20xf32> - %45 = stablehlo.multiply %44, %7 : tensor<20x20xf32> -@@ -81,7 +81,7 @@ - %75 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> - %76 = stablehlo.multiply %2, %2 : tensor<20x20xf32> - %77 = stablehlo.constant dense<0.000000e+00> : tensor<20x20xf32> -- %78 = stablehlo.multiply %77, %76 : tensor<20x20xf32> -+ %78 = stablehlo.multiply %76, %77 : tensor<20x20xf32> - %79 = stablehlo.constant dense<7.85386146E-5> : tensor<20x20xf32> - %80 = stablehlo.add %78, %79 : tensor<20x20xf32> - %81 = stablehlo.multiply %80, %76 : tensor<20x20xf32> -diff --ruN a/stablehlo/stablehlo/testdata/erfc_shape_float16_20_20.mlir b/stablehlo/stablehlo/testdata/erfc_shape_float16_20_20.mlir ---- stablehlo/stablehlo/testdata/erfc_shape_float16_20_20.mlir -+++ stablehlo/stablehlo/testdata/erfc_shape_float16_20_20.mlir -@@ -17,7 +17,7 @@ - %11 = stablehlo.constant dense<2.000000e+00> : tensor<20x20xf32> - %12 = stablehlo.compare LT, %5, %11 : (tensor<20x20xf32>, tensor<20x20xf32>) -> tensor<20x20xi1> - %13 = stablehlo.constant dense<0.000000e+00> : tensor<20x20xf32> -- %14 = stablehlo.multiply %13, %7 : tensor<20x20xf32> -+ %14 = stablehlo.multiply %7, %13 : tensor<20x20xf32> - %15 = stablehlo.constant dense<2.326820e-02> : tensor<20x20xf32> - %16 = stablehlo.add %14, %15 : tensor<20x20xf32> - %17 = stablehlo.multiply %16, %7 : tensor<20x20xf32> -@@ -45,7 +45,7 @@ - %39 = stablehlo.constant dense<0.563825965> : tensor<20x20xf32> - %40 = stablehlo.add %38, %39 : tensor<20x20xf32> - %41 = stablehlo.constant dense<0.000000e+00> : tensor<20x20xf32> -- %42 = stablehlo.multiply %41, %7 : tensor<20x20xf32> -+ %42 = stablehlo.multiply %7, %41 : tensor<20x20xf32> - %43 = stablehlo.constant dense<-10.477664> : tensor<20x20xf32> - %44 = stablehlo.add %42, %43 : tensor<20x20xf32> - %45 = stablehlo.multiply %44, %7 : tensor<20x20xf32> -@@ -81,7 +81,7 @@ - %75 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> - %76 = stablehlo.multiply %2, %2 : tensor<20x20xf32> - %77 = stablehlo.constant dense<0.000000e+00> : tensor<20x20xf32> -- %78 = stablehlo.multiply %77, %76 : tensor<20x20xf32> -+ %78 = stablehlo.multiply %76, %77 : tensor<20x20xf32> - %79 = stablehlo.constant dense<7.85386146E-5> : tensor<20x20xf32> - %80 = stablehlo.add %78, %79 : tensor<20x20xf32> - %81 = stablehlo.multiply %80, %76 : tensor<20x20xf32> -diff --ruN a/stablehlo/stablehlo/testdata/erfc_shape_float32_20_20.mlir b/stablehlo/stablehlo/testdata/erfc_shape_float32_20_20.mlir ---- stablehlo/stablehlo/testdata/erfc_shape_float32_20_20.mlir -+++ stablehlo/stablehlo/testdata/erfc_shape_float32_20_20.mlir -@@ -16,7 +16,7 @@ - %10 = stablehlo.constant dense<2.000000e+00> : tensor<20x20xf32> - %11 = stablehlo.compare LT, %4, %10 : (tensor<20x20xf32>, tensor<20x20xf32>) -> tensor<20x20xi1> - %12 = stablehlo.constant dense<0.000000e+00> : tensor<20x20xf32> -- %13 = stablehlo.multiply %12, %6 : tensor<20x20xf32> -+ %13 = stablehlo.multiply %6, %12 : tensor<20x20xf32> - %14 = stablehlo.constant dense<2.326820e-02> : tensor<20x20xf32> - %15 = stablehlo.add %13, %14 : tensor<20x20xf32> - %16 = stablehlo.multiply %15, %6 : tensor<20x20xf32> -@@ -44,7 +44,7 @@ - %38 = stablehlo.constant dense<0.563825965> : tensor<20x20xf32> - %39 = stablehlo.add %37, %38 : tensor<20x20xf32> - %40 = stablehlo.constant dense<0.000000e+00> : tensor<20x20xf32> -- %41 = stablehlo.multiply %40, %6 : tensor<20x20xf32> -+ %41 = stablehlo.multiply %6, %40 : tensor<20x20xf32> - %42 = stablehlo.constant dense<-10.477664> : tensor<20x20xf32> - %43 = stablehlo.add %41, %42 : tensor<20x20xf32> - %44 = stablehlo.multiply %43, %6 : tensor<20x20xf32> -@@ -80,7 +80,7 @@ - %74 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> - %75 = stablehlo.multiply %0, %0 : tensor<20x20xf32> - %76 = stablehlo.constant dense<0.000000e+00> : tensor<20x20xf32> -- %77 = stablehlo.multiply %76, %75 : tensor<20x20xf32> -+ %77 = stablehlo.multiply %75, %76 : tensor<20x20xf32> - %78 = stablehlo.constant dense<7.85386146E-5> : tensor<20x20xf32> - %79 = stablehlo.add %77, %78 : tensor<20x20xf32> - %80 = stablehlo.multiply %79, %75 : tensor<20x20xf32> -diff --ruN a/stablehlo/stablehlo/testdata/gather_dtypes_shape_bfloat16_10__axis_0_enable_xla_True.mlir b/stablehlo/stablehlo/testdata/gather_dtypes_shape_bfloat16_10__axis_0_enable_xla_True.mlir ---- stablehlo/stablehlo/testdata/gather_dtypes_shape_bfloat16_10__axis_0_enable_xla_True.mlir -+++ stablehlo/stablehlo/testdata/gather_dtypes_shape_bfloat16_10__axis_0_enable_xla_True.mlir -@@ -34,7 +34,7 @@ - %12 = stablehlo.compare LT, %6, %11, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %13 = stablehlo.constant dense<1> : tensor - %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1xi32> -- %15 = stablehlo.add %6, %14 : tensor<1xi32> -+ %15 = stablehlo.add %14, %6 : tensor<1xi32> - %16 = stablehlo.select %12, %15, %6 : tensor<1xi1>, tensor<1xi32> - %17 = stablehlo.broadcast_in_dim %16, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %18 = "stablehlo.gather"(%9, %17) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x1xi32>) -> tensor<1xi32> -@@ -45,7 +45,7 @@ - %23 = stablehlo.compare LT, %7, %22, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %24 = stablehlo.constant dense<1> : tensor - %25 = stablehlo.broadcast_in_dim %24, dims = [] : (tensor) -> tensor<1xi32> -- %26 = stablehlo.add %7, %25 : tensor<1xi32> -+ %26 = stablehlo.add %25, %7 : tensor<1xi32> - %27 = stablehlo.select %23, %26, %7 : tensor<1xi1>, tensor<1xi32> - %28 = stablehlo.broadcast_in_dim %27, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %29 = "stablehlo.gather"(%20, %28) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x1xi32>) -> tensor<1xi32> -diff --ruN a/stablehlo/stablehlo/testdata/gather_dtypes_shape_bool_10__axis_0_enable_xla_True.mlir b/stablehlo/stablehlo/testdata/gather_dtypes_shape_bool_10__axis_0_enable_xla_True.mlir ---- stablehlo/stablehlo/testdata/gather_dtypes_shape_bool_10__axis_0_enable_xla_True.mlir -+++ stablehlo/stablehlo/testdata/gather_dtypes_shape_bool_10__axis_0_enable_xla_True.mlir -@@ -34,7 +34,7 @@ - %12 = stablehlo.compare LT, %6, %11, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %13 = stablehlo.constant dense<1> : tensor - %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1xi32> -- %15 = stablehlo.add %6, %14 : tensor<1xi32> -+ %15 = stablehlo.add %14, %6 : tensor<1xi32> - %16 = stablehlo.select %12, %15, %6 : tensor<1xi1>, tensor<1xi32> - %17 = stablehlo.broadcast_in_dim %16, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %18 = "stablehlo.gather"(%9, %17) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x1xi32>) -> tensor<1xi32> -@@ -45,7 +45,7 @@ - %23 = stablehlo.compare LT, %7, %22, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %24 = stablehlo.constant dense<1> : tensor - %25 = stablehlo.broadcast_in_dim %24, dims = [] : (tensor) -> tensor<1xi32> -- %26 = stablehlo.add %7, %25 : tensor<1xi32> -+ %26 = stablehlo.add %25, %7 : tensor<1xi32> - %27 = stablehlo.select %23, %26, %7 : tensor<1xi1>, tensor<1xi32> - %28 = stablehlo.broadcast_in_dim %27, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %29 = "stablehlo.gather"(%20, %28) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x1xi32>) -> tensor<1xi32> -diff --ruN a/stablehlo/stablehlo/testdata/gather_dtypes_shape_complex64_10__axis_0_enable_xla_True.mlir b/stablehlo/stablehlo/testdata/gather_dtypes_shape_complex64_10__axis_0_enable_xla_True.mlir ---- stablehlo/stablehlo/testdata/gather_dtypes_shape_complex64_10__axis_0_enable_xla_True.mlir -+++ stablehlo/stablehlo/testdata/gather_dtypes_shape_complex64_10__axis_0_enable_xla_True.mlir -@@ -34,7 +34,7 @@ - %12 = stablehlo.compare LT, %6, %11, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %13 = stablehlo.constant dense<1> : tensor - %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1xi32> -- %15 = stablehlo.add %6, %14 : tensor<1xi32> -+ %15 = stablehlo.add %14, %6 : tensor<1xi32> - %16 = stablehlo.select %12, %15, %6 : tensor<1xi1>, tensor<1xi32> - %17 = stablehlo.broadcast_in_dim %16, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %18 = "stablehlo.gather"(%9, %17) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x1xi32>) -> tensor<1xi32> -@@ -45,7 +45,7 @@ - %23 = stablehlo.compare LT, %7, %22, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %24 = stablehlo.constant dense<1> : tensor - %25 = stablehlo.broadcast_in_dim %24, dims = [] : (tensor) -> tensor<1xi32> -- %26 = stablehlo.add %7, %25 : tensor<1xi32> -+ %26 = stablehlo.add %25, %7 : tensor<1xi32> - %27 = stablehlo.select %23, %26, %7 : tensor<1xi1>, tensor<1xi32> - %28 = stablehlo.broadcast_in_dim %27, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %29 = "stablehlo.gather"(%20, %28) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x1xi32>) -> tensor<1xi32> -diff --ruN a/stablehlo/stablehlo/testdata/gather_dtypes_shape_float16_10__axis_0_enable_xla_True.mlir b/stablehlo/stablehlo/testdata/gather_dtypes_shape_float16_10__axis_0_enable_xla_True.mlir ---- stablehlo/stablehlo/testdata/gather_dtypes_shape_float16_10__axis_0_enable_xla_True.mlir -+++ stablehlo/stablehlo/testdata/gather_dtypes_shape_float16_10__axis_0_enable_xla_True.mlir -@@ -34,7 +34,7 @@ - %12 = stablehlo.compare LT, %6, %11, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %13 = stablehlo.constant dense<1> : tensor - %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1xi32> -- %15 = stablehlo.add %6, %14 : tensor<1xi32> -+ %15 = stablehlo.add %14, %6 : tensor<1xi32> - %16 = stablehlo.select %12, %15, %6 : tensor<1xi1>, tensor<1xi32> - %17 = stablehlo.broadcast_in_dim %16, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %18 = "stablehlo.gather"(%9, %17) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x1xi32>) -> tensor<1xi32> -@@ -45,7 +45,7 @@ - %23 = stablehlo.compare LT, %7, %22, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %24 = stablehlo.constant dense<1> : tensor - %25 = stablehlo.broadcast_in_dim %24, dims = [] : (tensor) -> tensor<1xi32> -- %26 = stablehlo.add %7, %25 : tensor<1xi32> -+ %26 = stablehlo.add %25, %7 : tensor<1xi32> - %27 = stablehlo.select %23, %26, %7 : tensor<1xi1>, tensor<1xi32> - %28 = stablehlo.broadcast_in_dim %27, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %29 = "stablehlo.gather"(%20, %28) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x1xi32>) -> tensor<1xi32> -diff --ruN a/stablehlo/stablehlo/testdata/gather_dtypes_shape_float32_10__axis_0_enable_xla_True.mlir b/stablehlo/stablehlo/testdata/gather_dtypes_shape_float32_10__axis_0_enable_xla_True.mlir ---- stablehlo/stablehlo/testdata/gather_dtypes_shape_float32_10__axis_0_enable_xla_True.mlir -+++ stablehlo/stablehlo/testdata/gather_dtypes_shape_float32_10__axis_0_enable_xla_True.mlir -@@ -34,7 +34,7 @@ - %12 = stablehlo.compare LT, %6, %11, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %13 = stablehlo.constant dense<1> : tensor - %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1xi32> -- %15 = stablehlo.add %6, %14 : tensor<1xi32> -+ %15 = stablehlo.add %14, %6 : tensor<1xi32> - %16 = stablehlo.select %12, %15, %6 : tensor<1xi1>, tensor<1xi32> - %17 = stablehlo.broadcast_in_dim %16, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %18 = "stablehlo.gather"(%9, %17) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x1xi32>) -> tensor<1xi32> -@@ -45,7 +45,7 @@ - %23 = stablehlo.compare LT, %7, %22, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %24 = stablehlo.constant dense<1> : tensor - %25 = stablehlo.broadcast_in_dim %24, dims = [] : (tensor) -> tensor<1xi32> -- %26 = stablehlo.add %7, %25 : tensor<1xi32> -+ %26 = stablehlo.add %25, %7 : tensor<1xi32> - %27 = stablehlo.select %23, %26, %7 : tensor<1xi1>, tensor<1xi32> - %28 = stablehlo.broadcast_in_dim %27, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %29 = "stablehlo.gather"(%20, %28) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x1xi32>) -> tensor<1xi32> -diff --ruN a/stablehlo/stablehlo/testdata/gather_dtypes_shape_int16_10__axis_0_enable_xla_True.mlir b/stablehlo/stablehlo/testdata/gather_dtypes_shape_int16_10__axis_0_enable_xla_True.mlir ---- stablehlo/stablehlo/testdata/gather_dtypes_shape_int16_10__axis_0_enable_xla_True.mlir -+++ stablehlo/stablehlo/testdata/gather_dtypes_shape_int16_10__axis_0_enable_xla_True.mlir -@@ -34,7 +34,7 @@ - %12 = stablehlo.compare LT, %6, %11, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %13 = stablehlo.constant dense<1> : tensor - %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1xi32> -- %15 = stablehlo.add %6, %14 : tensor<1xi32> -+ %15 = stablehlo.add %14, %6 : tensor<1xi32> - %16 = stablehlo.select %12, %15, %6 : tensor<1xi1>, tensor<1xi32> - %17 = stablehlo.broadcast_in_dim %16, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %18 = "stablehlo.gather"(%9, %17) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x1xi32>) -> tensor<1xi32> -@@ -45,7 +45,7 @@ - %23 = stablehlo.compare LT, %7, %22, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %24 = stablehlo.constant dense<1> : tensor - %25 = stablehlo.broadcast_in_dim %24, dims = [] : (tensor) -> tensor<1xi32> -- %26 = stablehlo.add %7, %25 : tensor<1xi32> -+ %26 = stablehlo.add %25, %7 : tensor<1xi32> - %27 = stablehlo.select %23, %26, %7 : tensor<1xi1>, tensor<1xi32> - %28 = stablehlo.broadcast_in_dim %27, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %29 = "stablehlo.gather"(%20, %28) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x1xi32>) -> tensor<1xi32> -diff --ruN a/stablehlo/stablehlo/testdata/gather_dtypes_shape_int32_10__axis_0_enable_xla_True.mlir b/stablehlo/stablehlo/testdata/gather_dtypes_shape_int32_10__axis_0_enable_xla_True.mlir ---- stablehlo/stablehlo/testdata/gather_dtypes_shape_int32_10__axis_0_enable_xla_True.mlir -+++ stablehlo/stablehlo/testdata/gather_dtypes_shape_int32_10__axis_0_enable_xla_True.mlir -@@ -34,7 +34,7 @@ - %12 = stablehlo.compare LT, %6, %11, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %13 = stablehlo.constant dense<1> : tensor - %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1xi32> -- %15 = stablehlo.add %6, %14 : tensor<1xi32> -+ %15 = stablehlo.add %14, %6 : tensor<1xi32> - %16 = stablehlo.select %12, %15, %6 : tensor<1xi1>, tensor<1xi32> - %17 = stablehlo.broadcast_in_dim %16, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %18 = "stablehlo.gather"(%9, %17) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x1xi32>) -> tensor<1xi32> -@@ -45,7 +45,7 @@ - %23 = stablehlo.compare LT, %7, %22, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %24 = stablehlo.constant dense<1> : tensor - %25 = stablehlo.broadcast_in_dim %24, dims = [] : (tensor) -> tensor<1xi32> -- %26 = stablehlo.add %7, %25 : tensor<1xi32> -+ %26 = stablehlo.add %25, %7 : tensor<1xi32> - %27 = stablehlo.select %23, %26, %7 : tensor<1xi1>, tensor<1xi32> - %28 = stablehlo.broadcast_in_dim %27, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %29 = "stablehlo.gather"(%20, %28) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x1xi32>) -> tensor<1xi32> -diff --ruN a/stablehlo/stablehlo/testdata/gather_dtypes_shape_int8_10__axis_0_enable_xla_True.mlir b/stablehlo/stablehlo/testdata/gather_dtypes_shape_int8_10__axis_0_enable_xla_True.mlir ---- stablehlo/stablehlo/testdata/gather_dtypes_shape_int8_10__axis_0_enable_xla_True.mlir -+++ stablehlo/stablehlo/testdata/gather_dtypes_shape_int8_10__axis_0_enable_xla_True.mlir -@@ -34,7 +34,7 @@ - %12 = stablehlo.compare LT, %6, %11, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %13 = stablehlo.constant dense<1> : tensor - %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1xi32> -- %15 = stablehlo.add %6, %14 : tensor<1xi32> -+ %15 = stablehlo.add %14, %6 : tensor<1xi32> - %16 = stablehlo.select %12, %15, %6 : tensor<1xi1>, tensor<1xi32> - %17 = stablehlo.broadcast_in_dim %16, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %18 = "stablehlo.gather"(%9, %17) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x1xi32>) -> tensor<1xi32> -@@ -45,7 +45,7 @@ - %23 = stablehlo.compare LT, %7, %22, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %24 = stablehlo.constant dense<1> : tensor - %25 = stablehlo.broadcast_in_dim %24, dims = [] : (tensor) -> tensor<1xi32> -- %26 = stablehlo.add %7, %25 : tensor<1xi32> -+ %26 = stablehlo.add %25, %7 : tensor<1xi32> - %27 = stablehlo.select %23, %26, %7 : tensor<1xi1>, tensor<1xi32> - %28 = stablehlo.broadcast_in_dim %27, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %29 = "stablehlo.gather"(%20, %28) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x1xi32>) -> tensor<1xi32> -diff --ruN a/stablehlo/stablehlo/testdata/gather_dtypes_shape_uint16_10__axis_0_enable_xla_True.mlir b/stablehlo/stablehlo/testdata/gather_dtypes_shape_uint16_10__axis_0_enable_xla_True.mlir ---- stablehlo/stablehlo/testdata/gather_dtypes_shape_uint16_10__axis_0_enable_xla_True.mlir -+++ stablehlo/stablehlo/testdata/gather_dtypes_shape_uint16_10__axis_0_enable_xla_True.mlir -@@ -34,7 +34,7 @@ - %12 = stablehlo.compare LT, %6, %11, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %13 = stablehlo.constant dense<1> : tensor - %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1xi32> -- %15 = stablehlo.add %6, %14 : tensor<1xi32> -+ %15 = stablehlo.add %14, %6 : tensor<1xi32> - %16 = stablehlo.select %12, %15, %6 : tensor<1xi1>, tensor<1xi32> - %17 = stablehlo.broadcast_in_dim %16, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %18 = "stablehlo.gather"(%9, %17) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x1xi32>) -> tensor<1xi32> -@@ -45,7 +45,7 @@ - %23 = stablehlo.compare LT, %7, %22, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %24 = stablehlo.constant dense<1> : tensor - %25 = stablehlo.broadcast_in_dim %24, dims = [] : (tensor) -> tensor<1xi32> -- %26 = stablehlo.add %7, %25 : tensor<1xi32> -+ %26 = stablehlo.add %25, %7 : tensor<1xi32> - %27 = stablehlo.select %23, %26, %7 : tensor<1xi1>, tensor<1xi32> - %28 = stablehlo.broadcast_in_dim %27, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %29 = "stablehlo.gather"(%20, %28) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x1xi32>) -> tensor<1xi32> -diff --ruN a/stablehlo/stablehlo/testdata/gather_dtypes_shape_uint32_10__axis_0_enable_xla_True.mlir b/stablehlo/stablehlo/testdata/gather_dtypes_shape_uint32_10__axis_0_enable_xla_True.mlir ---- stablehlo/stablehlo/testdata/gather_dtypes_shape_uint32_10__axis_0_enable_xla_True.mlir -+++ stablehlo/stablehlo/testdata/gather_dtypes_shape_uint32_10__axis_0_enable_xla_True.mlir -@@ -34,7 +34,7 @@ - %12 = stablehlo.compare LT, %6, %11, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %13 = stablehlo.constant dense<1> : tensor - %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1xi32> -- %15 = stablehlo.add %6, %14 : tensor<1xi32> -+ %15 = stablehlo.add %14, %6 : tensor<1xi32> - %16 = stablehlo.select %12, %15, %6 : tensor<1xi1>, tensor<1xi32> - %17 = stablehlo.broadcast_in_dim %16, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %18 = "stablehlo.gather"(%9, %17) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x1xi32>) -> tensor<1xi32> -@@ -45,7 +45,7 @@ - %23 = stablehlo.compare LT, %7, %22, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %24 = stablehlo.constant dense<1> : tensor - %25 = stablehlo.broadcast_in_dim %24, dims = [] : (tensor) -> tensor<1xi32> -- %26 = stablehlo.add %7, %25 : tensor<1xi32> -+ %26 = stablehlo.add %25, %7 : tensor<1xi32> - %27 = stablehlo.select %23, %26, %7 : tensor<1xi1>, tensor<1xi32> - %28 = stablehlo.broadcast_in_dim %27, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %29 = "stablehlo.gather"(%20, %28) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x1xi32>) -> tensor<1xi32> -diff --ruN a/stablehlo/stablehlo/testdata/gather_dtypes_shape_uint8_10__axis_0_enable_xla_True.mlir b/stablehlo/stablehlo/testdata/gather_dtypes_shape_uint8_10__axis_0_enable_xla_True.mlir ---- stablehlo/stablehlo/testdata/gather_dtypes_shape_uint8_10__axis_0_enable_xla_True.mlir -+++ stablehlo/stablehlo/testdata/gather_dtypes_shape_uint8_10__axis_0_enable_xla_True.mlir -@@ -34,7 +34,7 @@ - %12 = stablehlo.compare LT, %6, %11, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %13 = stablehlo.constant dense<1> : tensor - %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1xi32> -- %15 = stablehlo.add %6, %14 : tensor<1xi32> -+ %15 = stablehlo.add %14, %6 : tensor<1xi32> - %16 = stablehlo.select %12, %15, %6 : tensor<1xi1>, tensor<1xi32> - %17 = stablehlo.broadcast_in_dim %16, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %18 = "stablehlo.gather"(%9, %17) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x1xi32>) -> tensor<1xi32> -@@ -45,7 +45,7 @@ - %23 = stablehlo.compare LT, %7, %22, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %24 = stablehlo.constant dense<1> : tensor - %25 = stablehlo.broadcast_in_dim %24, dims = [] : (tensor) -> tensor<1xi32> -- %26 = stablehlo.add %7, %25 : tensor<1xi32> -+ %26 = stablehlo.add %25, %7 : tensor<1xi32> - %27 = stablehlo.select %23, %26, %7 : tensor<1xi1>, tensor<1xi32> - %28 = stablehlo.broadcast_in_dim %27, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %29 = "stablehlo.gather"(%20, %28) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x1xi32>) -> tensor<1xi32> -diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__1__axis_0_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__1__axis_0_enable_xla_True_mode_fill.mlir ---- stablehlo/stablehlo/testdata/gather_from_take_indices_name__1__axis_0_enable_xla_True_mode_fill.mlir -+++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__1__axis_0_enable_xla_True_mode_fill.mlir -@@ -39,7 +39,7 @@ - %17 = stablehlo.compare LT, %6, %16, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %18 = stablehlo.constant dense<3> : tensor - %19 = stablehlo.broadcast_in_dim %18, dims = [] : (tensor) -> tensor<1xi32> -- %20 = stablehlo.add %6, %19 : tensor<1xi32> -+ %20 = stablehlo.add %19, %6 : tensor<1xi32> - %21 = stablehlo.select %17, %20, %6 : tensor<1xi1>, tensor<1xi32> - %22 = stablehlo.broadcast_in_dim %21, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %23 = "stablehlo.gather"(%14, %22) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> -@@ -55,7 +55,7 @@ - %33 = stablehlo.compare LT, %7, %32, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %34 = stablehlo.constant dense<3> : tensor - %35 = stablehlo.broadcast_in_dim %34, dims = [] : (tensor) -> tensor<1xi32> -- %36 = stablehlo.add %7, %35 : tensor<1xi32> -+ %36 = stablehlo.add %35, %7 : tensor<1xi32> - %37 = stablehlo.select %33, %36, %7 : tensor<1xi1>, tensor<1xi32> - %38 = stablehlo.broadcast_in_dim %37, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %39 = "stablehlo.gather"(%30, %38) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> -diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__1__axis_1_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__1__axis_1_enable_xla_True_mode_fill.mlir ---- stablehlo/stablehlo/testdata/gather_from_take_indices_name__1__axis_1_enable_xla_True_mode_fill.mlir -+++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__1__axis_1_enable_xla_True_mode_fill.mlir -@@ -39,7 +39,7 @@ - %17 = stablehlo.compare LT, %6, %16, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %18 = stablehlo.constant dense<3> : tensor - %19 = stablehlo.broadcast_in_dim %18, dims = [] : (tensor) -> tensor<1xi32> -- %20 = stablehlo.add %6, %19 : tensor<1xi32> -+ %20 = stablehlo.add %19, %6 : tensor<1xi32> - %21 = stablehlo.select %17, %20, %6 : tensor<1xi1>, tensor<1xi32> - %22 = stablehlo.broadcast_in_dim %21, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %23 = "stablehlo.gather"(%14, %22) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> -@@ -55,7 +55,7 @@ - %33 = stablehlo.compare LT, %7, %32, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %34 = stablehlo.constant dense<3> : tensor - %35 = stablehlo.broadcast_in_dim %34, dims = [] : (tensor) -> tensor<1xi32> -- %36 = stablehlo.add %7, %35 : tensor<1xi32> -+ %36 = stablehlo.add %35, %7 : tensor<1xi32> - %37 = stablehlo.select %33, %36, %7 : tensor<1xi1>, tensor<1xi32> - %38 = stablehlo.broadcast_in_dim %37, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %39 = "stablehlo.gather"(%30, %38) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> -diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__1__axis_2_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__1__axis_2_enable_xla_True_mode_fill.mlir ---- stablehlo/stablehlo/testdata/gather_from_take_indices_name__1__axis_2_enable_xla_True_mode_fill.mlir -+++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__1__axis_2_enable_xla_True_mode_fill.mlir -@@ -39,7 +39,7 @@ - %17 = stablehlo.compare LT, %6, %16, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %18 = stablehlo.constant dense<3> : tensor - %19 = stablehlo.broadcast_in_dim %18, dims = [] : (tensor) -> tensor<1xi32> -- %20 = stablehlo.add %6, %19 : tensor<1xi32> -+ %20 = stablehlo.add %19, %6 : tensor<1xi32> - %21 = stablehlo.select %17, %20, %6 : tensor<1xi1>, tensor<1xi32> - %22 = stablehlo.broadcast_in_dim %21, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %23 = "stablehlo.gather"(%14, %22) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> -@@ -55,7 +55,7 @@ - %33 = stablehlo.compare LT, %7, %32, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %34 = stablehlo.constant dense<3> : tensor - %35 = stablehlo.broadcast_in_dim %34, dims = [] : (tensor) -> tensor<1xi32> -- %36 = stablehlo.add %7, %35 : tensor<1xi32> -+ %36 = stablehlo.add %35, %7 : tensor<1xi32> - %37 = stablehlo.select %33, %36, %7 : tensor<1xi1>, tensor<1xi32> - %38 = stablehlo.broadcast_in_dim %37, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %39 = "stablehlo.gather"(%30, %38) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> -diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__2__axis_0_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__2__axis_0_enable_xla_True_mode_fill.mlir ---- stablehlo/stablehlo/testdata/gather_from_take_indices_name__2__axis_0_enable_xla_True_mode_fill.mlir -+++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__2__axis_0_enable_xla_True_mode_fill.mlir -@@ -41,7 +41,7 @@ - %19 = stablehlo.compare LT, %8, %18, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %20 = stablehlo.constant dense<3> : tensor - %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<1xi32> -- %22 = stablehlo.add %8, %21 : tensor<1xi32> -+ %22 = stablehlo.add %21, %8 : tensor<1xi32> - %23 = stablehlo.select %19, %22, %8 : tensor<1xi1>, tensor<1xi32> - %24 = stablehlo.broadcast_in_dim %23, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %25 = "stablehlo.gather"(%16, %24) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> -@@ -57,7 +57,7 @@ - %35 = stablehlo.compare LT, %9, %34, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %36 = stablehlo.constant dense<3> : tensor - %37 = stablehlo.broadcast_in_dim %36, dims = [] : (tensor) -> tensor<1xi32> -- %38 = stablehlo.add %9, %37 : tensor<1xi32> -+ %38 = stablehlo.add %37, %9 : tensor<1xi32> - %39 = stablehlo.select %35, %38, %9 : tensor<1xi1>, tensor<1xi32> - %40 = stablehlo.broadcast_in_dim %39, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %41 = "stablehlo.gather"(%32, %40) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> -diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__2__axis_1_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__2__axis_1_enable_xla_True_mode_fill.mlir ---- stablehlo/stablehlo/testdata/gather_from_take_indices_name__2__axis_1_enable_xla_True_mode_fill.mlir -+++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__2__axis_1_enable_xla_True_mode_fill.mlir -@@ -41,7 +41,7 @@ - %19 = stablehlo.compare LT, %8, %18, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %20 = stablehlo.constant dense<3> : tensor - %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<1xi32> -- %22 = stablehlo.add %8, %21 : tensor<1xi32> -+ %22 = stablehlo.add %21, %8 : tensor<1xi32> - %23 = stablehlo.select %19, %22, %8 : tensor<1xi1>, tensor<1xi32> - %24 = stablehlo.broadcast_in_dim %23, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %25 = "stablehlo.gather"(%16, %24) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> -@@ -57,7 +57,7 @@ - %35 = stablehlo.compare LT, %9, %34, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %36 = stablehlo.constant dense<3> : tensor - %37 = stablehlo.broadcast_in_dim %36, dims = [] : (tensor) -> tensor<1xi32> -- %38 = stablehlo.add %9, %37 : tensor<1xi32> -+ %38 = stablehlo.add %37, %9 : tensor<1xi32> - %39 = stablehlo.select %35, %38, %9 : tensor<1xi1>, tensor<1xi32> - %40 = stablehlo.broadcast_in_dim %39, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %41 = "stablehlo.gather"(%32, %40) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> -diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__2__axis_2_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__2__axis_2_enable_xla_True_mode_fill.mlir ---- stablehlo/stablehlo/testdata/gather_from_take_indices_name__2__axis_2_enable_xla_True_mode_fill.mlir -+++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__2__axis_2_enable_xla_True_mode_fill.mlir -@@ -41,7 +41,7 @@ - %19 = stablehlo.compare LT, %8, %18, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %20 = stablehlo.constant dense<3> : tensor - %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<1xi32> -- %22 = stablehlo.add %8, %21 : tensor<1xi32> -+ %22 = stablehlo.add %21, %8 : tensor<1xi32> - %23 = stablehlo.select %19, %22, %8 : tensor<1xi1>, tensor<1xi32> - %24 = stablehlo.broadcast_in_dim %23, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %25 = "stablehlo.gather"(%16, %24) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> -@@ -57,7 +57,7 @@ - %35 = stablehlo.compare LT, %9, %34, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %36 = stablehlo.constant dense<3> : tensor - %37 = stablehlo.broadcast_in_dim %36, dims = [] : (tensor) -> tensor<1xi32> -- %38 = stablehlo.add %9, %37 : tensor<1xi32> -+ %38 = stablehlo.add %37, %9 : tensor<1xi32> - %39 = stablehlo.select %35, %38, %9 : tensor<1xi1>, tensor<1xi32> - %40 = stablehlo.broadcast_in_dim %39, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %41 = "stablehlo.gather"(%32, %40) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> -diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__3__axis_0_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__3__axis_0_enable_xla_True_mode_fill.mlir ---- stablehlo/stablehlo/testdata/gather_from_take_indices_name__3__axis_0_enable_xla_True_mode_fill.mlir -+++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__3__axis_0_enable_xla_True_mode_fill.mlir -@@ -41,7 +41,7 @@ - %19 = stablehlo.compare LT, %8, %18, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %20 = stablehlo.constant dense<3> : tensor - %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<1xi32> -- %22 = stablehlo.add %8, %21 : tensor<1xi32> -+ %22 = stablehlo.add %21, %8 : tensor<1xi32> - %23 = stablehlo.select %19, %22, %8 : tensor<1xi1>, tensor<1xi32> - %24 = stablehlo.broadcast_in_dim %23, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %25 = "stablehlo.gather"(%16, %24) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> -@@ -57,7 +57,7 @@ - %35 = stablehlo.compare LT, %9, %34, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %36 = stablehlo.constant dense<3> : tensor - %37 = stablehlo.broadcast_in_dim %36, dims = [] : (tensor) -> tensor<1xi32> -- %38 = stablehlo.add %9, %37 : tensor<1xi32> -+ %38 = stablehlo.add %37, %9 : tensor<1xi32> - %39 = stablehlo.select %35, %38, %9 : tensor<1xi1>, tensor<1xi32> - %40 = stablehlo.broadcast_in_dim %39, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %41 = "stablehlo.gather"(%32, %40) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> -diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__3__axis_1_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__3__axis_1_enable_xla_True_mode_fill.mlir ---- stablehlo/stablehlo/testdata/gather_from_take_indices_name__3__axis_1_enable_xla_True_mode_fill.mlir -+++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__3__axis_1_enable_xla_True_mode_fill.mlir -@@ -41,7 +41,7 @@ - %19 = stablehlo.compare LT, %8, %18, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %20 = stablehlo.constant dense<3> : tensor - %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<1xi32> -- %22 = stablehlo.add %8, %21 : tensor<1xi32> -+ %22 = stablehlo.add %21, %8 : tensor<1xi32> - %23 = stablehlo.select %19, %22, %8 : tensor<1xi1>, tensor<1xi32> - %24 = stablehlo.broadcast_in_dim %23, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %25 = "stablehlo.gather"(%16, %24) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> -@@ -57,7 +57,7 @@ - %35 = stablehlo.compare LT, %9, %34, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %36 = stablehlo.constant dense<3> : tensor - %37 = stablehlo.broadcast_in_dim %36, dims = [] : (tensor) -> tensor<1xi32> -- %38 = stablehlo.add %9, %37 : tensor<1xi32> -+ %38 = stablehlo.add %37, %9 : tensor<1xi32> - %39 = stablehlo.select %35, %38, %9 : tensor<1xi1>, tensor<1xi32> - %40 = stablehlo.broadcast_in_dim %39, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %41 = "stablehlo.gather"(%32, %40) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> -diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__3__axis_2_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__3__axis_2_enable_xla_True_mode_fill.mlir ---- stablehlo/stablehlo/testdata/gather_from_take_indices_name__3__axis_2_enable_xla_True_mode_fill.mlir -+++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__3__axis_2_enable_xla_True_mode_fill.mlir -@@ -41,7 +41,7 @@ - %19 = stablehlo.compare LT, %8, %18, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %20 = stablehlo.constant dense<3> : tensor - %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<1xi32> -- %22 = stablehlo.add %8, %21 : tensor<1xi32> -+ %22 = stablehlo.add %21, %8 : tensor<1xi32> - %23 = stablehlo.select %19, %22, %8 : tensor<1xi1>, tensor<1xi32> - %24 = stablehlo.broadcast_in_dim %23, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %25 = "stablehlo.gather"(%16, %24) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> -@@ -57,7 +57,7 @@ - %35 = stablehlo.compare LT, %9, %34, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %36 = stablehlo.constant dense<3> : tensor - %37 = stablehlo.broadcast_in_dim %36, dims = [] : (tensor) -> tensor<1xi32> -- %38 = stablehlo.add %9, %37 : tensor<1xi32> -+ %38 = stablehlo.add %37, %9 : tensor<1xi32> - %39 = stablehlo.select %35, %38, %9 : tensor<1xi1>, tensor<1xi32> - %40 = stablehlo.broadcast_in_dim %39, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %41 = "stablehlo.gather"(%32, %40) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> -diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__3_uint32__axis_0_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__3_uint32__axis_0_enable_xla_True_mode_fill.mlir ---- stablehlo/stablehlo/testdata/gather_from_take_indices_name__3_uint32__axis_0_enable_xla_True_mode_fill.mlir -+++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__3_uint32__axis_0_enable_xla_True_mode_fill.mlir -@@ -42,7 +42,7 @@ - %20 = stablehlo.compare LT, %8, %19, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %21 = stablehlo.constant dense<3> : tensor - %22 = stablehlo.broadcast_in_dim %21, dims = [] : (tensor) -> tensor<1xi32> -- %23 = stablehlo.add %8, %22 : tensor<1xi32> -+ %23 = stablehlo.add %22, %8 : tensor<1xi32> - %24 = stablehlo.select %20, %23, %8 : tensor<1xi1>, tensor<1xi32> - %25 = stablehlo.broadcast_in_dim %24, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %26 = "stablehlo.gather"(%16, %25) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> -@@ -58,7 +58,7 @@ - %36 = stablehlo.compare LT, %9, %35, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %37 = stablehlo.constant dense<3> : tensor - %38 = stablehlo.broadcast_in_dim %37, dims = [] : (tensor) -> tensor<1xi32> -- %39 = stablehlo.add %9, %38 : tensor<1xi32> -+ %39 = stablehlo.add %38, %9 : tensor<1xi32> - %40 = stablehlo.select %36, %39, %9 : tensor<1xi1>, tensor<1xi32> - %41 = stablehlo.broadcast_in_dim %40, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %42 = "stablehlo.gather"(%33, %41) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> -diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__3_uint32__axis_1_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__3_uint32__axis_1_enable_xla_True_mode_fill.mlir ---- stablehlo/stablehlo/testdata/gather_from_take_indices_name__3_uint32__axis_1_enable_xla_True_mode_fill.mlir -+++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__3_uint32__axis_1_enable_xla_True_mode_fill.mlir -@@ -42,7 +42,7 @@ - %20 = stablehlo.compare LT, %8, %19, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %21 = stablehlo.constant dense<3> : tensor - %22 = stablehlo.broadcast_in_dim %21, dims = [] : (tensor) -> tensor<1xi32> -- %23 = stablehlo.add %8, %22 : tensor<1xi32> -+ %23 = stablehlo.add %22, %8 : tensor<1xi32> - %24 = stablehlo.select %20, %23, %8 : tensor<1xi1>, tensor<1xi32> - %25 = stablehlo.broadcast_in_dim %24, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %26 = "stablehlo.gather"(%16, %25) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> -@@ -58,7 +58,7 @@ - %36 = stablehlo.compare LT, %9, %35, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %37 = stablehlo.constant dense<3> : tensor - %38 = stablehlo.broadcast_in_dim %37, dims = [] : (tensor) -> tensor<1xi32> -- %39 = stablehlo.add %9, %38 : tensor<1xi32> -+ %39 = stablehlo.add %38, %9 : tensor<1xi32> - %40 = stablehlo.select %36, %39, %9 : tensor<1xi1>, tensor<1xi32> - %41 = stablehlo.broadcast_in_dim %40, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %42 = "stablehlo.gather"(%33, %41) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> -diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__3_uint32__axis_2_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__3_uint32__axis_2_enable_xla_True_mode_fill.mlir ---- stablehlo/stablehlo/testdata/gather_from_take_indices_name__3_uint32__axis_2_enable_xla_True_mode_fill.mlir -+++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__3_uint32__axis_2_enable_xla_True_mode_fill.mlir -@@ -42,7 +42,7 @@ - %20 = stablehlo.compare LT, %8, %19, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %21 = stablehlo.constant dense<3> : tensor - %22 = stablehlo.broadcast_in_dim %21, dims = [] : (tensor) -> tensor<1xi32> -- %23 = stablehlo.add %8, %22 : tensor<1xi32> -+ %23 = stablehlo.add %22, %8 : tensor<1xi32> - %24 = stablehlo.select %20, %23, %8 : tensor<1xi1>, tensor<1xi32> - %25 = stablehlo.broadcast_in_dim %24, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %26 = "stablehlo.gather"(%16, %25) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> -@@ -58,7 +58,7 @@ - %36 = stablehlo.compare LT, %9, %35, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %37 = stablehlo.constant dense<3> : tensor - %38 = stablehlo.broadcast_in_dim %37, dims = [] : (tensor) -> tensor<1xi32> -- %39 = stablehlo.add %9, %38 : tensor<1xi32> -+ %39 = stablehlo.add %38, %9 : tensor<1xi32> - %40 = stablehlo.select %36, %39, %9 : tensor<1xi1>, tensor<1xi32> - %41 = stablehlo.broadcast_in_dim %40, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %42 = "stablehlo.gather"(%33, %41) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> -diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__4__axis_0_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__4__axis_0_enable_xla_True_mode_fill.mlir ---- stablehlo/stablehlo/testdata/gather_from_take_indices_name__4__axis_0_enable_xla_True_mode_fill.mlir -+++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__4__axis_0_enable_xla_True_mode_fill.mlir -@@ -41,7 +41,7 @@ - %19 = stablehlo.compare LT, %8, %18, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %20 = stablehlo.constant dense<3> : tensor - %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<1xi32> -- %22 = stablehlo.add %8, %21 : tensor<1xi32> -+ %22 = stablehlo.add %21, %8 : tensor<1xi32> - %23 = stablehlo.select %19, %22, %8 : tensor<1xi1>, tensor<1xi32> - %24 = stablehlo.broadcast_in_dim %23, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %25 = "stablehlo.gather"(%16, %24) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> -@@ -57,7 +57,7 @@ - %35 = stablehlo.compare LT, %9, %34, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %36 = stablehlo.constant dense<3> : tensor - %37 = stablehlo.broadcast_in_dim %36, dims = [] : (tensor) -> tensor<1xi32> -- %38 = stablehlo.add %9, %37 : tensor<1xi32> -+ %38 = stablehlo.add %37, %9 : tensor<1xi32> - %39 = stablehlo.select %35, %38, %9 : tensor<1xi1>, tensor<1xi32> - %40 = stablehlo.broadcast_in_dim %39, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %41 = "stablehlo.gather"(%32, %40) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> -diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__4__axis_1_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__4__axis_1_enable_xla_True_mode_fill.mlir ---- stablehlo/stablehlo/testdata/gather_from_take_indices_name__4__axis_1_enable_xla_True_mode_fill.mlir -+++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__4__axis_1_enable_xla_True_mode_fill.mlir -@@ -41,7 +41,7 @@ - %19 = stablehlo.compare LT, %8, %18, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %20 = stablehlo.constant dense<3> : tensor - %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<1xi32> -- %22 = stablehlo.add %8, %21 : tensor<1xi32> -+ %22 = stablehlo.add %21, %8 : tensor<1xi32> - %23 = stablehlo.select %19, %22, %8 : tensor<1xi1>, tensor<1xi32> - %24 = stablehlo.broadcast_in_dim %23, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %25 = "stablehlo.gather"(%16, %24) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> -@@ -57,7 +57,7 @@ - %35 = stablehlo.compare LT, %9, %34, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %36 = stablehlo.constant dense<3> : tensor - %37 = stablehlo.broadcast_in_dim %36, dims = [] : (tensor) -> tensor<1xi32> -- %38 = stablehlo.add %9, %37 : tensor<1xi32> -+ %38 = stablehlo.add %37, %9 : tensor<1xi32> - %39 = stablehlo.select %35, %38, %9 : tensor<1xi1>, tensor<1xi32> - %40 = stablehlo.broadcast_in_dim %39, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %41 = "stablehlo.gather"(%32, %40) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> -diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__4__axis_2_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__4__axis_2_enable_xla_True_mode_fill.mlir ---- stablehlo/stablehlo/testdata/gather_from_take_indices_name__4__axis_2_enable_xla_True_mode_fill.mlir -+++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__4__axis_2_enable_xla_True_mode_fill.mlir -@@ -41,7 +41,7 @@ - %19 = stablehlo.compare LT, %8, %18, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %20 = stablehlo.constant dense<3> : tensor - %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<1xi32> -- %22 = stablehlo.add %8, %21 : tensor<1xi32> -+ %22 = stablehlo.add %21, %8 : tensor<1xi32> - %23 = stablehlo.select %19, %22, %8 : tensor<1xi1>, tensor<1xi32> - %24 = stablehlo.broadcast_in_dim %23, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %25 = "stablehlo.gather"(%16, %24) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> -@@ -57,7 +57,7 @@ - %35 = stablehlo.compare LT, %9, %34, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %36 = stablehlo.constant dense<3> : tensor - %37 = stablehlo.broadcast_in_dim %36, dims = [] : (tensor) -> tensor<1xi32> -- %38 = stablehlo.add %9, %37 : tensor<1xi32> -+ %38 = stablehlo.add %37, %9 : tensor<1xi32> - %39 = stablehlo.select %35, %38, %9 : tensor<1xi1>, tensor<1xi32> - %40 = stablehlo.broadcast_in_dim %39, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %41 = "stablehlo.gather"(%32, %40) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> -diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__5_oob__axis_0_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__5_oob__axis_0_enable_xla_True_mode_fill.mlir ---- stablehlo/stablehlo/testdata/gather_from_take_indices_name__5_oob__axis_0_enable_xla_True_mode_fill.mlir -+++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__5_oob__axis_0_enable_xla_True_mode_fill.mlir -@@ -41,7 +41,7 @@ - %19 = stablehlo.compare LT, %8, %18, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %20 = stablehlo.constant dense<3> : tensor - %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<1xi32> -- %22 = stablehlo.add %8, %21 : tensor<1xi32> -+ %22 = stablehlo.add %21, %8 : tensor<1xi32> - %23 = stablehlo.select %19, %22, %8 : tensor<1xi1>, tensor<1xi32> - %24 = stablehlo.broadcast_in_dim %23, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %25 = "stablehlo.gather"(%16, %24) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> -@@ -57,7 +57,7 @@ - %35 = stablehlo.compare LT, %9, %34, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %36 = stablehlo.constant dense<3> : tensor - %37 = stablehlo.broadcast_in_dim %36, dims = [] : (tensor) -> tensor<1xi32> -- %38 = stablehlo.add %9, %37 : tensor<1xi32> -+ %38 = stablehlo.add %37, %9 : tensor<1xi32> - %39 = stablehlo.select %35, %38, %9 : tensor<1xi1>, tensor<1xi32> - %40 = stablehlo.broadcast_in_dim %39, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %41 = "stablehlo.gather"(%32, %40) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> -diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__5_oob__axis_1_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__5_oob__axis_1_enable_xla_True_mode_fill.mlir ---- stablehlo/stablehlo/testdata/gather_from_take_indices_name__5_oob__axis_1_enable_xla_True_mode_fill.mlir -+++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__5_oob__axis_1_enable_xla_True_mode_fill.mlir -@@ -41,7 +41,7 @@ - %19 = stablehlo.compare LT, %8, %18, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %20 = stablehlo.constant dense<3> : tensor - %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<1xi32> -- %22 = stablehlo.add %8, %21 : tensor<1xi32> -+ %22 = stablehlo.add %21, %8 : tensor<1xi32> - %23 = stablehlo.select %19, %22, %8 : tensor<1xi1>, tensor<1xi32> - %24 = stablehlo.broadcast_in_dim %23, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %25 = "stablehlo.gather"(%16, %24) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> -@@ -57,7 +57,7 @@ - %35 = stablehlo.compare LT, %9, %34, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %36 = stablehlo.constant dense<3> : tensor - %37 = stablehlo.broadcast_in_dim %36, dims = [] : (tensor) -> tensor<1xi32> -- %38 = stablehlo.add %9, %37 : tensor<1xi32> -+ %38 = stablehlo.add %37, %9 : tensor<1xi32> - %39 = stablehlo.select %35, %38, %9 : tensor<1xi1>, tensor<1xi32> - %40 = stablehlo.broadcast_in_dim %39, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %41 = "stablehlo.gather"(%32, %40) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> -diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__5_oob__axis_2_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__5_oob__axis_2_enable_xla_True_mode_fill.mlir ---- stablehlo/stablehlo/testdata/gather_from_take_indices_name__5_oob__axis_2_enable_xla_True_mode_fill.mlir -+++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__5_oob__axis_2_enable_xla_True_mode_fill.mlir -@@ -41,7 +41,7 @@ - %19 = stablehlo.compare LT, %8, %18, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %20 = stablehlo.constant dense<3> : tensor - %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<1xi32> -- %22 = stablehlo.add %8, %21 : tensor<1xi32> -+ %22 = stablehlo.add %21, %8 : tensor<1xi32> - %23 = stablehlo.select %19, %22, %8 : tensor<1xi1>, tensor<1xi32> - %24 = stablehlo.broadcast_in_dim %23, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %25 = "stablehlo.gather"(%16, %24) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> -@@ -57,7 +57,7 @@ - %35 = stablehlo.compare LT, %9, %34, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %36 = stablehlo.constant dense<3> : tensor - %37 = stablehlo.broadcast_in_dim %36, dims = [] : (tensor) -> tensor<1xi32> -- %38 = stablehlo.add %9, %37 : tensor<1xi32> -+ %38 = stablehlo.add %37, %9 : tensor<1xi32> - %39 = stablehlo.select %35, %38, %9 : tensor<1xi1>, tensor<1xi32> - %40 = stablehlo.broadcast_in_dim %39, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %41 = "stablehlo.gather"(%32, %40) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> -diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__6_neg__axis_0_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__6_neg__axis_0_enable_xla_True_mode_fill.mlir ---- stablehlo/stablehlo/testdata/gather_from_take_indices_name__6_neg__axis_0_enable_xla_True_mode_fill.mlir -+++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__6_neg__axis_0_enable_xla_True_mode_fill.mlir -@@ -41,7 +41,7 @@ - %19 = stablehlo.compare LT, %8, %18, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %20 = stablehlo.constant dense<3> : tensor - %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<1xi32> -- %22 = stablehlo.add %8, %21 : tensor<1xi32> -+ %22 = stablehlo.add %21, %8 : tensor<1xi32> - %23 = stablehlo.select %19, %22, %8 : tensor<1xi1>, tensor<1xi32> - %24 = stablehlo.broadcast_in_dim %23, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %25 = "stablehlo.gather"(%16, %24) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> -@@ -57,7 +57,7 @@ - %35 = stablehlo.compare LT, %9, %34, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %36 = stablehlo.constant dense<3> : tensor - %37 = stablehlo.broadcast_in_dim %36, dims = [] : (tensor) -> tensor<1xi32> -- %38 = stablehlo.add %9, %37 : tensor<1xi32> -+ %38 = stablehlo.add %37, %9 : tensor<1xi32> - %39 = stablehlo.select %35, %38, %9 : tensor<1xi1>, tensor<1xi32> - %40 = stablehlo.broadcast_in_dim %39, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %41 = "stablehlo.gather"(%32, %40) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> -diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__6_neg__axis_1_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__6_neg__axis_1_enable_xla_True_mode_fill.mlir ---- stablehlo/stablehlo/testdata/gather_from_take_indices_name__6_neg__axis_1_enable_xla_True_mode_fill.mlir -+++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__6_neg__axis_1_enable_xla_True_mode_fill.mlir -@@ -41,7 +41,7 @@ - %19 = stablehlo.compare LT, %8, %18, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %20 = stablehlo.constant dense<3> : tensor - %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<1xi32> -- %22 = stablehlo.add %8, %21 : tensor<1xi32> -+ %22 = stablehlo.add %21, %8 : tensor<1xi32> - %23 = stablehlo.select %19, %22, %8 : tensor<1xi1>, tensor<1xi32> - %24 = stablehlo.broadcast_in_dim %23, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %25 = "stablehlo.gather"(%16, %24) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> -@@ -57,7 +57,7 @@ - %35 = stablehlo.compare LT, %9, %34, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %36 = stablehlo.constant dense<3> : tensor - %37 = stablehlo.broadcast_in_dim %36, dims = [] : (tensor) -> tensor<1xi32> -- %38 = stablehlo.add %9, %37 : tensor<1xi32> -+ %38 = stablehlo.add %37, %9 : tensor<1xi32> - %39 = stablehlo.select %35, %38, %9 : tensor<1xi1>, tensor<1xi32> - %40 = stablehlo.broadcast_in_dim %39, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %41 = "stablehlo.gather"(%32, %40) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> -diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__6_neg__axis_2_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__6_neg__axis_2_enable_xla_True_mode_fill.mlir ---- stablehlo/stablehlo/testdata/gather_from_take_indices_name__6_neg__axis_2_enable_xla_True_mode_fill.mlir -+++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__6_neg__axis_2_enable_xla_True_mode_fill.mlir -@@ -41,7 +41,7 @@ - %19 = stablehlo.compare LT, %8, %18, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %20 = stablehlo.constant dense<3> : tensor - %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<1xi32> -- %22 = stablehlo.add %8, %21 : tensor<1xi32> -+ %22 = stablehlo.add %21, %8 : tensor<1xi32> - %23 = stablehlo.select %19, %22, %8 : tensor<1xi1>, tensor<1xi32> - %24 = stablehlo.broadcast_in_dim %23, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %25 = "stablehlo.gather"(%16, %24) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> -@@ -57,7 +57,7 @@ - %35 = stablehlo.compare LT, %9, %34, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %36 = stablehlo.constant dense<3> : tensor - %37 = stablehlo.broadcast_in_dim %36, dims = [] : (tensor) -> tensor<1xi32> -- %38 = stablehlo.add %9, %37 : tensor<1xi32> -+ %38 = stablehlo.add %37, %9 : tensor<1xi32> - %39 = stablehlo.select %35, %38, %9 : tensor<1xi1>, tensor<1xi32> - %40 = stablehlo.broadcast_in_dim %39, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %41 = "stablehlo.gather"(%32, %40) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> -diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__7_neg__axis_0_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__7_neg__axis_0_enable_xla_True_mode_fill.mlir ---- stablehlo/stablehlo/testdata/gather_from_take_indices_name__7_neg__axis_0_enable_xla_True_mode_fill.mlir -+++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__7_neg__axis_0_enable_xla_True_mode_fill.mlir -@@ -41,7 +41,7 @@ - %19 = stablehlo.compare LT, %8, %18, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %20 = stablehlo.constant dense<3> : tensor - %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<1xi32> -- %22 = stablehlo.add %8, %21 : tensor<1xi32> -+ %22 = stablehlo.add %21, %8 : tensor<1xi32> - %23 = stablehlo.select %19, %22, %8 : tensor<1xi1>, tensor<1xi32> - %24 = stablehlo.broadcast_in_dim %23, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %25 = "stablehlo.gather"(%16, %24) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> -@@ -57,7 +57,7 @@ - %35 = stablehlo.compare LT, %9, %34, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %36 = stablehlo.constant dense<3> : tensor - %37 = stablehlo.broadcast_in_dim %36, dims = [] : (tensor) -> tensor<1xi32> -- %38 = stablehlo.add %9, %37 : tensor<1xi32> -+ %38 = stablehlo.add %37, %9 : tensor<1xi32> - %39 = stablehlo.select %35, %38, %9 : tensor<1xi1>, tensor<1xi32> - %40 = stablehlo.broadcast_in_dim %39, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %41 = "stablehlo.gather"(%32, %40) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> -diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__7_neg__axis_1_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__7_neg__axis_1_enable_xla_True_mode_fill.mlir ---- stablehlo/stablehlo/testdata/gather_from_take_indices_name__7_neg__axis_1_enable_xla_True_mode_fill.mlir -+++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__7_neg__axis_1_enable_xla_True_mode_fill.mlir -@@ -41,7 +41,7 @@ - %19 = stablehlo.compare LT, %8, %18, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %20 = stablehlo.constant dense<3> : tensor - %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<1xi32> -- %22 = stablehlo.add %8, %21 : tensor<1xi32> -+ %22 = stablehlo.add %21, %8 : tensor<1xi32> - %23 = stablehlo.select %19, %22, %8 : tensor<1xi1>, tensor<1xi32> - %24 = stablehlo.broadcast_in_dim %23, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %25 = "stablehlo.gather"(%16, %24) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> -@@ -57,7 +57,7 @@ - %35 = stablehlo.compare LT, %9, %34, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %36 = stablehlo.constant dense<3> : tensor - %37 = stablehlo.broadcast_in_dim %36, dims = [] : (tensor) -> tensor<1xi32> -- %38 = stablehlo.add %9, %37 : tensor<1xi32> -+ %38 = stablehlo.add %37, %9 : tensor<1xi32> - %39 = stablehlo.select %35, %38, %9 : tensor<1xi1>, tensor<1xi32> - %40 = stablehlo.broadcast_in_dim %39, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %41 = "stablehlo.gather"(%32, %40) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> -diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__7_neg__axis_2_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__7_neg__axis_2_enable_xla_True_mode_fill.mlir ---- stablehlo/stablehlo/testdata/gather_from_take_indices_name__7_neg__axis_2_enable_xla_True_mode_fill.mlir -+++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__7_neg__axis_2_enable_xla_True_mode_fill.mlir -@@ -41,7 +41,7 @@ - %19 = stablehlo.compare LT, %8, %18, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %20 = stablehlo.constant dense<3> : tensor - %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<1xi32> -- %22 = stablehlo.add %8, %21 : tensor<1xi32> -+ %22 = stablehlo.add %21, %8 : tensor<1xi32> - %23 = stablehlo.select %19, %22, %8 : tensor<1xi1>, tensor<1xi32> - %24 = stablehlo.broadcast_in_dim %23, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %25 = "stablehlo.gather"(%16, %24) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> -@@ -57,7 +57,7 @@ - %35 = stablehlo.compare LT, %9, %34, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %36 = stablehlo.constant dense<3> : tensor - %37 = stablehlo.broadcast_in_dim %36, dims = [] : (tensor) -> tensor<1xi32> -- %38 = stablehlo.add %9, %37 : tensor<1xi32> -+ %38 = stablehlo.add %37, %9 : tensor<1xi32> - %39 = stablehlo.select %35, %38, %9 : tensor<1xi1>, tensor<1xi32> - %40 = stablehlo.broadcast_in_dim %39, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %41 = "stablehlo.gather"(%32, %40) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> -diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__8_neg_oob__axis_0_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__8_neg_oob__axis_0_enable_xla_True_mode_fill.mlir ---- stablehlo/stablehlo/testdata/gather_from_take_indices_name__8_neg_oob__axis_0_enable_xla_True_mode_fill.mlir -+++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__8_neg_oob__axis_0_enable_xla_True_mode_fill.mlir -@@ -41,7 +41,7 @@ - %19 = stablehlo.compare LT, %8, %18, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %20 = stablehlo.constant dense<3> : tensor - %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<1xi32> -- %22 = stablehlo.add %8, %21 : tensor<1xi32> -+ %22 = stablehlo.add %21, %8 : tensor<1xi32> - %23 = stablehlo.select %19, %22, %8 : tensor<1xi1>, tensor<1xi32> - %24 = stablehlo.broadcast_in_dim %23, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %25 = "stablehlo.gather"(%16, %24) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> -@@ -57,7 +57,7 @@ - %35 = stablehlo.compare LT, %9, %34, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %36 = stablehlo.constant dense<3> : tensor - %37 = stablehlo.broadcast_in_dim %36, dims = [] : (tensor) -> tensor<1xi32> -- %38 = stablehlo.add %9, %37 : tensor<1xi32> -+ %38 = stablehlo.add %37, %9 : tensor<1xi32> - %39 = stablehlo.select %35, %38, %9 : tensor<1xi1>, tensor<1xi32> - %40 = stablehlo.broadcast_in_dim %39, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %41 = "stablehlo.gather"(%32, %40) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> -diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__8_neg_oob__axis_1_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__8_neg_oob__axis_1_enable_xla_True_mode_fill.mlir ---- stablehlo/stablehlo/testdata/gather_from_take_indices_name__8_neg_oob__axis_1_enable_xla_True_mode_fill.mlir -+++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__8_neg_oob__axis_1_enable_xla_True_mode_fill.mlir -@@ -41,7 +41,7 @@ - %19 = stablehlo.compare LT, %8, %18, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %20 = stablehlo.constant dense<3> : tensor - %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<1xi32> -- %22 = stablehlo.add %8, %21 : tensor<1xi32> -+ %22 = stablehlo.add %21, %8 : tensor<1xi32> - %23 = stablehlo.select %19, %22, %8 : tensor<1xi1>, tensor<1xi32> - %24 = stablehlo.broadcast_in_dim %23, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %25 = "stablehlo.gather"(%16, %24) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> -@@ -57,7 +57,7 @@ - %35 = stablehlo.compare LT, %9, %34, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %36 = stablehlo.constant dense<3> : tensor - %37 = stablehlo.broadcast_in_dim %36, dims = [] : (tensor) -> tensor<1xi32> -- %38 = stablehlo.add %9, %37 : tensor<1xi32> -+ %38 = stablehlo.add %37, %9 : tensor<1xi32> - %39 = stablehlo.select %35, %38, %9 : tensor<1xi1>, tensor<1xi32> - %40 = stablehlo.broadcast_in_dim %39, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %41 = "stablehlo.gather"(%32, %40) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> -diff --ruN a/stablehlo/stablehlo/testdata/gather_from_take_indices_name__8_neg_oob__axis_2_enable_xla_True_mode_fill.mlir b/stablehlo/stablehlo/testdata/gather_from_take_indices_name__8_neg_oob__axis_2_enable_xla_True_mode_fill.mlir ---- stablehlo/stablehlo/testdata/gather_from_take_indices_name__8_neg_oob__axis_2_enable_xla_True_mode_fill.mlir -+++ stablehlo/stablehlo/testdata/gather_from_take_indices_name__8_neg_oob__axis_2_enable_xla_True_mode_fill.mlir -@@ -41,7 +41,7 @@ - %19 = stablehlo.compare LT, %8, %18, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %20 = stablehlo.constant dense<3> : tensor - %21 = stablehlo.broadcast_in_dim %20, dims = [] : (tensor) -> tensor<1xi32> -- %22 = stablehlo.add %8, %21 : tensor<1xi32> -+ %22 = stablehlo.add %21, %8 : tensor<1xi32> - %23 = stablehlo.select %19, %22, %8 : tensor<1xi1>, tensor<1xi32> - %24 = stablehlo.broadcast_in_dim %23, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %25 = "stablehlo.gather"(%16, %24) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> -@@ -57,7 +57,7 @@ - %35 = stablehlo.compare LT, %9, %34, SIGNED : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> - %36 = stablehlo.constant dense<3> : tensor - %37 = stablehlo.broadcast_in_dim %36, dims = [] : (tensor) -> tensor<1xi32> -- %38 = stablehlo.add %9, %37 : tensor<1xi32> -+ %38 = stablehlo.add %37, %9 : tensor<1xi32> - %39 = stablehlo.select %35, %38, %9 : tensor<1xi1>, tensor<1xi32> - %40 = stablehlo.broadcast_in_dim %39, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> - %41 = "stablehlo.gather"(%32, %40) {dimension_numbers = #stablehlo.gather, slice_sizes = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<1x1xi32>) -> tensor<1xi32> -diff --ruN a/stablehlo/stablehlo/testdata/igamma_broadcasting_lhs_float32_1_20__rhs_float32_20_20.mlir b/stablehlo/stablehlo/testdata/igamma_broadcasting_lhs_float32_1_20__rhs_float32_20_20.mlir ---- stablehlo/stablehlo/testdata/igamma_broadcasting_lhs_float32_1_20__rhs_float32_20_20.mlir -+++ stablehlo/stablehlo/testdata/igamma_broadcasting_lhs_float32_1_20__rhs_float32_20_20.mlir -@@ -202,7 +202,7 @@ - %21 = stablehlo.multiply %19, %20 : tensor<20x20xf32> - %22 = stablehlo.constant dense<-1.000000e+00> : tensor - %23 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xf32> -- %24 = stablehlo.multiply %23, %1 : tensor<20x20xf32> -+ %24 = stablehlo.multiply %1, %23 : tensor<20x20xf32> - %25 = stablehlo.multiply %24, %2 : tensor<20x20xf32> - %26 = stablehlo.multiply %6, %6 : tensor<20x20xf32> - %27 = stablehlo.divide %25, %26 : tensor<20x20xf32> -@@ -272,7 +272,7 @@ - %34 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> - %35 = stablehlo.subtract %34, %29 : tensor<20x20xf32> - %36 = stablehlo.select %32, %35, %29 : tensor<20x20xi1>, tensor<20x20xf32> -- %37 = stablehlo.multiply %26, %36 : tensor<20x20xf32> -+ %37 = stablehlo.multiply %36, %26 : tensor<20x20xf32> - %38 = stablehlo.sine %37 : tensor<20x20xf32> - %39 = stablehlo.log %38 : tensor<20x20xf32> - %40 = stablehlo.is_finite %39 : (tensor<20x20xf32>) -> tensor<20x20xi1> -@@ -290,17 +290,17 @@ - %52 = stablehlo.add %50, %51 : tensor<20x20xf32> - %53 = stablehlo.constant dense<7.500000e+00> : tensor - %54 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> -- %55 = stablehlo.add %54, %50 : tensor<20x20xf32> -+ %55 = stablehlo.add %50, %54 : tensor<20x20xf32> - %56 = stablehlo.constant dense<2.01490307> : tensor - %57 = stablehlo.constant dense<2.01490307> : tensor<20x20xf32> - %58 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> - %59 = stablehlo.divide %50, %58 : tensor<20x20xf32> - %60 = stablehlo.log_plus_one %59 : tensor<20x20xf32> -- %61 = stablehlo.add %57, %60 : tensor<20x20xf32> -+ %61 = stablehlo.add %60, %57 : tensor<20x20xf32> - %62 = stablehlo.divide %55, %61 : tensor<20x20xf32> - %63 = stablehlo.subtract %52, %62 : tensor<20x20xf32> - %64 = stablehlo.multiply %63, %61 : tensor<20x20xf32> -- %65 = stablehlo.add %45, %64 : tensor<20x20xf32> -+ %65 = stablehlo.add %64, %45 : tensor<20x20xf32> - %66 = stablehlo.constant dense<1.000000e+00> : tensor - %67 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> - %68 = stablehlo.constant dense<676.520386> : tensor -@@ -311,7 +311,7 @@ - %73 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> - %74 = stablehlo.add %72, %73 : tensor<20x20xf32> - %75 = stablehlo.divide %69, %74 : tensor<20x20xf32> -- %76 = stablehlo.add %67, %75 : tensor<20x20xf32> -+ %76 = stablehlo.add %75, %67 : tensor<20x20xf32> - %77 = stablehlo.constant dense<-1259.13916> : tensor - %78 = stablehlo.constant dense<-1259.13916> : tensor<20x20xf32> - %79 = stablehlo.constant dense<1.000000e+00> : tensor -@@ -585,7 +585,7 @@ - %240 = stablehlo.multiply %iterArg_4, %239 : tensor<20x20xf32> - %241 = stablehlo.constant dense<-1.000000e+00> : tensor - %242 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xf32> -- %243 = stablehlo.multiply %242, %iterArg_1 : tensor<20x20xf32> -+ %243 = stablehlo.multiply %iterArg_1, %242 : tensor<20x20xf32> - %244 = stablehlo.multiply %243, %iterArg_3 : tensor<20x20xf32> - %245 = stablehlo.multiply %227, %227 : tensor<20x20xf32> - %246 = stablehlo.divide %244, %245 : tensor<20x20xf32> -diff --ruN a/stablehlo/stablehlo/testdata/igamma_broadcasting_lhs_float32_20_20__rhs_float32_1_20.mlir b/stablehlo/stablehlo/testdata/igamma_broadcasting_lhs_float32_20_20__rhs_float32_1_20.mlir ---- stablehlo/stablehlo/testdata/igamma_broadcasting_lhs_float32_20_20__rhs_float32_1_20.mlir -+++ stablehlo/stablehlo/testdata/igamma_broadcasting_lhs_float32_20_20__rhs_float32_1_20.mlir -@@ -202,7 +202,7 @@ - %21 = stablehlo.multiply %19, %20 : tensor<20x20xf32> - %22 = stablehlo.constant dense<-1.000000e+00> : tensor - %23 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xf32> -- %24 = stablehlo.multiply %23, %1 : tensor<20x20xf32> -+ %24 = stablehlo.multiply %1, %23 : tensor<20x20xf32> - %25 = stablehlo.multiply %24, %2 : tensor<20x20xf32> - %26 = stablehlo.multiply %6, %6 : tensor<20x20xf32> - %27 = stablehlo.divide %25, %26 : tensor<20x20xf32> -@@ -272,7 +272,7 @@ - %34 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> - %35 = stablehlo.subtract %34, %29 : tensor<20x20xf32> - %36 = stablehlo.select %32, %35, %29 : tensor<20x20xi1>, tensor<20x20xf32> -- %37 = stablehlo.multiply %26, %36 : tensor<20x20xf32> -+ %37 = stablehlo.multiply %36, %26 : tensor<20x20xf32> - %38 = stablehlo.sine %37 : tensor<20x20xf32> - %39 = stablehlo.log %38 : tensor<20x20xf32> - %40 = stablehlo.is_finite %39 : (tensor<20x20xf32>) -> tensor<20x20xi1> -@@ -290,17 +290,17 @@ - %52 = stablehlo.add %50, %51 : tensor<20x20xf32> - %53 = stablehlo.constant dense<7.500000e+00> : tensor - %54 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> -- %55 = stablehlo.add %54, %50 : tensor<20x20xf32> -+ %55 = stablehlo.add %50, %54 : tensor<20x20xf32> - %56 = stablehlo.constant dense<2.01490307> : tensor - %57 = stablehlo.constant dense<2.01490307> : tensor<20x20xf32> - %58 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> - %59 = stablehlo.divide %50, %58 : tensor<20x20xf32> - %60 = stablehlo.log_plus_one %59 : tensor<20x20xf32> -- %61 = stablehlo.add %57, %60 : tensor<20x20xf32> -+ %61 = stablehlo.add %60, %57 : tensor<20x20xf32> - %62 = stablehlo.divide %55, %61 : tensor<20x20xf32> - %63 = stablehlo.subtract %52, %62 : tensor<20x20xf32> - %64 = stablehlo.multiply %63, %61 : tensor<20x20xf32> -- %65 = stablehlo.add %45, %64 : tensor<20x20xf32> -+ %65 = stablehlo.add %64, %45 : tensor<20x20xf32> - %66 = stablehlo.constant dense<1.000000e+00> : tensor - %67 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> - %68 = stablehlo.constant dense<676.520386> : tensor -@@ -311,7 +311,7 @@ - %73 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> - %74 = stablehlo.add %72, %73 : tensor<20x20xf32> - %75 = stablehlo.divide %69, %74 : tensor<20x20xf32> -- %76 = stablehlo.add %67, %75 : tensor<20x20xf32> -+ %76 = stablehlo.add %75, %67 : tensor<20x20xf32> - %77 = stablehlo.constant dense<-1259.13916> : tensor - %78 = stablehlo.constant dense<-1259.13916> : tensor<20x20xf32> - %79 = stablehlo.constant dense<1.000000e+00> : tensor -@@ -585,7 +585,7 @@ - %240 = stablehlo.multiply %iterArg_4, %239 : tensor<20x20xf32> - %241 = stablehlo.constant dense<-1.000000e+00> : tensor - %242 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xf32> -- %243 = stablehlo.multiply %242, %iterArg_1 : tensor<20x20xf32> -+ %243 = stablehlo.multiply %iterArg_1, %242 : tensor<20x20xf32> - %244 = stablehlo.multiply %243, %iterArg_3 : tensor<20x20xf32> - %245 = stablehlo.multiply %227, %227 : tensor<20x20xf32> - %246 = stablehlo.divide %244, %245 : tensor<20x20xf32> -diff --ruN a/stablehlo/stablehlo/testdata/igamma_dtypes_lhs_bfloat16_20_20__rhs_bfloat16_20_20.mlir b/stablehlo/stablehlo/testdata/igamma_dtypes_lhs_bfloat16_20_20__rhs_bfloat16_20_20.mlir ---- stablehlo/stablehlo/testdata/igamma_dtypes_lhs_bfloat16_20_20__rhs_bfloat16_20_20.mlir -+++ stablehlo/stablehlo/testdata/igamma_dtypes_lhs_bfloat16_20_20__rhs_bfloat16_20_20.mlir -@@ -202,7 +202,7 @@ - %21 = stablehlo.multiply %19, %20 : tensor<20x20xf32> - %22 = stablehlo.constant dense<-1.000000e+00> : tensor - %23 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xf32> -- %24 = stablehlo.multiply %23, %1 : tensor<20x20xf32> -+ %24 = stablehlo.multiply %1, %23 : tensor<20x20xf32> - %25 = stablehlo.multiply %24, %2 : tensor<20x20xf32> - %26 = stablehlo.multiply %6, %6 : tensor<20x20xf32> - %27 = stablehlo.divide %25, %26 : tensor<20x20xf32> -@@ -272,7 +272,7 @@ - %34 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> - %35 = stablehlo.subtract %34, %29 : tensor<20x20xf32> - %36 = stablehlo.select %32, %35, %29 : tensor<20x20xi1>, tensor<20x20xf32> -- %37 = stablehlo.multiply %26, %36 : tensor<20x20xf32> -+ %37 = stablehlo.multiply %36, %26 : tensor<20x20xf32> - %38 = stablehlo.sine %37 : tensor<20x20xf32> - %39 = stablehlo.log %38 : tensor<20x20xf32> - %40 = stablehlo.is_finite %39 : (tensor<20x20xf32>) -> tensor<20x20xi1> -@@ -290,17 +290,17 @@ - %52 = stablehlo.add %50, %51 : tensor<20x20xf32> - %53 = stablehlo.constant dense<7.500000e+00> : tensor - %54 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> -- %55 = stablehlo.add %54, %50 : tensor<20x20xf32> -+ %55 = stablehlo.add %50, %54 : tensor<20x20xf32> - %56 = stablehlo.constant dense<2.01490307> : tensor - %57 = stablehlo.constant dense<2.01490307> : tensor<20x20xf32> - %58 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> - %59 = stablehlo.divide %50, %58 : tensor<20x20xf32> - %60 = stablehlo.log_plus_one %59 : tensor<20x20xf32> -- %61 = stablehlo.add %57, %60 : tensor<20x20xf32> -+ %61 = stablehlo.add %60, %57 : tensor<20x20xf32> - %62 = stablehlo.divide %55, %61 : tensor<20x20xf32> - %63 = stablehlo.subtract %52, %62 : tensor<20x20xf32> - %64 = stablehlo.multiply %63, %61 : tensor<20x20xf32> -- %65 = stablehlo.add %45, %64 : tensor<20x20xf32> -+ %65 = stablehlo.add %64, %45 : tensor<20x20xf32> - %66 = stablehlo.constant dense<1.000000e+00> : tensor - %67 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> - %68 = stablehlo.constant dense<676.520386> : tensor -@@ -311,7 +311,7 @@ - %73 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> - %74 = stablehlo.add %72, %73 : tensor<20x20xf32> - %75 = stablehlo.divide %69, %74 : tensor<20x20xf32> -- %76 = stablehlo.add %67, %75 : tensor<20x20xf32> -+ %76 = stablehlo.add %75, %67 : tensor<20x20xf32> - %77 = stablehlo.constant dense<-1259.13916> : tensor - %78 = stablehlo.constant dense<-1259.13916> : tensor<20x20xf32> - %79 = stablehlo.constant dense<1.000000e+00> : tensor -@@ -585,7 +585,7 @@ - %241 = stablehlo.multiply %iterArg_4, %240 : tensor<20x20xf32> - %242 = stablehlo.constant dense<-1.000000e+00> : tensor - %243 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xf32> -- %244 = stablehlo.multiply %243, %iterArg_1 : tensor<20x20xf32> -+ %244 = stablehlo.multiply %iterArg_1, %243 : tensor<20x20xf32> - %245 = stablehlo.multiply %244, %iterArg_3 : tensor<20x20xf32> - %246 = stablehlo.multiply %228, %228 : tensor<20x20xf32> - %247 = stablehlo.divide %245, %246 : tensor<20x20xf32> -diff --ruN a/stablehlo/stablehlo/testdata/igamma_dtypes_lhs_float16_20_20__rhs_float16_20_20.mlir b/stablehlo/stablehlo/testdata/igamma_dtypes_lhs_float16_20_20__rhs_float16_20_20.mlir ---- stablehlo/stablehlo/testdata/igamma_dtypes_lhs_float16_20_20__rhs_float16_20_20.mlir -+++ stablehlo/stablehlo/testdata/igamma_dtypes_lhs_float16_20_20__rhs_float16_20_20.mlir -@@ -202,7 +202,7 @@ - %21 = stablehlo.multiply %19, %20 : tensor<20x20xf32> - %22 = stablehlo.constant dense<-1.000000e+00> : tensor - %23 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xf32> -- %24 = stablehlo.multiply %23, %1 : tensor<20x20xf32> -+ %24 = stablehlo.multiply %1, %23 : tensor<20x20xf32> - %25 = stablehlo.multiply %24, %2 : tensor<20x20xf32> - %26 = stablehlo.multiply %6, %6 : tensor<20x20xf32> - %27 = stablehlo.divide %25, %26 : tensor<20x20xf32> -@@ -272,7 +272,7 @@ - %34 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> - %35 = stablehlo.subtract %34, %29 : tensor<20x20xf32> - %36 = stablehlo.select %32, %35, %29 : tensor<20x20xi1>, tensor<20x20xf32> -- %37 = stablehlo.multiply %26, %36 : tensor<20x20xf32> -+ %37 = stablehlo.multiply %36, %26 : tensor<20x20xf32> - %38 = stablehlo.sine %37 : tensor<20x20xf32> - %39 = stablehlo.log %38 : tensor<20x20xf32> - %40 = stablehlo.is_finite %39 : (tensor<20x20xf32>) -> tensor<20x20xi1> -@@ -290,17 +290,17 @@ - %52 = stablehlo.add %50, %51 : tensor<20x20xf32> - %53 = stablehlo.constant dense<7.500000e+00> : tensor - %54 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> -- %55 = stablehlo.add %54, %50 : tensor<20x20xf32> -+ %55 = stablehlo.add %50, %54 : tensor<20x20xf32> - %56 = stablehlo.constant dense<2.01490307> : tensor - %57 = stablehlo.constant dense<2.01490307> : tensor<20x20xf32> - %58 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> - %59 = stablehlo.divide %50, %58 : tensor<20x20xf32> - %60 = stablehlo.log_plus_one %59 : tensor<20x20xf32> -- %61 = stablehlo.add %57, %60 : tensor<20x20xf32> -+ %61 = stablehlo.add %60, %57 : tensor<20x20xf32> - %62 = stablehlo.divide %55, %61 : tensor<20x20xf32> - %63 = stablehlo.subtract %52, %62 : tensor<20x20xf32> - %64 = stablehlo.multiply %63, %61 : tensor<20x20xf32> -- %65 = stablehlo.add %45, %64 : tensor<20x20xf32> -+ %65 = stablehlo.add %64, %45 : tensor<20x20xf32> - %66 = stablehlo.constant dense<1.000000e+00> : tensor - %67 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> - %68 = stablehlo.constant dense<676.520386> : tensor -@@ -311,7 +311,7 @@ - %73 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> - %74 = stablehlo.add %72, %73 : tensor<20x20xf32> - %75 = stablehlo.divide %69, %74 : tensor<20x20xf32> -- %76 = stablehlo.add %67, %75 : tensor<20x20xf32> -+ %76 = stablehlo.add %75, %67 : tensor<20x20xf32> - %77 = stablehlo.constant dense<-1259.13916> : tensor - %78 = stablehlo.constant dense<-1259.13916> : tensor<20x20xf32> - %79 = stablehlo.constant dense<1.000000e+00> : tensor -@@ -585,7 +585,7 @@ - %241 = stablehlo.multiply %iterArg_4, %240 : tensor<20x20xf32> - %242 = stablehlo.constant dense<-1.000000e+00> : tensor - %243 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xf32> -- %244 = stablehlo.multiply %243, %iterArg_1 : tensor<20x20xf32> -+ %244 = stablehlo.multiply %iterArg_1, %243 : tensor<20x20xf32> - %245 = stablehlo.multiply %244, %iterArg_3 : tensor<20x20xf32> - %246 = stablehlo.multiply %228, %228 : tensor<20x20xf32> - %247 = stablehlo.divide %245, %246 : tensor<20x20xf32> -diff --ruN a/stablehlo/stablehlo/testdata/igamma_dtypes_lhs_float32_20_20__rhs_float32_20_20.mlir b/stablehlo/stablehlo/testdata/igamma_dtypes_lhs_float32_20_20__rhs_float32_20_20.mlir ---- stablehlo/stablehlo/testdata/igamma_dtypes_lhs_float32_20_20__rhs_float32_20_20.mlir -+++ stablehlo/stablehlo/testdata/igamma_dtypes_lhs_float32_20_20__rhs_float32_20_20.mlir -@@ -202,7 +202,7 @@ - %21 = stablehlo.multiply %19, %20 : tensor<20x20xf32> - %22 = stablehlo.constant dense<-1.000000e+00> : tensor - %23 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xf32> -- %24 = stablehlo.multiply %23, %1 : tensor<20x20xf32> -+ %24 = stablehlo.multiply %1, %23 : tensor<20x20xf32> - %25 = stablehlo.multiply %24, %2 : tensor<20x20xf32> - %26 = stablehlo.multiply %6, %6 : tensor<20x20xf32> - %27 = stablehlo.divide %25, %26 : tensor<20x20xf32> -@@ -270,7 +270,7 @@ - %32 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> - %33 = stablehlo.subtract %32, %27 : tensor<20x20xf32> - %34 = stablehlo.select %30, %33, %27 : tensor<20x20xi1>, tensor<20x20xf32> -- %35 = stablehlo.multiply %24, %34 : tensor<20x20xf32> -+ %35 = stablehlo.multiply %34, %24 : tensor<20x20xf32> - %36 = stablehlo.sine %35 : tensor<20x20xf32> - %37 = stablehlo.log %36 : tensor<20x20xf32> - %38 = stablehlo.is_finite %37 : (tensor<20x20xf32>) -> tensor<20x20xi1> -@@ -288,17 +288,17 @@ - %50 = stablehlo.add %48, %49 : tensor<20x20xf32> - %51 = stablehlo.constant dense<7.500000e+00> : tensor - %52 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> -- %53 = stablehlo.add %52, %48 : tensor<20x20xf32> -+ %53 = stablehlo.add %48, %52 : tensor<20x20xf32> - %54 = stablehlo.constant dense<2.01490307> : tensor - %55 = stablehlo.constant dense<2.01490307> : tensor<20x20xf32> - %56 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> - %57 = stablehlo.divide %48, %56 : tensor<20x20xf32> - %58 = stablehlo.log_plus_one %57 : tensor<20x20xf32> -- %59 = stablehlo.add %55, %58 : tensor<20x20xf32> -+ %59 = stablehlo.add %58, %55 : tensor<20x20xf32> - %60 = stablehlo.divide %53, %59 : tensor<20x20xf32> - %61 = stablehlo.subtract %50, %60 : tensor<20x20xf32> - %62 = stablehlo.multiply %61, %59 : tensor<20x20xf32> -- %63 = stablehlo.add %43, %62 : tensor<20x20xf32> -+ %63 = stablehlo.add %62, %43 : tensor<20x20xf32> - %64 = stablehlo.constant dense<1.000000e+00> : tensor - %65 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> - %66 = stablehlo.constant dense<676.520386> : tensor -@@ -309,7 +309,7 @@ - %71 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> - %72 = stablehlo.add %70, %71 : tensor<20x20xf32> - %73 = stablehlo.divide %67, %72 : tensor<20x20xf32> -- %74 = stablehlo.add %65, %73 : tensor<20x20xf32> -+ %74 = stablehlo.add %73, %65 : tensor<20x20xf32> - %75 = stablehlo.constant dense<-1259.13916> : tensor - %76 = stablehlo.constant dense<-1259.13916> : tensor<20x20xf32> - %77 = stablehlo.constant dense<1.000000e+00> : tensor -@@ -583,7 +583,7 @@ - %238 = stablehlo.multiply %iterArg_4, %237 : tensor<20x20xf32> - %239 = stablehlo.constant dense<-1.000000e+00> : tensor - %240 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xf32> -- %241 = stablehlo.multiply %240, %iterArg_1 : tensor<20x20xf32> -+ %241 = stablehlo.multiply %iterArg_1, %240 : tensor<20x20xf32> - %242 = stablehlo.multiply %241, %iterArg_3 : tensor<20x20xf32> - %243 = stablehlo.multiply %225, %225 : tensor<20x20xf32> - %244 = stablehlo.divide %242, %243 : tensor<20x20xf32> -diff --ruN a/stablehlo/stablehlo/testdata/igammac_broadcasting_lhs_float32_1_20__rhs_float32_20_20.mlir b/stablehlo/stablehlo/testdata/igammac_broadcasting_lhs_float32_1_20__rhs_float32_20_20.mlir ---- stablehlo/stablehlo/testdata/igammac_broadcasting_lhs_float32_1_20__rhs_float32_20_20.mlir -+++ stablehlo/stablehlo/testdata/igammac_broadcasting_lhs_float32_1_20__rhs_float32_20_20.mlir -@@ -47,7 +47,7 @@ - %21 = stablehlo.multiply %19, %20 : tensor<20x20xf32> - %22 = stablehlo.constant dense<-1.000000e+00> : tensor - %23 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xf32> -- %24 = stablehlo.multiply %23, %1 : tensor<20x20xf32> -+ %24 = stablehlo.multiply %1, %23 : tensor<20x20xf32> - %25 = stablehlo.multiply %24, %2 : tensor<20x20xf32> - %26 = stablehlo.multiply %6, %6 : tensor<20x20xf32> - %27 = stablehlo.divide %25, %26 : tensor<20x20xf32> -@@ -268,7 +268,7 @@ - %30 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> - %31 = stablehlo.subtract %30, %25 : tensor<20x20xf32> - %32 = stablehlo.select %28, %31, %25 : tensor<20x20xi1>, tensor<20x20xf32> -- %33 = stablehlo.multiply %22, %32 : tensor<20x20xf32> -+ %33 = stablehlo.multiply %32, %22 : tensor<20x20xf32> - %34 = stablehlo.sine %33 : tensor<20x20xf32> - %35 = stablehlo.log %34 : tensor<20x20xf32> - %36 = stablehlo.is_finite %35 : (tensor<20x20xf32>) -> tensor<20x20xi1> -@@ -286,17 +286,17 @@ - %48 = stablehlo.add %46, %47 : tensor<20x20xf32> - %49 = stablehlo.constant dense<7.500000e+00> : tensor - %50 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> -- %51 = stablehlo.add %50, %46 : tensor<20x20xf32> -+ %51 = stablehlo.add %46, %50 : tensor<20x20xf32> - %52 = stablehlo.constant dense<2.01490307> : tensor - %53 = stablehlo.constant dense<2.01490307> : tensor<20x20xf32> - %54 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> - %55 = stablehlo.divide %46, %54 : tensor<20x20xf32> - %56 = stablehlo.log_plus_one %55 : tensor<20x20xf32> -- %57 = stablehlo.add %53, %56 : tensor<20x20xf32> -+ %57 = stablehlo.add %56, %53 : tensor<20x20xf32> - %58 = stablehlo.divide %51, %57 : tensor<20x20xf32> - %59 = stablehlo.subtract %48, %58 : tensor<20x20xf32> - %60 = stablehlo.multiply %59, %57 : tensor<20x20xf32> -- %61 = stablehlo.add %41, %60 : tensor<20x20xf32> -+ %61 = stablehlo.add %60, %41 : tensor<20x20xf32> - %62 = stablehlo.constant dense<1.000000e+00> : tensor - %63 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> - %64 = stablehlo.constant dense<676.520386> : tensor -@@ -307,7 +307,7 @@ - %69 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> - %70 = stablehlo.add %68, %69 : tensor<20x20xf32> - %71 = stablehlo.divide %65, %70 : tensor<20x20xf32> -- %72 = stablehlo.add %63, %71 : tensor<20x20xf32> -+ %72 = stablehlo.add %71, %63 : tensor<20x20xf32> - %73 = stablehlo.constant dense<-1259.13916> : tensor - %74 = stablehlo.constant dense<-1259.13916> : tensor<20x20xf32> - %75 = stablehlo.constant dense<1.000000e+00> : tensor -@@ -428,7 +428,7 @@ - %228 = stablehlo.multiply %iterArg_4, %227 : tensor<20x20xf32> - %229 = stablehlo.constant dense<-1.000000e+00> : tensor - %230 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xf32> -- %231 = stablehlo.multiply %230, %iterArg_1 : tensor<20x20xf32> -+ %231 = stablehlo.multiply %iterArg_1, %230 : tensor<20x20xf32> - %232 = stablehlo.multiply %231, %iterArg_3 : tensor<20x20xf32> - %233 = stablehlo.multiply %215, %215 : tensor<20x20xf32> - %234 = stablehlo.divide %232, %233 : tensor<20x20xf32> -diff --ruN a/stablehlo/stablehlo/testdata/igammac_broadcasting_lhs_float32_20_20__rhs_float32_1_20.mlir b/stablehlo/stablehlo/testdata/igammac_broadcasting_lhs_float32_20_20__rhs_float32_1_20.mlir ---- stablehlo/stablehlo/testdata/igammac_broadcasting_lhs_float32_20_20__rhs_float32_1_20.mlir -+++ stablehlo/stablehlo/testdata/igammac_broadcasting_lhs_float32_20_20__rhs_float32_1_20.mlir -@@ -47,7 +47,7 @@ - %21 = stablehlo.multiply %19, %20 : tensor<20x20xf32> - %22 = stablehlo.constant dense<-1.000000e+00> : tensor - %23 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xf32> -- %24 = stablehlo.multiply %23, %1 : tensor<20x20xf32> -+ %24 = stablehlo.multiply %1, %23 : tensor<20x20xf32> - %25 = stablehlo.multiply %24, %2 : tensor<20x20xf32> - %26 = stablehlo.multiply %6, %6 : tensor<20x20xf32> - %27 = stablehlo.divide %25, %26 : tensor<20x20xf32> -@@ -268,7 +268,7 @@ - %30 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> - %31 = stablehlo.subtract %30, %25 : tensor<20x20xf32> - %32 = stablehlo.select %28, %31, %25 : tensor<20x20xi1>, tensor<20x20xf32> -- %33 = stablehlo.multiply %22, %32 : tensor<20x20xf32> -+ %33 = stablehlo.multiply %32, %22 : tensor<20x20xf32> - %34 = stablehlo.sine %33 : tensor<20x20xf32> - %35 = stablehlo.log %34 : tensor<20x20xf32> - %36 = stablehlo.is_finite %35 : (tensor<20x20xf32>) -> tensor<20x20xi1> -@@ -286,17 +286,17 @@ - %48 = stablehlo.add %46, %47 : tensor<20x20xf32> - %49 = stablehlo.constant dense<7.500000e+00> : tensor - %50 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> -- %51 = stablehlo.add %50, %46 : tensor<20x20xf32> -+ %51 = stablehlo.add %46, %50 : tensor<20x20xf32> - %52 = stablehlo.constant dense<2.01490307> : tensor - %53 = stablehlo.constant dense<2.01490307> : tensor<20x20xf32> - %54 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> - %55 = stablehlo.divide %46, %54 : tensor<20x20xf32> - %56 = stablehlo.log_plus_one %55 : tensor<20x20xf32> -- %57 = stablehlo.add %53, %56 : tensor<20x20xf32> -+ %57 = stablehlo.add %56, %53 : tensor<20x20xf32> - %58 = stablehlo.divide %51, %57 : tensor<20x20xf32> - %59 = stablehlo.subtract %48, %58 : tensor<20x20xf32> - %60 = stablehlo.multiply %59, %57 : tensor<20x20xf32> -- %61 = stablehlo.add %41, %60 : tensor<20x20xf32> -+ %61 = stablehlo.add %60, %41 : tensor<20x20xf32> - %62 = stablehlo.constant dense<1.000000e+00> : tensor - %63 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> - %64 = stablehlo.constant dense<676.520386> : tensor -@@ -307,7 +307,7 @@ - %69 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> - %70 = stablehlo.add %68, %69 : tensor<20x20xf32> - %71 = stablehlo.divide %65, %70 : tensor<20x20xf32> -- %72 = stablehlo.add %63, %71 : tensor<20x20xf32> -+ %72 = stablehlo.add %71, %63 : tensor<20x20xf32> - %73 = stablehlo.constant dense<-1259.13916> : tensor - %74 = stablehlo.constant dense<-1259.13916> : tensor<20x20xf32> - %75 = stablehlo.constant dense<1.000000e+00> : tensor -@@ -428,7 +428,7 @@ - %228 = stablehlo.multiply %iterArg_4, %227 : tensor<20x20xf32> - %229 = stablehlo.constant dense<-1.000000e+00> : tensor - %230 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xf32> -- %231 = stablehlo.multiply %230, %iterArg_1 : tensor<20x20xf32> -+ %231 = stablehlo.multiply %iterArg_1, %230 : tensor<20x20xf32> - %232 = stablehlo.multiply %231, %iterArg_3 : tensor<20x20xf32> - %233 = stablehlo.multiply %215, %215 : tensor<20x20xf32> - %234 = stablehlo.divide %232, %233 : tensor<20x20xf32> -diff --ruN a/stablehlo/stablehlo/testdata/igammac_dtypes_lhs_bfloat16_20_20__rhs_bfloat16_20_20.mlir b/stablehlo/stablehlo/testdata/igammac_dtypes_lhs_bfloat16_20_20__rhs_bfloat16_20_20.mlir ---- stablehlo/stablehlo/testdata/igammac_dtypes_lhs_bfloat16_20_20__rhs_bfloat16_20_20.mlir -+++ stablehlo/stablehlo/testdata/igammac_dtypes_lhs_bfloat16_20_20__rhs_bfloat16_20_20.mlir -@@ -47,7 +47,7 @@ - %21 = stablehlo.multiply %19, %20 : tensor<20x20xf32> - %22 = stablehlo.constant dense<-1.000000e+00> : tensor - %23 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xf32> -- %24 = stablehlo.multiply %23, %1 : tensor<20x20xf32> -+ %24 = stablehlo.multiply %1, %23 : tensor<20x20xf32> - %25 = stablehlo.multiply %24, %2 : tensor<20x20xf32> - %26 = stablehlo.multiply %6, %6 : tensor<20x20xf32> - %27 = stablehlo.divide %25, %26 : tensor<20x20xf32> -@@ -268,7 +268,7 @@ - %30 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> - %31 = stablehlo.subtract %30, %25 : tensor<20x20xf32> - %32 = stablehlo.select %28, %31, %25 : tensor<20x20xi1>, tensor<20x20xf32> -- %33 = stablehlo.multiply %22, %32 : tensor<20x20xf32> -+ %33 = stablehlo.multiply %32, %22 : tensor<20x20xf32> - %34 = stablehlo.sine %33 : tensor<20x20xf32> - %35 = stablehlo.log %34 : tensor<20x20xf32> - %36 = stablehlo.is_finite %35 : (tensor<20x20xf32>) -> tensor<20x20xi1> -@@ -286,17 +286,17 @@ - %48 = stablehlo.add %46, %47 : tensor<20x20xf32> - %49 = stablehlo.constant dense<7.500000e+00> : tensor - %50 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> -- %51 = stablehlo.add %50, %46 : tensor<20x20xf32> -+ %51 = stablehlo.add %46, %50 : tensor<20x20xf32> - %52 = stablehlo.constant dense<2.01490307> : tensor - %53 = stablehlo.constant dense<2.01490307> : tensor<20x20xf32> - %54 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> - %55 = stablehlo.divide %46, %54 : tensor<20x20xf32> - %56 = stablehlo.log_plus_one %55 : tensor<20x20xf32> -- %57 = stablehlo.add %53, %56 : tensor<20x20xf32> -+ %57 = stablehlo.add %56, %53 : tensor<20x20xf32> - %58 = stablehlo.divide %51, %57 : tensor<20x20xf32> - %59 = stablehlo.subtract %48, %58 : tensor<20x20xf32> - %60 = stablehlo.multiply %59, %57 : tensor<20x20xf32> -- %61 = stablehlo.add %41, %60 : tensor<20x20xf32> -+ %61 = stablehlo.add %60, %41 : tensor<20x20xf32> - %62 = stablehlo.constant dense<1.000000e+00> : tensor - %63 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> - %64 = stablehlo.constant dense<676.520386> : tensor -@@ -307,7 +307,7 @@ - %69 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> - %70 = stablehlo.add %68, %69 : tensor<20x20xf32> - %71 = stablehlo.divide %65, %70 : tensor<20x20xf32> -- %72 = stablehlo.add %63, %71 : tensor<20x20xf32> -+ %72 = stablehlo.add %71, %63 : tensor<20x20xf32> - %73 = stablehlo.constant dense<-1259.13916> : tensor - %74 = stablehlo.constant dense<-1259.13916> : tensor<20x20xf32> - %75 = stablehlo.constant dense<1.000000e+00> : tensor -@@ -428,7 +428,7 @@ - %229 = stablehlo.multiply %iterArg_4, %228 : tensor<20x20xf32> - %230 = stablehlo.constant dense<-1.000000e+00> : tensor - %231 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xf32> -- %232 = stablehlo.multiply %231, %iterArg_1 : tensor<20x20xf32> -+ %232 = stablehlo.multiply %iterArg_1, %231 : tensor<20x20xf32> - %233 = stablehlo.multiply %232, %iterArg_3 : tensor<20x20xf32> - %234 = stablehlo.multiply %216, %216 : tensor<20x20xf32> - %235 = stablehlo.divide %233, %234 : tensor<20x20xf32> -diff --ruN a/stablehlo/stablehlo/testdata/igammac_dtypes_lhs_float16_20_20__rhs_float16_20_20.mlir b/stablehlo/stablehlo/testdata/igammac_dtypes_lhs_float16_20_20__rhs_float16_20_20.mlir ---- stablehlo/stablehlo/testdata/igammac_dtypes_lhs_float16_20_20__rhs_float16_20_20.mlir -+++ stablehlo/stablehlo/testdata/igammac_dtypes_lhs_float16_20_20__rhs_float16_20_20.mlir -@@ -47,7 +47,7 @@ - %21 = stablehlo.multiply %19, %20 : tensor<20x20xf32> - %22 = stablehlo.constant dense<-1.000000e+00> : tensor - %23 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xf32> -- %24 = stablehlo.multiply %23, %1 : tensor<20x20xf32> -+ %24 = stablehlo.multiply %1, %23 : tensor<20x20xf32> - %25 = stablehlo.multiply %24, %2 : tensor<20x20xf32> - %26 = stablehlo.multiply %6, %6 : tensor<20x20xf32> - %27 = stablehlo.divide %25, %26 : tensor<20x20xf32> -@@ -268,7 +268,7 @@ - %30 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> - %31 = stablehlo.subtract %30, %25 : tensor<20x20xf32> - %32 = stablehlo.select %28, %31, %25 : tensor<20x20xi1>, tensor<20x20xf32> -- %33 = stablehlo.multiply %22, %32 : tensor<20x20xf32> -+ %33 = stablehlo.multiply %32, %22 : tensor<20x20xf32> - %34 = stablehlo.sine %33 : tensor<20x20xf32> - %35 = stablehlo.log %34 : tensor<20x20xf32> - %36 = stablehlo.is_finite %35 : (tensor<20x20xf32>) -> tensor<20x20xi1> -@@ -286,17 +286,17 @@ - %48 = stablehlo.add %46, %47 : tensor<20x20xf32> - %49 = stablehlo.constant dense<7.500000e+00> : tensor - %50 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> -- %51 = stablehlo.add %50, %46 : tensor<20x20xf32> -+ %51 = stablehlo.add %46, %50 : tensor<20x20xf32> - %52 = stablehlo.constant dense<2.01490307> : tensor - %53 = stablehlo.constant dense<2.01490307> : tensor<20x20xf32> - %54 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> - %55 = stablehlo.divide %46, %54 : tensor<20x20xf32> - %56 = stablehlo.log_plus_one %55 : tensor<20x20xf32> -- %57 = stablehlo.add %53, %56 : tensor<20x20xf32> -+ %57 = stablehlo.add %56, %53 : tensor<20x20xf32> - %58 = stablehlo.divide %51, %57 : tensor<20x20xf32> - %59 = stablehlo.subtract %48, %58 : tensor<20x20xf32> - %60 = stablehlo.multiply %59, %57 : tensor<20x20xf32> -- %61 = stablehlo.add %41, %60 : tensor<20x20xf32> -+ %61 = stablehlo.add %60, %41 : tensor<20x20xf32> - %62 = stablehlo.constant dense<1.000000e+00> : tensor - %63 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> - %64 = stablehlo.constant dense<676.520386> : tensor -@@ -307,7 +307,7 @@ - %69 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> - %70 = stablehlo.add %68, %69 : tensor<20x20xf32> - %71 = stablehlo.divide %65, %70 : tensor<20x20xf32> -- %72 = stablehlo.add %63, %71 : tensor<20x20xf32> -+ %72 = stablehlo.add %71, %63 : tensor<20x20xf32> - %73 = stablehlo.constant dense<-1259.13916> : tensor - %74 = stablehlo.constant dense<-1259.13916> : tensor<20x20xf32> - %75 = stablehlo.constant dense<1.000000e+00> : tensor -@@ -428,7 +428,7 @@ - %229 = stablehlo.multiply %iterArg_4, %228 : tensor<20x20xf32> - %230 = stablehlo.constant dense<-1.000000e+00> : tensor - %231 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xf32> -- %232 = stablehlo.multiply %231, %iterArg_1 : tensor<20x20xf32> -+ %232 = stablehlo.multiply %iterArg_1, %231 : tensor<20x20xf32> - %233 = stablehlo.multiply %232, %iterArg_3 : tensor<20x20xf32> - %234 = stablehlo.multiply %216, %216 : tensor<20x20xf32> - %235 = stablehlo.divide %233, %234 : tensor<20x20xf32> -diff --ruN a/stablehlo/stablehlo/testdata/igammac_dtypes_lhs_float32_20_20__rhs_float32_20_20.mlir b/stablehlo/stablehlo/testdata/igammac_dtypes_lhs_float32_20_20__rhs_float32_20_20.mlir ---- stablehlo/stablehlo/testdata/igammac_dtypes_lhs_float32_20_20__rhs_float32_20_20.mlir -+++ stablehlo/stablehlo/testdata/igammac_dtypes_lhs_float32_20_20__rhs_float32_20_20.mlir -@@ -47,7 +47,7 @@ - %21 = stablehlo.multiply %19, %20 : tensor<20x20xf32> - %22 = stablehlo.constant dense<-1.000000e+00> : tensor - %23 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xf32> -- %24 = stablehlo.multiply %23, %1 : tensor<20x20xf32> -+ %24 = stablehlo.multiply %1, %23 : tensor<20x20xf32> - %25 = stablehlo.multiply %24, %2 : tensor<20x20xf32> - %26 = stablehlo.multiply %6, %6 : tensor<20x20xf32> - %27 = stablehlo.divide %25, %26 : tensor<20x20xf32> -@@ -266,7 +266,7 @@ - %28 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> - %29 = stablehlo.subtract %28, %23 : tensor<20x20xf32> - %30 = stablehlo.select %26, %29, %23 : tensor<20x20xi1>, tensor<20x20xf32> -- %31 = stablehlo.multiply %20, %30 : tensor<20x20xf32> -+ %31 = stablehlo.multiply %30, %20 : tensor<20x20xf32> - %32 = stablehlo.sine %31 : tensor<20x20xf32> - %33 = stablehlo.log %32 : tensor<20x20xf32> - %34 = stablehlo.is_finite %33 : (tensor<20x20xf32>) -> tensor<20x20xi1> -@@ -284,17 +284,17 @@ - %46 = stablehlo.add %44, %45 : tensor<20x20xf32> - %47 = stablehlo.constant dense<7.500000e+00> : tensor - %48 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> -- %49 = stablehlo.add %48, %44 : tensor<20x20xf32> -+ %49 = stablehlo.add %44, %48 : tensor<20x20xf32> - %50 = stablehlo.constant dense<2.01490307> : tensor - %51 = stablehlo.constant dense<2.01490307> : tensor<20x20xf32> - %52 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> - %53 = stablehlo.divide %44, %52 : tensor<20x20xf32> - %54 = stablehlo.log_plus_one %53 : tensor<20x20xf32> -- %55 = stablehlo.add %51, %54 : tensor<20x20xf32> -+ %55 = stablehlo.add %54, %51 : tensor<20x20xf32> - %56 = stablehlo.divide %49, %55 : tensor<20x20xf32> - %57 = stablehlo.subtract %46, %56 : tensor<20x20xf32> - %58 = stablehlo.multiply %57, %55 : tensor<20x20xf32> -- %59 = stablehlo.add %39, %58 : tensor<20x20xf32> -+ %59 = stablehlo.add %58, %39 : tensor<20x20xf32> - %60 = stablehlo.constant dense<1.000000e+00> : tensor - %61 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> - %62 = stablehlo.constant dense<676.520386> : tensor -@@ -305,7 +305,7 @@ - %67 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> - %68 = stablehlo.add %66, %67 : tensor<20x20xf32> - %69 = stablehlo.divide %63, %68 : tensor<20x20xf32> -- %70 = stablehlo.add %61, %69 : tensor<20x20xf32> -+ %70 = stablehlo.add %69, %61 : tensor<20x20xf32> - %71 = stablehlo.constant dense<-1259.13916> : tensor - %72 = stablehlo.constant dense<-1259.13916> : tensor<20x20xf32> - %73 = stablehlo.constant dense<1.000000e+00> : tensor -@@ -426,7 +426,7 @@ - %226 = stablehlo.multiply %iterArg_4, %225 : tensor<20x20xf32> - %227 = stablehlo.constant dense<-1.000000e+00> : tensor - %228 = stablehlo.constant dense<-1.000000e+00> : tensor<20x20xf32> -- %229 = stablehlo.multiply %228, %iterArg_1 : tensor<20x20xf32> -+ %229 = stablehlo.multiply %iterArg_1, %228 : tensor<20x20xf32> - %230 = stablehlo.multiply %229, %iterArg_3 : tensor<20x20xf32> - %231 = stablehlo.multiply %213, %213 : tensor<20x20xf32> - %232 = stablehlo.divide %230, %231 : tensor<20x20xf32> -diff --ruN a/stablehlo/stablehlo/testdata/index_in_dim_0_dynamic.mlir b/stablehlo/stablehlo/testdata/index_in_dim_0_dynamic.mlir ---- stablehlo/stablehlo/testdata/index_in_dim_0_dynamic.mlir -+++ stablehlo/stablehlo/testdata/index_in_dim_0_dynamic.mlir -@@ -3,7 +3,7 @@ - module @jit_fun_flat_jax { - func.func public @main(%arg0: tensor, %arg1: tensor {mhlo.sharding = ""}) -> tensor<4xf32> { - %0 = stablehlo.constant dense<-1> : tensor -- %1 = stablehlo.add %0, %arg0 : tensor -+ %1 = stablehlo.add %arg0, %0 : tensor - %2 = stablehlo.convert %1 : (tensor) -> tensor - %3 = stablehlo.reshape %2 : (tensor) -> tensor<1xi32> - %4 = stablehlo.constant dense<0> : tensor<1xi32> -diff --ruN a/stablehlo/stablehlo/testdata/index_in_dim_idx_neg_dynamic.mlir b/stablehlo/stablehlo/testdata/index_in_dim_idx_neg_dynamic.mlir ---- stablehlo/stablehlo/testdata/index_in_dim_idx_neg_dynamic.mlir -+++ stablehlo/stablehlo/testdata/index_in_dim_idx_neg_dynamic.mlir -@@ -3,7 +3,7 @@ - module @jit_fun_flat_jax { - func.func public @main(%arg0: tensor, %arg1: tensor {mhlo.sharding = ""}) -> tensor<4xf32> { - %0 = stablehlo.constant dense<-1> : tensor -- %1 = stablehlo.add %0, %arg0 : tensor -+ %1 = stablehlo.add %arg0, %0 : tensor - %2 = stablehlo.convert %1 : (tensor) -> tensor - %3 = stablehlo.reshape %2 : (tensor) -> tensor<1xi32> - %4 = stablehlo.constant dense<0> : tensor<1xi32> -diff --ruN a/stablehlo/stablehlo/testdata/lgamma_shape_bfloat16_20_20.mlir b/stablehlo/stablehlo/testdata/lgamma_shape_bfloat16_20_20.mlir ---- stablehlo/stablehlo/testdata/lgamma_shape_bfloat16_20_20.mlir -+++ stablehlo/stablehlo/testdata/lgamma_shape_bfloat16_20_20.mlir -@@ -17,7 +17,7 @@ - %11 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> - %12 = stablehlo.add %8, %11 : tensor<20x20xf32> - %13 = stablehlo.divide %10, %12 : tensor<20x20xf32> -- %14 = stablehlo.add %9, %13 : tensor<20x20xf32> -+ %14 = stablehlo.add %13, %9 : tensor<20x20xf32> - %15 = stablehlo.constant dense<-1259.13916> : tensor<20x20xf32> - %16 = stablehlo.constant dense<2.000000e+00> : tensor<20x20xf32> - %17 = stablehlo.add %8, %16 : tensor<20x20xf32> -@@ -54,18 +54,18 @@ - %48 = stablehlo.divide %45, %47 : tensor<20x20xf32> - %49 = stablehlo.add %44, %48 : tensor<20x20xf32> - %50 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> -- %51 = stablehlo.add %50, %8 : tensor<20x20xf32> -+ %51 = stablehlo.add %8, %50 : tensor<20x20xf32> - %52 = stablehlo.constant dense<2.01490307> : tensor<20x20xf32> - %53 = stablehlo.divide %8, %50 : tensor<20x20xf32> - %54 = stablehlo.log_plus_one %53 : tensor<20x20xf32> -- %55 = stablehlo.add %52, %54 : tensor<20x20xf32> -+ %55 = stablehlo.add %54, %52 : tensor<20x20xf32> - %56 = stablehlo.divide %51, %55 : tensor<20x20xf32> - %57 = stablehlo.add %8, %3 : tensor<20x20xf32> - %58 = stablehlo.subtract %57, %56 : tensor<20x20xf32> - %59 = stablehlo.multiply %58, %55 : tensor<20x20xf32> - %60 = stablehlo.log %49 : tensor<20x20xf32> - %61 = stablehlo.constant dense<0.918938517> : tensor<20x20xf32> -- %62 = stablehlo.add %61, %59 : tensor<20x20xf32> -+ %62 = stablehlo.add %59, %61 : tensor<20x20xf32> - %63 = stablehlo.add %62, %60 : tensor<20x20xf32> - %64 = stablehlo.abs %2 : tensor<20x20xf32> - %65 = stablehlo.floor %64 : tensor<20x20xf32> -@@ -74,7 +74,7 @@ - %68 = stablehlo.subtract %6, %66 : tensor<20x20xf32> - %69 = stablehlo.select %67, %68, %66 : tensor<20x20xi1>, tensor<20x20xf32> - %70 = stablehlo.constant dense<3.14159274> : tensor<20x20xf32> -- %71 = stablehlo.multiply %70, %69 : tensor<20x20xf32> -+ %71 = stablehlo.multiply %69, %70 : tensor<20x20xf32> - %72 = stablehlo.sine %71 : tensor<20x20xf32> - %73 = stablehlo.log %72 : tensor<20x20xf32> - %74 = stablehlo.constant dense<1.14472985> : tensor<20x20xf32> -diff --ruN a/stablehlo/stablehlo/testdata/lgamma_shape_float16_20_20.mlir b/stablehlo/stablehlo/testdata/lgamma_shape_float16_20_20.mlir ---- stablehlo/stablehlo/testdata/lgamma_shape_float16_20_20.mlir -+++ stablehlo/stablehlo/testdata/lgamma_shape_float16_20_20.mlir -@@ -17,7 +17,7 @@ - %11 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> - %12 = stablehlo.add %8, %11 : tensor<20x20xf32> - %13 = stablehlo.divide %10, %12 : tensor<20x20xf32> -- %14 = stablehlo.add %9, %13 : tensor<20x20xf32> -+ %14 = stablehlo.add %13, %9 : tensor<20x20xf32> - %15 = stablehlo.constant dense<-1259.13916> : tensor<20x20xf32> - %16 = stablehlo.constant dense<2.000000e+00> : tensor<20x20xf32> - %17 = stablehlo.add %8, %16 : tensor<20x20xf32> -@@ -54,18 +54,18 @@ - %48 = stablehlo.divide %45, %47 : tensor<20x20xf32> - %49 = stablehlo.add %44, %48 : tensor<20x20xf32> - %50 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> -- %51 = stablehlo.add %50, %8 : tensor<20x20xf32> -+ %51 = stablehlo.add %8, %50 : tensor<20x20xf32> - %52 = stablehlo.constant dense<2.01490307> : tensor<20x20xf32> - %53 = stablehlo.divide %8, %50 : tensor<20x20xf32> - %54 = stablehlo.log_plus_one %53 : tensor<20x20xf32> -- %55 = stablehlo.add %52, %54 : tensor<20x20xf32> -+ %55 = stablehlo.add %54, %52 : tensor<20x20xf32> - %56 = stablehlo.divide %51, %55 : tensor<20x20xf32> - %57 = stablehlo.add %8, %3 : tensor<20x20xf32> - %58 = stablehlo.subtract %57, %56 : tensor<20x20xf32> - %59 = stablehlo.multiply %58, %55 : tensor<20x20xf32> - %60 = stablehlo.log %49 : tensor<20x20xf32> - %61 = stablehlo.constant dense<0.918938517> : tensor<20x20xf32> -- %62 = stablehlo.add %61, %59 : tensor<20x20xf32> -+ %62 = stablehlo.add %59, %61 : tensor<20x20xf32> - %63 = stablehlo.add %62, %60 : tensor<20x20xf32> - %64 = stablehlo.abs %2 : tensor<20x20xf32> - %65 = stablehlo.floor %64 : tensor<20x20xf32> -@@ -74,7 +74,7 @@ - %68 = stablehlo.subtract %6, %66 : tensor<20x20xf32> - %69 = stablehlo.select %67, %68, %66 : tensor<20x20xi1>, tensor<20x20xf32> - %70 = stablehlo.constant dense<3.14159274> : tensor<20x20xf32> -- %71 = stablehlo.multiply %70, %69 : tensor<20x20xf32> -+ %71 = stablehlo.multiply %69, %70 : tensor<20x20xf32> - %72 = stablehlo.sine %71 : tensor<20x20xf32> - %73 = stablehlo.log %72 : tensor<20x20xf32> - %74 = stablehlo.constant dense<1.14472985> : tensor<20x20xf32> -diff --ruN a/stablehlo/stablehlo/testdata/lgamma_shape_float32_20_20.mlir b/stablehlo/stablehlo/testdata/lgamma_shape_float32_20_20.mlir ---- stablehlo/stablehlo/testdata/lgamma_shape_float32_20_20.mlir -+++ stablehlo/stablehlo/testdata/lgamma_shape_float32_20_20.mlir -@@ -16,7 +16,7 @@ - %10 = stablehlo.constant dense<1.000000e+00> : tensor<20x20xf32> - %11 = stablehlo.add %7, %10 : tensor<20x20xf32> - %12 = stablehlo.divide %9, %11 : tensor<20x20xf32> -- %13 = stablehlo.add %8, %12 : tensor<20x20xf32> -+ %13 = stablehlo.add %12, %8 : tensor<20x20xf32> - %14 = stablehlo.constant dense<-1259.13916> : tensor<20x20xf32> - %15 = stablehlo.constant dense<2.000000e+00> : tensor<20x20xf32> - %16 = stablehlo.add %7, %15 : tensor<20x20xf32> -@@ -53,18 +53,18 @@ - %47 = stablehlo.divide %44, %46 : tensor<20x20xf32> - %48 = stablehlo.add %43, %47 : tensor<20x20xf32> - %49 = stablehlo.constant dense<7.500000e+00> : tensor<20x20xf32> -- %50 = stablehlo.add %49, %7 : tensor<20x20xf32> -+ %50 = stablehlo.add %7, %49 : tensor<20x20xf32> - %51 = stablehlo.constant dense<2.01490307> : tensor<20x20xf32> - %52 = stablehlo.divide %7, %49 : tensor<20x20xf32> - %53 = stablehlo.log_plus_one %52 : tensor<20x20xf32> -- %54 = stablehlo.add %51, %53 : tensor<20x20xf32> -+ %54 = stablehlo.add %53, %51 : tensor<20x20xf32> - %55 = stablehlo.divide %50, %54 : tensor<20x20xf32> - %56 = stablehlo.add %7, %2 : tensor<20x20xf32> - %57 = stablehlo.subtract %56, %55 : tensor<20x20xf32> - %58 = stablehlo.multiply %57, %54 : tensor<20x20xf32> - %59 = stablehlo.log %48 : tensor<20x20xf32> - %60 = stablehlo.constant dense<0.918938517> : tensor<20x20xf32> -- %61 = stablehlo.add %60, %58 : tensor<20x20xf32> -+ %61 = stablehlo.add %58, %60 : tensor<20x20xf32> - %62 = stablehlo.add %61, %59 : tensor<20x20xf32> - %63 = stablehlo.abs %0 : tensor<20x20xf32> - %64 = stablehlo.floor %63 : tensor<20x20xf32> -@@ -73,7 +73,7 @@ - %67 = stablehlo.subtract %5, %65 : tensor<20x20xf32> - %68 = stablehlo.select %66, %67, %65 : tensor<20x20xi1>, tensor<20x20xf32> - %69 = stablehlo.constant dense<3.14159274> : tensor<20x20xf32> -- %70 = stablehlo.multiply %69, %68 : tensor<20x20xf32> -+ %70 = stablehlo.multiply %68, %69 : tensor<20x20xf32> - %71 = stablehlo.sine %70 : tensor<20x20xf32> - %72 = stablehlo.log %71 : tensor<20x20xf32> - %73 = stablehlo.constant dense<1.14472985> : tensor<20x20xf32> -diff --ruN a/stablehlo/stablehlo/testdata/nanquantile_axis_None_dynamic.mlir b/stablehlo/stablehlo/testdata/nanquantile_axis_None_dynamic.mlir ---- stablehlo/stablehlo/testdata/nanquantile_axis_None_dynamic.mlir -+++ stablehlo/stablehlo/testdata/nanquantile_axis_None_dynamic.mlir -@@ -72,12 +72,12 @@ - %24 = stablehlo.subtract %14, %23 : tensor - %25 = stablehlo.minimum %18, %24 : tensor - %26 = stablehlo.constant dense<0.000000e+00> : tensor -- %27 = stablehlo.maximum %26, %25 : tensor -+ %27 = stablehlo.maximum %25, %26 : tensor - %28 = stablehlo.constant dense<1.000000e+00> : tensor - %29 = stablehlo.subtract %14, %28 : tensor - %30 = stablehlo.minimum %19, %29 : tensor - %31 = stablehlo.constant dense<0.000000e+00> : tensor -- %32 = stablehlo.maximum %31, %30 : tensor -+ %32 = stablehlo.maximum %30, %31 : tensor - %33 = stablehlo.convert %27 : (tensor) -> tensor - %34 = stablehlo.convert %32 : (tensor) -> tensor - %35 = stablehlo.constant dense<5> : tensor -diff --ruN a/stablehlo/stablehlo/testdata/random_gamma_shape_float32.mlir b/stablehlo/stablehlo/testdata/random_gamma_shape_float32.mlir ---- stablehlo/stablehlo/testdata/random_gamma_shape_float32.mlir -+++ stablehlo/stablehlo/testdata/random_gamma_shape_float32.mlir -@@ -338,7 +338,7 @@ - %122 = stablehlo.add %120, %121 : tensor - %123 = stablehlo.reshape %122 : (tensor) -> tensor - %124 = stablehlo.constant dense<0.000000e+00> : tensor -- %125 = stablehlo.maximum %124, %123 : tensor -+ %125 = stablehlo.maximum %123, %124 : tensor - %126 = stablehlo.constant dense<0.000000e+00> : tensor - %127 = stablehlo.constant dense<1.000000e+00> : tensor - %128 = stablehlo.constant dense<2.000000e+00> : tensor -@@ -346,7 +346,7 @@ - cond { - %151 = stablehlo.multiply %iterArg_6, %iterArg_6 : tensor - %152 = stablehlo.constant dense<3.310000e-02> : tensor -- %153 = stablehlo.multiply %152, %151 : tensor -+ %153 = stablehlo.multiply %151, %152 : tensor - %154 = stablehlo.constant dense<1.000000e+00> : tensor - %155 = stablehlo.subtract %154, %153 : tensor - %156 = stablehlo.compare GE, %iterArg_8, %155, FLOAT : (tensor, tensor) -> tensor -@@ -656,13 +656,13 @@ - %289 = stablehlo.add %287, %288 : tensor - %290 = stablehlo.reshape %289 : (tensor) -> tensor - %291 = stablehlo.constant dense<-0.99999994> : tensor -- %292 = stablehlo.maximum %291, %290 : tensor -+ %292 = stablehlo.maximum %290, %291 : tensor - %293 = func.call @erf_inv(%292) : (tensor) -> tensor - %294 = stablehlo.constant dense<1.41421354> : tensor -- %295 = stablehlo.multiply %294, %293 : tensor -+ %295 = stablehlo.multiply %293, %294 : tensor - %296 = stablehlo.multiply %295, %iterArg_9 : tensor - %297 = stablehlo.constant dense<1.000000e+00> : tensor -- %298 = stablehlo.add %297, %296 : tensor -+ %298 = stablehlo.add %296, %297 : tensor - stablehlo.return %iterArg_9, %248, %295, %298 : tensor, tensor<2xui32>, tensor, tensor - } - %181 = stablehlo.multiply %180#2, %180#2 : tensor -@@ -773,7 +773,7 @@ - %222 = stablehlo.add %220, %221 : tensor - %223 = stablehlo.reshape %222 : (tensor) -> tensor - %224 = stablehlo.constant dense<0.000000e+00> : tensor -- %225 = stablehlo.maximum %224, %223 : tensor -+ %225 = stablehlo.maximum %223, %224 : tensor - stablehlo.return %iterArg_3, %iterArg_4, %173, %181, %183, %225 : tensor, tensor, tensor<2xui32>, tensor, tensor, tensor - } - %130 = stablehlo.constant dense<1.000000e+00> : tensor -diff --ruN a/stablehlo/stablehlo/testdata/random_gamma_shape_float32_3.mlir b/stablehlo/stablehlo/testdata/random_gamma_shape_float32_3.mlir ---- stablehlo/stablehlo/testdata/random_gamma_shape_float32_3.mlir -+++ stablehlo/stablehlo/testdata/random_gamma_shape_float32_3.mlir -@@ -341,7 +341,7 @@ - %122 = stablehlo.add %120, %121 : tensor - %123 = stablehlo.reshape %122 : (tensor) -> tensor - %124 = stablehlo.constant dense<0.000000e+00> : tensor -- %125 = stablehlo.maximum %124, %123 : tensor -+ %125 = stablehlo.maximum %123, %124 : tensor - %126 = stablehlo.constant dense<0.000000e+00> : tensor - %127 = stablehlo.constant dense<1.000000e+00> : tensor - %128 = stablehlo.constant dense<2.000000e+00> : tensor -@@ -349,7 +349,7 @@ - cond { - %151 = stablehlo.multiply %iterArg_6, %iterArg_6 : tensor - %152 = stablehlo.constant dense<3.310000e-02> : tensor -- %153 = stablehlo.multiply %152, %151 : tensor -+ %153 = stablehlo.multiply %151, %152 : tensor - %154 = stablehlo.constant dense<1.000000e+00> : tensor - %155 = stablehlo.subtract %154, %153 : tensor - %156 = stablehlo.compare GE, %iterArg_8, %155, FLOAT : (tensor, tensor) -> tensor -@@ -659,13 +659,13 @@ - %289 = stablehlo.add %287, %288 : tensor - %290 = stablehlo.reshape %289 : (tensor) -> tensor - %291 = stablehlo.constant dense<-0.99999994> : tensor -- %292 = stablehlo.maximum %291, %290 : tensor -+ %292 = stablehlo.maximum %290, %291 : tensor - %293 = func.call @erf_inv(%292) : (tensor) -> tensor - %294 = stablehlo.constant dense<1.41421354> : tensor -- %295 = stablehlo.multiply %294, %293 : tensor -+ %295 = stablehlo.multiply %293, %294 : tensor - %296 = stablehlo.multiply %295, %iterArg_9 : tensor - %297 = stablehlo.constant dense<1.000000e+00> : tensor -- %298 = stablehlo.add %297, %296 : tensor -+ %298 = stablehlo.add %296, %297 : tensor - stablehlo.return %iterArg_9, %248, %295, %298 : tensor, tensor<2xui32>, tensor, tensor - } - %181 = stablehlo.multiply %180#2, %180#2 : tensor -@@ -776,7 +776,7 @@ - %222 = stablehlo.add %220, %221 : tensor - %223 = stablehlo.reshape %222 : (tensor) -> tensor - %224 = stablehlo.constant dense<0.000000e+00> : tensor -- %225 = stablehlo.maximum %224, %223 : tensor -+ %225 = stablehlo.maximum %223, %224 : tensor - stablehlo.return %iterArg_3, %iterArg_4, %173, %181, %183, %225 : tensor, tensor, tensor<2xui32>, tensor, tensor, tensor - } - %130 = stablehlo.constant dense<1.000000e+00> : tensor -diff --ruN a/stablehlo/stablehlo/testdata/random_gamma_shape_float64.mlir b/stablehlo/stablehlo/testdata/random_gamma_shape_float64.mlir ---- stablehlo/stablehlo/testdata/random_gamma_shape_float64.mlir -+++ stablehlo/stablehlo/testdata/random_gamma_shape_float64.mlir -@@ -338,7 +338,7 @@ - %122 = stablehlo.add %120, %121 : tensor - %123 = stablehlo.reshape %122 : (tensor) -> tensor - %124 = stablehlo.constant dense<0.000000e+00> : tensor -- %125 = stablehlo.maximum %124, %123 : tensor -+ %125 = stablehlo.maximum %123, %124 : tensor - %126 = stablehlo.constant dense<0.000000e+00> : tensor - %127 = stablehlo.constant dense<1.000000e+00> : tensor - %128 = stablehlo.constant dense<2.000000e+00> : tensor -@@ -346,7 +346,7 @@ - cond { - %151 = stablehlo.multiply %iterArg_6, %iterArg_6 : tensor - %152 = stablehlo.constant dense<3.310000e-02> : tensor -- %153 = stablehlo.multiply %152, %151 : tensor -+ %153 = stablehlo.multiply %151, %152 : tensor - %154 = stablehlo.constant dense<1.000000e+00> : tensor - %155 = stablehlo.subtract %154, %153 : tensor - %156 = stablehlo.compare GE, %iterArg_8, %155, FLOAT : (tensor, tensor) -> tensor -@@ -656,13 +656,13 @@ - %289 = stablehlo.add %287, %288 : tensor - %290 = stablehlo.reshape %289 : (tensor) -> tensor - %291 = stablehlo.constant dense<-0.99999994> : tensor -- %292 = stablehlo.maximum %291, %290 : tensor -+ %292 = stablehlo.maximum %290, %291 : tensor - %293 = func.call @erf_inv(%292) : (tensor) -> tensor - %294 = stablehlo.constant dense<1.41421354> : tensor -- %295 = stablehlo.multiply %294, %293 : tensor -+ %295 = stablehlo.multiply %293, %294 : tensor - %296 = stablehlo.multiply %295, %iterArg_9 : tensor - %297 = stablehlo.constant dense<1.000000e+00> : tensor -- %298 = stablehlo.add %297, %296 : tensor -+ %298 = stablehlo.add %296, %297 : tensor - stablehlo.return %iterArg_9, %248, %295, %298 : tensor, tensor<2xui32>, tensor, tensor - } - %181 = stablehlo.multiply %180#2, %180#2 : tensor -@@ -773,7 +773,7 @@ - %222 = stablehlo.add %220, %221 : tensor - %223 = stablehlo.reshape %222 : (tensor) -> tensor - %224 = stablehlo.constant dense<0.000000e+00> : tensor -- %225 = stablehlo.maximum %224, %223 : tensor -+ %225 = stablehlo.maximum %223, %224 : tensor - stablehlo.return %iterArg_3, %iterArg_4, %173, %181, %183, %225 : tensor, tensor, tensor<2xui32>, tensor, tensor, tensor - } - %130 = stablehlo.constant dense<1.000000e+00> : tensor -diff --ruN a/stablehlo/stablehlo/testdata/random_gamma_shape_float64_3.mlir b/stablehlo/stablehlo/testdata/random_gamma_shape_float64_3.mlir ---- stablehlo/stablehlo/testdata/random_gamma_shape_float64_3.mlir -+++ stablehlo/stablehlo/testdata/random_gamma_shape_float64_3.mlir -@@ -341,7 +341,7 @@ - %122 = stablehlo.add %120, %121 : tensor - %123 = stablehlo.reshape %122 : (tensor) -> tensor - %124 = stablehlo.constant dense<0.000000e+00> : tensor -- %125 = stablehlo.maximum %124, %123 : tensor -+ %125 = stablehlo.maximum %123, %124 : tensor - %126 = stablehlo.constant dense<0.000000e+00> : tensor - %127 = stablehlo.constant dense<1.000000e+00> : tensor - %128 = stablehlo.constant dense<2.000000e+00> : tensor -@@ -349,7 +349,7 @@ - cond { - %151 = stablehlo.multiply %iterArg_6, %iterArg_6 : tensor - %152 = stablehlo.constant dense<3.310000e-02> : tensor -- %153 = stablehlo.multiply %152, %151 : tensor -+ %153 = stablehlo.multiply %151, %152 : tensor - %154 = stablehlo.constant dense<1.000000e+00> : tensor - %155 = stablehlo.subtract %154, %153 : tensor - %156 = stablehlo.compare GE, %iterArg_8, %155, FLOAT : (tensor, tensor) -> tensor -@@ -659,13 +659,13 @@ - %289 = stablehlo.add %287, %288 : tensor - %290 = stablehlo.reshape %289 : (tensor) -> tensor - %291 = stablehlo.constant dense<-0.99999994> : tensor -- %292 = stablehlo.maximum %291, %290 : tensor -+ %292 = stablehlo.maximum %290, %291 : tensor - %293 = func.call @erf_inv(%292) : (tensor) -> tensor - %294 = stablehlo.constant dense<1.41421354> : tensor -- %295 = stablehlo.multiply %294, %293 : tensor -+ %295 = stablehlo.multiply %293, %294 : tensor - %296 = stablehlo.multiply %295, %iterArg_9 : tensor - %297 = stablehlo.constant dense<1.000000e+00> : tensor -- %298 = stablehlo.add %297, %296 : tensor -+ %298 = stablehlo.add %296, %297 : tensor - stablehlo.return %iterArg_9, %248, %295, %298 : tensor, tensor<2xui32>, tensor, tensor - } - %181 = stablehlo.multiply %180#2, %180#2 : tensor -@@ -776,7 +776,7 @@ - %222 = stablehlo.add %220, %221 : tensor - %223 = stablehlo.reshape %222 : (tensor) -> tensor - %224 = stablehlo.constant dense<0.000000e+00> : tensor -- %225 = stablehlo.maximum %224, %223 : tensor -+ %225 = stablehlo.maximum %223, %224 : tensor - stablehlo.return %iterArg_3, %iterArg_4, %173, %181, %183, %225 : tensor, tensor, tensor<2xui32>, tensor, tensor, tensor - } - %130 = stablehlo.constant dense<1.000000e+00> : tensor -diff --ruN a/stablehlo/stablehlo/testdata/random_uniform_shape_bfloat16.mlir b/stablehlo/stablehlo/testdata/random_uniform_shape_bfloat16.mlir ---- stablehlo/stablehlo/testdata/random_uniform_shape_bfloat16.mlir -+++ stablehlo/stablehlo/testdata/random_uniform_shape_bfloat16.mlir -@@ -125,7 +125,7 @@ - %55 = stablehlo.add %53, %54 : tensor - %56 = stablehlo.reshape %55 : (tensor) -> tensor - %57 = stablehlo.constant dense<0.000000e+00> : tensor -- %58 = stablehlo.maximum %57, %56 : tensor -+ %58 = stablehlo.maximum %56, %57 : tensor - %59 = stablehlo.custom_call @check.eq(%58, %1) : (tensor, tensor) -> tensor - return %59 : tensor - } -diff --ruN a/stablehlo/stablehlo/testdata/random_uniform_shape_float16.mlir b/stablehlo/stablehlo/testdata/random_uniform_shape_float16.mlir ---- stablehlo/stablehlo/testdata/random_uniform_shape_float16.mlir -+++ stablehlo/stablehlo/testdata/random_uniform_shape_float16.mlir -@@ -125,7 +125,7 @@ - %55 = stablehlo.add %53, %54 : tensor - %56 = stablehlo.reshape %55 : (tensor) -> tensor - %57 = stablehlo.constant dense<0.000000e+00> : tensor -- %58 = stablehlo.maximum %57, %56 : tensor -+ %58 = stablehlo.maximum %56, %57 : tensor - %59 = stablehlo.custom_call @check.eq(%58, %1) : (tensor, tensor) -> tensor - return %59 : tensor - } -diff --ruN a/stablehlo/stablehlo/testdata/random_uniform_shape_float32.mlir b/stablehlo/stablehlo/testdata/random_uniform_shape_float32.mlir ---- stablehlo/stablehlo/testdata/random_uniform_shape_float32.mlir -+++ stablehlo/stablehlo/testdata/random_uniform_shape_float32.mlir -@@ -110,7 +110,7 @@ - %40 = stablehlo.add %38, %39 : tensor - %41 = stablehlo.reshape %40 : (tensor) -> tensor - %42 = stablehlo.constant dense<0.000000e+00> : tensor -- %43 = stablehlo.maximum %42, %41 : tensor -+ %43 = stablehlo.maximum %41, %42 : tensor - %44 = stablehlo.custom_call @check.eq(%43, %1) : (tensor, tensor) -> tensor - return %44 : tensor - } -diff --ruN a/stablehlo/stablehlo/testdata/regularized_incomplete_beta__bfloat16.mlir b/stablehlo/stablehlo/testdata/regularized_incomplete_beta__bfloat16.mlir ---- stablehlo/stablehlo/testdata/regularized_incomplete_beta__bfloat16.mlir -+++ stablehlo/stablehlo/testdata/regularized_incomplete_beta__bfloat16.mlir -@@ -71,9 +71,9 @@ - %40 = stablehlo.multiply %38, %39 : tensor<9xf32> - %41 = stablehlo.constant dense<2.000000e+00> : tensor - %42 = stablehlo.constant dense<2.000000e+00> : tensor<9xf32> -- %43 = stablehlo.multiply %42, %32 : tensor<9xf32> -+ %43 = stablehlo.multiply %32, %42 : tensor<9xf32> - %44 = stablehlo.add %25, %43 : tensor<9xf32> -- %45 = stablehlo.multiply %42, %32 : tensor<9xf32> -+ %45 = stablehlo.multiply %32, %42 : tensor<9xf32> - %46 = stablehlo.add %25, %45 : tensor<9xf32> - %47 = stablehlo.constant dense<1.000000e+00> : tensor - %48 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> -@@ -83,10 +83,10 @@ - %52 = stablehlo.subtract %35, %32 : tensor<9xf32> - %53 = stablehlo.multiply %32, %52 : tensor<9xf32> - %54 = stablehlo.multiply %53, %39 : tensor<9xf32> -- %55 = stablehlo.multiply %42, %32 : tensor<9xf32> -+ %55 = stablehlo.multiply %32, %42 : tensor<9xf32> - %56 = stablehlo.add %25, %55 : tensor<9xf32> - %57 = stablehlo.subtract %56, %48 : tensor<9xf32> -- %58 = stablehlo.multiply %42, %32 : tensor<9xf32> -+ %58 = stablehlo.multiply %32, %42 : tensor<9xf32> - %59 = stablehlo.add %25, %58 : tensor<9xf32> - %60 = stablehlo.multiply %57, %59 : tensor<9xf32> - %61 = stablehlo.divide %54, %60 : tensor<9xf32> -@@ -225,9 +225,9 @@ - %502 = stablehlo.multiply %501, %iterArg_6 : tensor<9xf32> - %503 = stablehlo.constant dense<2.000000e+00> : tensor - %504 = stablehlo.constant dense<2.000000e+00> : tensor<9xf32> -- %505 = stablehlo.multiply %504, %496 : tensor<9xf32> -+ %505 = stablehlo.multiply %496, %504 : tensor<9xf32> - %506 = stablehlo.add %iterArg_4, %505 : tensor<9xf32> -- %507 = stablehlo.multiply %504, %496 : tensor<9xf32> -+ %507 = stablehlo.multiply %496, %504 : tensor<9xf32> - %508 = stablehlo.add %iterArg_4, %507 : tensor<9xf32> - %509 = stablehlo.constant dense<1.000000e+00> : tensor - %510 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> -@@ -237,10 +237,10 @@ - %514 = stablehlo.subtract %iterArg_5, %496 : tensor<9xf32> - %515 = stablehlo.multiply %496, %514 : tensor<9xf32> - %516 = stablehlo.multiply %515, %iterArg_6 : tensor<9xf32> -- %517 = stablehlo.multiply %504, %496 : tensor<9xf32> -+ %517 = stablehlo.multiply %496, %504 : tensor<9xf32> - %518 = stablehlo.add %iterArg_4, %517 : tensor<9xf32> - %519 = stablehlo.subtract %518, %510 : tensor<9xf32> -- %520 = stablehlo.multiply %504, %496 : tensor<9xf32> -+ %520 = stablehlo.multiply %496, %504 : tensor<9xf32> - %521 = stablehlo.add %iterArg_4, %520 : tensor<9xf32> - %522 = stablehlo.multiply %519, %521 : tensor<9xf32> - %523 = stablehlo.divide %516, %522 : tensor<9xf32> -@@ -322,7 +322,7 @@ - %79 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> - %80 = stablehlo.subtract %79, %74 : tensor<9xf32> - %81 = stablehlo.select %77, %80, %74 : tensor<9xi1>, tensor<9xf32> -- %82 = stablehlo.multiply %71, %81 : tensor<9xf32> -+ %82 = stablehlo.multiply %81, %71 : tensor<9xf32> - %83 = stablehlo.sine %82 : tensor<9xf32> - %84 = stablehlo.log %83 : tensor<9xf32> - %85 = stablehlo.is_finite %84 : (tensor<9xf32>) -> tensor<9xi1> -@@ -340,17 +340,17 @@ - %97 = stablehlo.add %95, %96 : tensor<9xf32> - %98 = stablehlo.constant dense<7.500000e+00> : tensor - %99 = stablehlo.constant dense<7.500000e+00> : tensor<9xf32> -- %100 = stablehlo.add %99, %95 : tensor<9xf32> -+ %100 = stablehlo.add %95, %99 : tensor<9xf32> - %101 = stablehlo.constant dense<2.01490307> : tensor - %102 = stablehlo.constant dense<2.01490307> : tensor<9xf32> - %103 = stablehlo.constant dense<7.500000e+00> : tensor<9xf32> - %104 = stablehlo.divide %95, %103 : tensor<9xf32> - %105 = stablehlo.log_plus_one %104 : tensor<9xf32> -- %106 = stablehlo.add %102, %105 : tensor<9xf32> -+ %106 = stablehlo.add %105, %102 : tensor<9xf32> - %107 = stablehlo.divide %100, %106 : tensor<9xf32> - %108 = stablehlo.subtract %97, %107 : tensor<9xf32> - %109 = stablehlo.multiply %108, %106 : tensor<9xf32> -- %110 = stablehlo.add %90, %109 : tensor<9xf32> -+ %110 = stablehlo.add %109, %90 : tensor<9xf32> - %111 = stablehlo.constant dense<1.000000e+00> : tensor - %112 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> - %113 = stablehlo.constant dense<676.520386> : tensor -@@ -361,7 +361,7 @@ - %118 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> - %119 = stablehlo.add %117, %118 : tensor<9xf32> - %120 = stablehlo.divide %114, %119 : tensor<9xf32> -- %121 = stablehlo.add %112, %120 : tensor<9xf32> -+ %121 = stablehlo.add %120, %112 : tensor<9xf32> - %122 = stablehlo.constant dense<-1259.13916> : tensor - %123 = stablehlo.constant dense<-1259.13916> : tensor<9xf32> - %124 = stablehlo.constant dense<1.000000e+00> : tensor -@@ -453,7 +453,7 @@ - %210 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> - %211 = stablehlo.subtract %210, %205 : tensor<9xf32> - %212 = stablehlo.select %208, %211, %205 : tensor<9xi1>, tensor<9xf32> -- %213 = stablehlo.multiply %202, %212 : tensor<9xf32> -+ %213 = stablehlo.multiply %212, %202 : tensor<9xf32> - %214 = stablehlo.sine %213 : tensor<9xf32> - %215 = stablehlo.log %214 : tensor<9xf32> - %216 = stablehlo.is_finite %215 : (tensor<9xf32>) -> tensor<9xi1> -@@ -471,17 +471,17 @@ - %228 = stablehlo.add %226, %227 : tensor<9xf32> - %229 = stablehlo.constant dense<7.500000e+00> : tensor - %230 = stablehlo.constant dense<7.500000e+00> : tensor<9xf32> -- %231 = stablehlo.add %230, %226 : tensor<9xf32> -+ %231 = stablehlo.add %226, %230 : tensor<9xf32> - %232 = stablehlo.constant dense<2.01490307> : tensor - %233 = stablehlo.constant dense<2.01490307> : tensor<9xf32> - %234 = stablehlo.constant dense<7.500000e+00> : tensor<9xf32> - %235 = stablehlo.divide %226, %234 : tensor<9xf32> - %236 = stablehlo.log_plus_one %235 : tensor<9xf32> -- %237 = stablehlo.add %233, %236 : tensor<9xf32> -+ %237 = stablehlo.add %236, %233 : tensor<9xf32> - %238 = stablehlo.divide %231, %237 : tensor<9xf32> - %239 = stablehlo.subtract %228, %238 : tensor<9xf32> - %240 = stablehlo.multiply %239, %237 : tensor<9xf32> -- %241 = stablehlo.add %221, %240 : tensor<9xf32> -+ %241 = stablehlo.add %240, %221 : tensor<9xf32> - %242 = stablehlo.constant dense<1.000000e+00> : tensor - %243 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> - %244 = stablehlo.constant dense<676.520386> : tensor -@@ -492,7 +492,7 @@ - %249 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> - %250 = stablehlo.add %248, %249 : tensor<9xf32> - %251 = stablehlo.divide %245, %250 : tensor<9xf32> -- %252 = stablehlo.add %243, %251 : tensor<9xf32> -+ %252 = stablehlo.add %251, %243 : tensor<9xf32> - %253 = stablehlo.constant dense<-1259.13916> : tensor - %254 = stablehlo.constant dense<-1259.13916> : tensor<9xf32> - %255 = stablehlo.constant dense<1.000000e+00> : tensor -@@ -586,7 +586,7 @@ - %343 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> - %344 = stablehlo.subtract %343, %338 : tensor<9xf32> - %345 = stablehlo.select %341, %344, %338 : tensor<9xi1>, tensor<9xf32> -- %346 = stablehlo.multiply %335, %345 : tensor<9xf32> -+ %346 = stablehlo.multiply %345, %335 : tensor<9xf32> - %347 = stablehlo.sine %346 : tensor<9xf32> - %348 = stablehlo.log %347 : tensor<9xf32> - %349 = stablehlo.is_finite %348 : (tensor<9xf32>) -> tensor<9xi1> -@@ -604,17 +604,17 @@ - %361 = stablehlo.add %359, %360 : tensor<9xf32> - %362 = stablehlo.constant dense<7.500000e+00> : tensor - %363 = stablehlo.constant dense<7.500000e+00> : tensor<9xf32> -- %364 = stablehlo.add %363, %359 : tensor<9xf32> -+ %364 = stablehlo.add %359, %363 : tensor<9xf32> - %365 = stablehlo.constant dense<2.01490307> : tensor - %366 = stablehlo.constant dense<2.01490307> : tensor<9xf32> - %367 = stablehlo.constant dense<7.500000e+00> : tensor<9xf32> - %368 = stablehlo.divide %359, %367 : tensor<9xf32> - %369 = stablehlo.log_plus_one %368 : tensor<9xf32> -- %370 = stablehlo.add %366, %369 : tensor<9xf32> -+ %370 = stablehlo.add %369, %366 : tensor<9xf32> - %371 = stablehlo.divide %364, %370 : tensor<9xf32> - %372 = stablehlo.subtract %361, %371 : tensor<9xf32> - %373 = stablehlo.multiply %372, %370 : tensor<9xf32> -- %374 = stablehlo.add %354, %373 : tensor<9xf32> -+ %374 = stablehlo.add %373, %354 : tensor<9xf32> - %375 = stablehlo.constant dense<1.000000e+00> : tensor - %376 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> - %377 = stablehlo.constant dense<676.520386> : tensor -@@ -625,7 +625,7 @@ - %382 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> - %383 = stablehlo.add %381, %382 : tensor<9xf32> - %384 = stablehlo.divide %378, %383 : tensor<9xf32> -- %385 = stablehlo.add %376, %384 : tensor<9xf32> -+ %385 = stablehlo.add %384, %376 : tensor<9xf32> - %386 = stablehlo.constant dense<-1259.13916> : tensor - %387 = stablehlo.constant dense<-1259.13916> : tensor<9xf32> - %388 = stablehlo.constant dense<1.000000e+00> : tensor -diff --ruN a/stablehlo/stablehlo/testdata/regularized_incomplete_beta__float16.mlir b/stablehlo/stablehlo/testdata/regularized_incomplete_beta__float16.mlir ---- stablehlo/stablehlo/testdata/regularized_incomplete_beta__float16.mlir -+++ stablehlo/stablehlo/testdata/regularized_incomplete_beta__float16.mlir -@@ -71,9 +71,9 @@ - %40 = stablehlo.multiply %38, %39 : tensor<9xf32> - %41 = stablehlo.constant dense<2.000000e+00> : tensor - %42 = stablehlo.constant dense<2.000000e+00> : tensor<9xf32> -- %43 = stablehlo.multiply %42, %32 : tensor<9xf32> -+ %43 = stablehlo.multiply %32, %42 : tensor<9xf32> - %44 = stablehlo.add %25, %43 : tensor<9xf32> -- %45 = stablehlo.multiply %42, %32 : tensor<9xf32> -+ %45 = stablehlo.multiply %32, %42 : tensor<9xf32> - %46 = stablehlo.add %25, %45 : tensor<9xf32> - %47 = stablehlo.constant dense<1.000000e+00> : tensor - %48 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> -@@ -83,10 +83,10 @@ - %52 = stablehlo.subtract %35, %32 : tensor<9xf32> - %53 = stablehlo.multiply %32, %52 : tensor<9xf32> - %54 = stablehlo.multiply %53, %39 : tensor<9xf32> -- %55 = stablehlo.multiply %42, %32 : tensor<9xf32> -+ %55 = stablehlo.multiply %32, %42 : tensor<9xf32> - %56 = stablehlo.add %25, %55 : tensor<9xf32> - %57 = stablehlo.subtract %56, %48 : tensor<9xf32> -- %58 = stablehlo.multiply %42, %32 : tensor<9xf32> -+ %58 = stablehlo.multiply %32, %42 : tensor<9xf32> - %59 = stablehlo.add %25, %58 : tensor<9xf32> - %60 = stablehlo.multiply %57, %59 : tensor<9xf32> - %61 = stablehlo.divide %54, %60 : tensor<9xf32> -@@ -225,9 +225,9 @@ - %502 = stablehlo.multiply %501, %iterArg_6 : tensor<9xf32> - %503 = stablehlo.constant dense<2.000000e+00> : tensor - %504 = stablehlo.constant dense<2.000000e+00> : tensor<9xf32> -- %505 = stablehlo.multiply %504, %496 : tensor<9xf32> -+ %505 = stablehlo.multiply %496, %504 : tensor<9xf32> - %506 = stablehlo.add %iterArg_4, %505 : tensor<9xf32> -- %507 = stablehlo.multiply %504, %496 : tensor<9xf32> -+ %507 = stablehlo.multiply %496, %504 : tensor<9xf32> - %508 = stablehlo.add %iterArg_4, %507 : tensor<9xf32> - %509 = stablehlo.constant dense<1.000000e+00> : tensor - %510 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> -@@ -237,10 +237,10 @@ - %514 = stablehlo.subtract %iterArg_5, %496 : tensor<9xf32> - %515 = stablehlo.multiply %496, %514 : tensor<9xf32> - %516 = stablehlo.multiply %515, %iterArg_6 : tensor<9xf32> -- %517 = stablehlo.multiply %504, %496 : tensor<9xf32> -+ %517 = stablehlo.multiply %496, %504 : tensor<9xf32> - %518 = stablehlo.add %iterArg_4, %517 : tensor<9xf32> - %519 = stablehlo.subtract %518, %510 : tensor<9xf32> -- %520 = stablehlo.multiply %504, %496 : tensor<9xf32> -+ %520 = stablehlo.multiply %496, %504 : tensor<9xf32> - %521 = stablehlo.add %iterArg_4, %520 : tensor<9xf32> - %522 = stablehlo.multiply %519, %521 : tensor<9xf32> - %523 = stablehlo.divide %516, %522 : tensor<9xf32> -@@ -322,7 +322,7 @@ - %79 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> - %80 = stablehlo.subtract %79, %74 : tensor<9xf32> - %81 = stablehlo.select %77, %80, %74 : tensor<9xi1>, tensor<9xf32> -- %82 = stablehlo.multiply %71, %81 : tensor<9xf32> -+ %82 = stablehlo.multiply %81, %71 : tensor<9xf32> - %83 = stablehlo.sine %82 : tensor<9xf32> - %84 = stablehlo.log %83 : tensor<9xf32> - %85 = stablehlo.is_finite %84 : (tensor<9xf32>) -> tensor<9xi1> -@@ -340,17 +340,17 @@ - %97 = stablehlo.add %95, %96 : tensor<9xf32> - %98 = stablehlo.constant dense<7.500000e+00> : tensor - %99 = stablehlo.constant dense<7.500000e+00> : tensor<9xf32> -- %100 = stablehlo.add %99, %95 : tensor<9xf32> -+ %100 = stablehlo.add %95, %99 : tensor<9xf32> - %101 = stablehlo.constant dense<2.01490307> : tensor - %102 = stablehlo.constant dense<2.01490307> : tensor<9xf32> - %103 = stablehlo.constant dense<7.500000e+00> : tensor<9xf32> - %104 = stablehlo.divide %95, %103 : tensor<9xf32> - %105 = stablehlo.log_plus_one %104 : tensor<9xf32> -- %106 = stablehlo.add %102, %105 : tensor<9xf32> -+ %106 = stablehlo.add %105, %102 : tensor<9xf32> - %107 = stablehlo.divide %100, %106 : tensor<9xf32> - %108 = stablehlo.subtract %97, %107 : tensor<9xf32> - %109 = stablehlo.multiply %108, %106 : tensor<9xf32> -- %110 = stablehlo.add %90, %109 : tensor<9xf32> -+ %110 = stablehlo.add %109, %90 : tensor<9xf32> - %111 = stablehlo.constant dense<1.000000e+00> : tensor - %112 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> - %113 = stablehlo.constant dense<676.520386> : tensor -@@ -361,7 +361,7 @@ - %118 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> - %119 = stablehlo.add %117, %118 : tensor<9xf32> - %120 = stablehlo.divide %114, %119 : tensor<9xf32> -- %121 = stablehlo.add %112, %120 : tensor<9xf32> -+ %121 = stablehlo.add %120, %112 : tensor<9xf32> - %122 = stablehlo.constant dense<-1259.13916> : tensor - %123 = stablehlo.constant dense<-1259.13916> : tensor<9xf32> - %124 = stablehlo.constant dense<1.000000e+00> : tensor -@@ -453,7 +453,7 @@ - %210 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> - %211 = stablehlo.subtract %210, %205 : tensor<9xf32> - %212 = stablehlo.select %208, %211, %205 : tensor<9xi1>, tensor<9xf32> -- %213 = stablehlo.multiply %202, %212 : tensor<9xf32> -+ %213 = stablehlo.multiply %212, %202 : tensor<9xf32> - %214 = stablehlo.sine %213 : tensor<9xf32> - %215 = stablehlo.log %214 : tensor<9xf32> - %216 = stablehlo.is_finite %215 : (tensor<9xf32>) -> tensor<9xi1> -@@ -471,17 +471,17 @@ - %228 = stablehlo.add %226, %227 : tensor<9xf32> - %229 = stablehlo.constant dense<7.500000e+00> : tensor - %230 = stablehlo.constant dense<7.500000e+00> : tensor<9xf32> -- %231 = stablehlo.add %230, %226 : tensor<9xf32> -+ %231 = stablehlo.add %226, %230 : tensor<9xf32> - %232 = stablehlo.constant dense<2.01490307> : tensor - %233 = stablehlo.constant dense<2.01490307> : tensor<9xf32> - %234 = stablehlo.constant dense<7.500000e+00> : tensor<9xf32> - %235 = stablehlo.divide %226, %234 : tensor<9xf32> - %236 = stablehlo.log_plus_one %235 : tensor<9xf32> -- %237 = stablehlo.add %233, %236 : tensor<9xf32> -+ %237 = stablehlo.add %236, %233 : tensor<9xf32> - %238 = stablehlo.divide %231, %237 : tensor<9xf32> - %239 = stablehlo.subtract %228, %238 : tensor<9xf32> - %240 = stablehlo.multiply %239, %237 : tensor<9xf32> -- %241 = stablehlo.add %221, %240 : tensor<9xf32> -+ %241 = stablehlo.add %240, %221 : tensor<9xf32> - %242 = stablehlo.constant dense<1.000000e+00> : tensor - %243 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> - %244 = stablehlo.constant dense<676.520386> : tensor -@@ -492,7 +492,7 @@ - %249 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> - %250 = stablehlo.add %248, %249 : tensor<9xf32> - %251 = stablehlo.divide %245, %250 : tensor<9xf32> -- %252 = stablehlo.add %243, %251 : tensor<9xf32> -+ %252 = stablehlo.add %251, %243 : tensor<9xf32> - %253 = stablehlo.constant dense<-1259.13916> : tensor - %254 = stablehlo.constant dense<-1259.13916> : tensor<9xf32> - %255 = stablehlo.constant dense<1.000000e+00> : tensor -@@ -586,7 +586,7 @@ - %343 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> - %344 = stablehlo.subtract %343, %338 : tensor<9xf32> - %345 = stablehlo.select %341, %344, %338 : tensor<9xi1>, tensor<9xf32> -- %346 = stablehlo.multiply %335, %345 : tensor<9xf32> -+ %346 = stablehlo.multiply %345, %335 : tensor<9xf32> - %347 = stablehlo.sine %346 : tensor<9xf32> - %348 = stablehlo.log %347 : tensor<9xf32> - %349 = stablehlo.is_finite %348 : (tensor<9xf32>) -> tensor<9xi1> -@@ -604,17 +604,17 @@ - %361 = stablehlo.add %359, %360 : tensor<9xf32> - %362 = stablehlo.constant dense<7.500000e+00> : tensor - %363 = stablehlo.constant dense<7.500000e+00> : tensor<9xf32> -- %364 = stablehlo.add %363, %359 : tensor<9xf32> -+ %364 = stablehlo.add %359, %363 : tensor<9xf32> - %365 = stablehlo.constant dense<2.01490307> : tensor - %366 = stablehlo.constant dense<2.01490307> : tensor<9xf32> - %367 = stablehlo.constant dense<7.500000e+00> : tensor<9xf32> - %368 = stablehlo.divide %359, %367 : tensor<9xf32> - %369 = stablehlo.log_plus_one %368 : tensor<9xf32> -- %370 = stablehlo.add %366, %369 : tensor<9xf32> -+ %370 = stablehlo.add %369, %366 : tensor<9xf32> - %371 = stablehlo.divide %364, %370 : tensor<9xf32> - %372 = stablehlo.subtract %361, %371 : tensor<9xf32> - %373 = stablehlo.multiply %372, %370 : tensor<9xf32> -- %374 = stablehlo.add %354, %373 : tensor<9xf32> -+ %374 = stablehlo.add %373, %354 : tensor<9xf32> - %375 = stablehlo.constant dense<1.000000e+00> : tensor - %376 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> - %377 = stablehlo.constant dense<676.520386> : tensor -@@ -625,7 +625,7 @@ - %382 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> - %383 = stablehlo.add %381, %382 : tensor<9xf32> - %384 = stablehlo.divide %378, %383 : tensor<9xf32> -- %385 = stablehlo.add %376, %384 : tensor<9xf32> -+ %385 = stablehlo.add %384, %376 : tensor<9xf32> - %386 = stablehlo.constant dense<-1259.13916> : tensor - %387 = stablehlo.constant dense<-1259.13916> : tensor<9xf32> - %388 = stablehlo.constant dense<1.000000e+00> : tensor -diff --ruN a/stablehlo/stablehlo/testdata/regularized_incomplete_beta__float32.mlir b/stablehlo/stablehlo/testdata/regularized_incomplete_beta__float32.mlir ---- stablehlo/stablehlo/testdata/regularized_incomplete_beta__float32.mlir -+++ stablehlo/stablehlo/testdata/regularized_incomplete_beta__float32.mlir -@@ -71,9 +71,9 @@ - %40 = stablehlo.multiply %38, %39 : tensor<9xf32> - %41 = stablehlo.constant dense<2.000000e+00> : tensor - %42 = stablehlo.constant dense<2.000000e+00> : tensor<9xf32> -- %43 = stablehlo.multiply %42, %32 : tensor<9xf32> -+ %43 = stablehlo.multiply %32, %42 : tensor<9xf32> - %44 = stablehlo.add %25, %43 : tensor<9xf32> -- %45 = stablehlo.multiply %42, %32 : tensor<9xf32> -+ %45 = stablehlo.multiply %32, %42 : tensor<9xf32> - %46 = stablehlo.add %25, %45 : tensor<9xf32> - %47 = stablehlo.constant dense<1.000000e+00> : tensor - %48 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> -@@ -83,10 +83,10 @@ - %52 = stablehlo.subtract %35, %32 : tensor<9xf32> - %53 = stablehlo.multiply %32, %52 : tensor<9xf32> - %54 = stablehlo.multiply %53, %39 : tensor<9xf32> -- %55 = stablehlo.multiply %42, %32 : tensor<9xf32> -+ %55 = stablehlo.multiply %32, %42 : tensor<9xf32> - %56 = stablehlo.add %25, %55 : tensor<9xf32> - %57 = stablehlo.subtract %56, %48 : tensor<9xf32> -- %58 = stablehlo.multiply %42, %32 : tensor<9xf32> -+ %58 = stablehlo.multiply %32, %42 : tensor<9xf32> - %59 = stablehlo.add %25, %58 : tensor<9xf32> - %60 = stablehlo.multiply %57, %59 : tensor<9xf32> - %61 = stablehlo.divide %54, %60 : tensor<9xf32> -@@ -222,9 +222,9 @@ - %498 = stablehlo.multiply %497, %iterArg_6 : tensor<9xf32> - %499 = stablehlo.constant dense<2.000000e+00> : tensor - %500 = stablehlo.constant dense<2.000000e+00> : tensor<9xf32> -- %501 = stablehlo.multiply %500, %492 : tensor<9xf32> -+ %501 = stablehlo.multiply %492, %500 : tensor<9xf32> - %502 = stablehlo.add %iterArg_4, %501 : tensor<9xf32> -- %503 = stablehlo.multiply %500, %492 : tensor<9xf32> -+ %503 = stablehlo.multiply %492, %500 : tensor<9xf32> - %504 = stablehlo.add %iterArg_4, %503 : tensor<9xf32> - %505 = stablehlo.constant dense<1.000000e+00> : tensor - %506 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> -@@ -234,10 +234,10 @@ - %510 = stablehlo.subtract %iterArg_5, %492 : tensor<9xf32> - %511 = stablehlo.multiply %492, %510 : tensor<9xf32> - %512 = stablehlo.multiply %511, %iterArg_6 : tensor<9xf32> -- %513 = stablehlo.multiply %500, %492 : tensor<9xf32> -+ %513 = stablehlo.multiply %492, %500 : tensor<9xf32> - %514 = stablehlo.add %iterArg_4, %513 : tensor<9xf32> - %515 = stablehlo.subtract %514, %506 : tensor<9xf32> -- %516 = stablehlo.multiply %500, %492 : tensor<9xf32> -+ %516 = stablehlo.multiply %492, %500 : tensor<9xf32> - %517 = stablehlo.add %iterArg_4, %516 : tensor<9xf32> - %518 = stablehlo.multiply %515, %517 : tensor<9xf32> - %519 = stablehlo.divide %512, %518 : tensor<9xf32> -@@ -319,7 +319,7 @@ - %76 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> - %77 = stablehlo.subtract %76, %71 : tensor<9xf32> - %78 = stablehlo.select %74, %77, %71 : tensor<9xi1>, tensor<9xf32> -- %79 = stablehlo.multiply %68, %78 : tensor<9xf32> -+ %79 = stablehlo.multiply %78, %68 : tensor<9xf32> - %80 = stablehlo.sine %79 : tensor<9xf32> - %81 = stablehlo.log %80 : tensor<9xf32> - %82 = stablehlo.is_finite %81 : (tensor<9xf32>) -> tensor<9xi1> -@@ -337,17 +337,17 @@ - %94 = stablehlo.add %92, %93 : tensor<9xf32> - %95 = stablehlo.constant dense<7.500000e+00> : tensor - %96 = stablehlo.constant dense<7.500000e+00> : tensor<9xf32> -- %97 = stablehlo.add %96, %92 : tensor<9xf32> -+ %97 = stablehlo.add %92, %96 : tensor<9xf32> - %98 = stablehlo.constant dense<2.01490307> : tensor - %99 = stablehlo.constant dense<2.01490307> : tensor<9xf32> - %100 = stablehlo.constant dense<7.500000e+00> : tensor<9xf32> - %101 = stablehlo.divide %92, %100 : tensor<9xf32> - %102 = stablehlo.log_plus_one %101 : tensor<9xf32> -- %103 = stablehlo.add %99, %102 : tensor<9xf32> -+ %103 = stablehlo.add %102, %99 : tensor<9xf32> - %104 = stablehlo.divide %97, %103 : tensor<9xf32> - %105 = stablehlo.subtract %94, %104 : tensor<9xf32> - %106 = stablehlo.multiply %105, %103 : tensor<9xf32> -- %107 = stablehlo.add %87, %106 : tensor<9xf32> -+ %107 = stablehlo.add %106, %87 : tensor<9xf32> - %108 = stablehlo.constant dense<1.000000e+00> : tensor - %109 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> - %110 = stablehlo.constant dense<676.520386> : tensor -@@ -358,7 +358,7 @@ - %115 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> - %116 = stablehlo.add %114, %115 : tensor<9xf32> - %117 = stablehlo.divide %111, %116 : tensor<9xf32> -- %118 = stablehlo.add %109, %117 : tensor<9xf32> -+ %118 = stablehlo.add %117, %109 : tensor<9xf32> - %119 = stablehlo.constant dense<-1259.13916> : tensor - %120 = stablehlo.constant dense<-1259.13916> : tensor<9xf32> - %121 = stablehlo.constant dense<1.000000e+00> : tensor -@@ -450,7 +450,7 @@ - %207 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> - %208 = stablehlo.subtract %207, %202 : tensor<9xf32> - %209 = stablehlo.select %205, %208, %202 : tensor<9xi1>, tensor<9xf32> -- %210 = stablehlo.multiply %199, %209 : tensor<9xf32> -+ %210 = stablehlo.multiply %209, %199 : tensor<9xf32> - %211 = stablehlo.sine %210 : tensor<9xf32> - %212 = stablehlo.log %211 : tensor<9xf32> - %213 = stablehlo.is_finite %212 : (tensor<9xf32>) -> tensor<9xi1> -@@ -468,17 +468,17 @@ - %225 = stablehlo.add %223, %224 : tensor<9xf32> - %226 = stablehlo.constant dense<7.500000e+00> : tensor - %227 = stablehlo.constant dense<7.500000e+00> : tensor<9xf32> -- %228 = stablehlo.add %227, %223 : tensor<9xf32> -+ %228 = stablehlo.add %223, %227 : tensor<9xf32> - %229 = stablehlo.constant dense<2.01490307> : tensor - %230 = stablehlo.constant dense<2.01490307> : tensor<9xf32> - %231 = stablehlo.constant dense<7.500000e+00> : tensor<9xf32> - %232 = stablehlo.divide %223, %231 : tensor<9xf32> - %233 = stablehlo.log_plus_one %232 : tensor<9xf32> -- %234 = stablehlo.add %230, %233 : tensor<9xf32> -+ %234 = stablehlo.add %233, %230 : tensor<9xf32> - %235 = stablehlo.divide %228, %234 : tensor<9xf32> - %236 = stablehlo.subtract %225, %235 : tensor<9xf32> - %237 = stablehlo.multiply %236, %234 : tensor<9xf32> -- %238 = stablehlo.add %218, %237 : tensor<9xf32> -+ %238 = stablehlo.add %237, %218 : tensor<9xf32> - %239 = stablehlo.constant dense<1.000000e+00> : tensor - %240 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> - %241 = stablehlo.constant dense<676.520386> : tensor -@@ -489,7 +489,7 @@ - %246 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> - %247 = stablehlo.add %245, %246 : tensor<9xf32> - %248 = stablehlo.divide %242, %247 : tensor<9xf32> -- %249 = stablehlo.add %240, %248 : tensor<9xf32> -+ %249 = stablehlo.add %248, %240 : tensor<9xf32> - %250 = stablehlo.constant dense<-1259.13916> : tensor - %251 = stablehlo.constant dense<-1259.13916> : tensor<9xf32> - %252 = stablehlo.constant dense<1.000000e+00> : tensor -@@ -583,7 +583,7 @@ - %340 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> - %341 = stablehlo.subtract %340, %335 : tensor<9xf32> - %342 = stablehlo.select %338, %341, %335 : tensor<9xi1>, tensor<9xf32> -- %343 = stablehlo.multiply %332, %342 : tensor<9xf32> -+ %343 = stablehlo.multiply %342, %332 : tensor<9xf32> - %344 = stablehlo.sine %343 : tensor<9xf32> - %345 = stablehlo.log %344 : tensor<9xf32> - %346 = stablehlo.is_finite %345 : (tensor<9xf32>) -> tensor<9xi1> -@@ -601,17 +601,17 @@ - %358 = stablehlo.add %356, %357 : tensor<9xf32> - %359 = stablehlo.constant dense<7.500000e+00> : tensor - %360 = stablehlo.constant dense<7.500000e+00> : tensor<9xf32> -- %361 = stablehlo.add %360, %356 : tensor<9xf32> -+ %361 = stablehlo.add %356, %360 : tensor<9xf32> - %362 = stablehlo.constant dense<2.01490307> : tensor - %363 = stablehlo.constant dense<2.01490307> : tensor<9xf32> - %364 = stablehlo.constant dense<7.500000e+00> : tensor<9xf32> - %365 = stablehlo.divide %356, %364 : tensor<9xf32> - %366 = stablehlo.log_plus_one %365 : tensor<9xf32> -- %367 = stablehlo.add %363, %366 : tensor<9xf32> -+ %367 = stablehlo.add %366, %363 : tensor<9xf32> - %368 = stablehlo.divide %361, %367 : tensor<9xf32> - %369 = stablehlo.subtract %358, %368 : tensor<9xf32> - %370 = stablehlo.multiply %369, %367 : tensor<9xf32> -- %371 = stablehlo.add %351, %370 : tensor<9xf32> -+ %371 = stablehlo.add %370, %351 : tensor<9xf32> - %372 = stablehlo.constant dense<1.000000e+00> : tensor - %373 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> - %374 = stablehlo.constant dense<676.520386> : tensor -@@ -622,7 +622,7 @@ - %379 = stablehlo.constant dense<1.000000e+00> : tensor<9xf32> - %380 = stablehlo.add %378, %379 : tensor<9xf32> - %381 = stablehlo.divide %375, %380 : tensor<9xf32> -- %382 = stablehlo.add %373, %381 : tensor<9xf32> -+ %382 = stablehlo.add %381, %373 : tensor<9xf32> - %383 = stablehlo.constant dense<-1259.13916> : tensor - %384 = stablehlo.constant dense<-1259.13916> : tensor<9xf32> - %385 = stablehlo.constant dense<1.000000e+00> : tensor -diff --ruN a/stablehlo/stablehlo/testdata/sinh_shape_bfloat16_20_20.mlir b/stablehlo/stablehlo/testdata/sinh_shape_bfloat16_20_20.mlir ---- stablehlo/stablehlo/testdata/sinh_shape_bfloat16_20_20.mlir -+++ stablehlo/stablehlo/testdata/sinh_shape_bfloat16_20_20.mlir -@@ -19,7 +19,7 @@ - %13 = stablehlo.add %10, %11 : tensor<20x20xf32> - %14 = stablehlo.divide %10, %13 : tensor<20x20xf32> - %15 = stablehlo.add %10, %14 : tensor<20x20xf32> -- %16 = stablehlo.multiply %12, %15 : tensor<20x20xf32> -+ %16 = stablehlo.multiply %15, %12 : tensor<20x20xf32> - %17 = stablehlo.abs %2 : tensor<20x20xf32> - %18 = stablehlo.compare LT, %17, %11 : (tensor<20x20xf32>, tensor<20x20xf32>) -> tensor<20x20xi1> - %19 = stablehlo.select %18, %16, %9 : tensor<20x20xi1>, tensor<20x20xf32> -diff --ruN a/stablehlo/stablehlo/testdata/sinh_shape_float16_20_20.mlir b/stablehlo/stablehlo/testdata/sinh_shape_float16_20_20.mlir ---- stablehlo/stablehlo/testdata/sinh_shape_float16_20_20.mlir -+++ stablehlo/stablehlo/testdata/sinh_shape_float16_20_20.mlir -@@ -19,7 +19,7 @@ - %13 = stablehlo.add %10, %11 : tensor<20x20xf32> - %14 = stablehlo.divide %10, %13 : tensor<20x20xf32> - %15 = stablehlo.add %10, %14 : tensor<20x20xf32> -- %16 = stablehlo.multiply %12, %15 : tensor<20x20xf32> -+ %16 = stablehlo.multiply %15, %12 : tensor<20x20xf32> - %17 = stablehlo.abs %2 : tensor<20x20xf32> - %18 = stablehlo.compare LT, %17, %11 : (tensor<20x20xf32>, tensor<20x20xf32>) -> tensor<20x20xi1> - %19 = stablehlo.select %18, %16, %9 : tensor<20x20xi1>, tensor<20x20xf32> -diff --ruN a/stablehlo/stablehlo/testdata/sinh_shape_float32_20_20.mlir b/stablehlo/stablehlo/testdata/sinh_shape_float32_20_20.mlir ---- stablehlo/stablehlo/testdata/sinh_shape_float32_20_20.mlir -+++ stablehlo/stablehlo/testdata/sinh_shape_float32_20_20.mlir -@@ -18,7 +18,7 @@ - %12 = stablehlo.add %9, %10 : tensor<20x20xf32> - %13 = stablehlo.divide %9, %12 : tensor<20x20xf32> - %14 = stablehlo.add %9, %13 : tensor<20x20xf32> -- %15 = stablehlo.multiply %11, %14 : tensor<20x20xf32> -+ %15 = stablehlo.multiply %14, %11 : tensor<20x20xf32> - %16 = stablehlo.abs %0 : tensor<20x20xf32> - %17 = stablehlo.compare LT, %16, %10 : (tensor<20x20xf32>, tensor<20x20xf32>) -> tensor<20x20xi1> - %18 = stablehlo.select %17, %15, %8 : tensor<20x20xi1>, tensor<20x20xf32> -diff --ruN a/stablehlo/stablehlo/testdata/slice_in_dim_limit_neg_dynamic.mlir b/stablehlo/stablehlo/testdata/slice_in_dim_limit_neg_dynamic.mlir ---- stablehlo/stablehlo/testdata/slice_in_dim_limit_neg_dynamic.mlir -+++ stablehlo/stablehlo/testdata/slice_in_dim_limit_neg_dynamic.mlir -@@ -3,7 +3,7 @@ - module @jit_fun_flat_jax { - func.func public @main(%arg0: tensor, %arg1: tensor {mhlo.sharding = ""}) -> tensor { - %0 = stablehlo.constant dense<-1> : tensor -- %1 = stablehlo.add %0, %arg0 : tensor -+ %1 = stablehlo.add %arg0, %0 : tensor - %2 = stablehlo.constant dense<0> : tensor<1xi32> - %3 = stablehlo.constant dense<0> : tensor<1xi32> - %4 = stablehlo.concatenate %2, %3, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> -diff --ruN a/stablehlo/stablehlo/testdata/slice_in_dim_start_neg_dynamic.mlir b/stablehlo/stablehlo/testdata/slice_in_dim_start_neg_dynamic.mlir ---- stablehlo/stablehlo/testdata/slice_in_dim_start_neg_dynamic.mlir -+++ stablehlo/stablehlo/testdata/slice_in_dim_start_neg_dynamic.mlir -@@ -3,7 +3,7 @@ - module @jit_fun_flat_jax { - func.func public @main(%arg0: tensor, %arg1: tensor {mhlo.sharding = ""}) -> tensor<1x4xf32> { - %0 = stablehlo.constant dense<-1> : tensor -- %1 = stablehlo.add %0, %arg0 : tensor -+ %1 = stablehlo.add %arg0, %0 : tensor - %2 = stablehlo.convert %1 : (tensor) -> tensor - %3 = stablehlo.reshape %2 : (tensor) -> tensor<1xi32> - %4 = stablehlo.constant dense<0> : tensor<1xi32> -diff --ruN a/stablehlo/stablehlo/testdata/take__enable_xla_True_dynamic.mlir b/stablehlo/stablehlo/testdata/take__enable_xla_True_dynamic.mlir ---- stablehlo/stablehlo/testdata/take__enable_xla_True_dynamic.mlir -+++ stablehlo/stablehlo/testdata/take__enable_xla_True_dynamic.mlir -@@ -29,7 +29,7 @@ - %20 = stablehlo.compare LT, %8, %19, SIGNED : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> - %21 = stablehlo.constant dense<3> : tensor - %22 = stablehlo.broadcast_in_dim %21, dims = [] : (tensor) -> tensor<1xi64> -- %23 = stablehlo.add %8, %22 : tensor<1xi64> -+ %23 = stablehlo.add %22, %8 : tensor<1xi64> - %24 = stablehlo.select %20, %23, %8 : tensor<1xi1>, tensor<1xi64> - %25 = stablehlo.convert %24 : (tensor<1xi64>) -> tensor<1xi32> - %26 = stablehlo.broadcast_in_dim %25, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> -@@ -46,7 +46,7 @@ - %37 = stablehlo.compare LT, %9, %36, SIGNED : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> - %38 = stablehlo.constant dense<3> : tensor - %39 = stablehlo.broadcast_in_dim %38, dims = [] : (tensor) -> tensor<1xi64> -- %40 = stablehlo.add %9, %39 : tensor<1xi64> -+ %40 = stablehlo.add %39, %9 : tensor<1xi64> - %41 = stablehlo.select %37, %40, %9 : tensor<1xi1>, tensor<1xi64> - %42 = stablehlo.convert %41 : (tensor<1xi64>) -> tensor<1xi32> - %43 = stablehlo.broadcast_in_dim %42, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> -diff --ruN a/stablehlo/stablehlo/testdata/take_along_axis_0_dynamic.mlir b/stablehlo/stablehlo/testdata/take_along_axis_0_dynamic.mlir ---- stablehlo/stablehlo/testdata/take_along_axis_0_dynamic.mlir -+++ stablehlo/stablehlo/testdata/take_along_axis_0_dynamic.mlir -@@ -34,7 +34,7 @@ - %25 = stablehlo.compare LT, %16, %24, SIGNED : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> - %26 = stablehlo.constant dense<2> : tensor - %27 = stablehlo.broadcast_in_dim %26, dims = [] : (tensor) -> tensor<1xi64> -- %28 = stablehlo.add %16, %27 : tensor<1xi64> -+ %28 = stablehlo.add %27, %16 : tensor<1xi64> - %29 = stablehlo.select %25, %28, %16 : tensor<1xi1>, tensor<1xi64> - %30 = stablehlo.convert %29 : (tensor<1xi64>) -> tensor<1xi32> - %31 = stablehlo.broadcast_in_dim %30, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> -@@ -49,7 +49,7 @@ - %40 = stablehlo.compare LT, %17, %39, SIGNED : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1> - %41 = stablehlo.constant dense<2> : tensor - %42 = stablehlo.broadcast_in_dim %41, dims = [] : (tensor) -> tensor<1xi64> -- %43 = stablehlo.add %17, %42 : tensor<1xi64> -+ %43 = stablehlo.add %42, %17 : tensor<1xi64> - %44 = stablehlo.select %40, %43, %17 : tensor<1xi1>, tensor<1xi64> - %45 = stablehlo.convert %44 : (tensor<1xi64>) -> tensor<1xi32> - %46 = stablehlo.broadcast_in_dim %45, dims = [0] : (tensor<1xi32>) -> tensor<1x1xi32> -diff --ruN a/stablehlo/stablehlo/testdata/take_along_axis_1_dynamic.mlir b/stablehlo/stablehlo/testdata/take_along_axis_1_dynamic.mlir ---- stablehlo/stablehlo/testdata/take_along_axis_1_dynamic.mlir -+++ stablehlo/stablehlo/testdata/take_along_axis_1_dynamic.mlir -@@ -47,7 +47,7 @@ - %38 = stablehlo.compare LT, %29, %37, SIGNED : (tensor<2xi64>, tensor<2xi64>) -> tensor<2xi1> - %39 = stablehlo.constant dense<2> : tensor - %40 = stablehlo.broadcast_in_dim %39, dims = [] : (tensor) -> tensor<2xi64> -- %41 = stablehlo.add %29, %40 : tensor<2xi64> -+ %41 = stablehlo.add %40, %29 : tensor<2xi64> - %42 = stablehlo.select %38, %41, %29 : tensor<2xi1>, tensor<2xi64> - %43 = stablehlo.convert %42 : (tensor<2xi64>) -> tensor<2xi32> - %44 = stablehlo.broadcast_in_dim %43, dims = [0] : (tensor<2xi32>) -> tensor<2x1xi32> -@@ -62,7 +62,7 @@ - %53 = stablehlo.compare LT, %30, %52, SIGNED : (tensor<2xi64>, tensor<2xi64>) -> tensor<2xi1> - %54 = stablehlo.constant dense<2> : tensor - %55 = stablehlo.broadcast_in_dim %54, dims = [] : (tensor) -> tensor<2xi64> -- %56 = stablehlo.add %30, %55 : tensor<2xi64> -+ %56 = stablehlo.add %55, %30 : tensor<2xi64> - %57 = stablehlo.select %53, %56, %30 : tensor<2xi1>, tensor<2xi64> - %58 = stablehlo.convert %57 : (tensor<2xi64>) -> tensor<2xi32> - %59 = stablehlo.broadcast_in_dim %58, dims = [0] : (tensor<2xi32>) -> tensor<2x1xi32> -diff --ruN a/stablehlo/stablehlo/testdata/vmap_gather_dtypes_shape_float32_10__axis_0_enable_xla_True_dynamic.mlir b/stablehlo/stablehlo/testdata/vmap_gather_dtypes_shape_float32_10__axis_0_enable_xla_True_dynamic.mlir ---- stablehlo/stablehlo/testdata/vmap_gather_dtypes_shape_float32_10__axis_0_enable_xla_True_dynamic.mlir -+++ stablehlo/stablehlo/testdata/vmap_gather_dtypes_shape_float32_10__axis_0_enable_xla_True_dynamic.mlir -@@ -41,7 +41,7 @@ - %32 = stablehlo.compare LT, %22, %31, SIGNED : (tensor<2xi64>, tensor<2xi64>) -> tensor<2xi1> - %33 = stablehlo.constant dense<2> : tensor - %34 = stablehlo.broadcast_in_dim %33, dims = [] : (tensor) -> tensor<2xi64> -- %35 = stablehlo.add %22, %34 : tensor<2xi64> -+ %35 = stablehlo.add %34, %22 : tensor<2xi64> - %36 = stablehlo.select %32, %35, %22 : tensor<2xi1>, tensor<2xi64> - %37 = stablehlo.convert %36 : (tensor<2xi64>) -> tensor<2xi32> - %38 = stablehlo.broadcast_in_dim %37, dims = [0] : (tensor<2xi32>) -> tensor<2x1xi32> -@@ -56,7 +56,7 @@ - %47 = stablehlo.compare LT, %23, %46, SIGNED : (tensor<2xi64>, tensor<2xi64>) -> tensor<2xi1> - %48 = stablehlo.constant dense<2> : tensor - %49 = stablehlo.broadcast_in_dim %48, dims = [] : (tensor) -> tensor<2xi64> -- %50 = stablehlo.add %23, %49 : tensor<2xi64> -+ %50 = stablehlo.add %49, %23 : tensor<2xi64> - %51 = stablehlo.select %47, %50, %23 : tensor<2xi1>, tensor<2xi64> - %52 = stablehlo.convert %51 : (tensor<2xi64>) -> tensor<2xi32> - %53 = stablehlo.broadcast_in_dim %52, dims = [0] : (tensor<2xi32>) -> tensor<2x1xi32> -diff --ruN a/stablehlo/stablehlo/testdata/vmap_gather_from_take_indices_name__1__axis_0_enable_xla_True_mode_fill_dynamic.mlir b/stablehlo/stablehlo/testdata/vmap_gather_from_take_indices_name__1__axis_0_enable_xla_True_mode_fill_dynamic.mlir ---- stablehlo/stablehlo/testdata/vmap_gather_from_take_indices_name__1__axis_0_enable_xla_True_mode_fill_dynamic.mlir -+++ stablehlo/stablehlo/testdata/vmap_gather_from_take_indices_name__1__axis_0_enable_xla_True_mode_fill_dynamic.mlir -@@ -45,7 +45,7 @@ - %36 = stablehlo.compare LT, %22, %35, SIGNED : (tensor<2xi64>, tensor<2xi64>) -> tensor<2xi1> - %37 = stablehlo.constant dense<4> : tensor - %38 = stablehlo.broadcast_in_dim %37, dims = [] : (tensor) -> tensor<2xi64> -- %39 = stablehlo.add %22, %38 : tensor<2xi64> -+ %39 = stablehlo.add %38, %22 : tensor<2xi64> - %40 = stablehlo.select %36, %39, %22 : tensor<2xi1>, tensor<2xi64> - %41 = stablehlo.convert %40 : (tensor<2xi64>) -> tensor<2xi32> - %42 = stablehlo.broadcast_in_dim %41, dims = [0] : (tensor<2xi32>) -> tensor<2x1xi32> -@@ -64,7 +64,7 @@ - %55 = stablehlo.compare LT, %23, %54, SIGNED : (tensor<2xi64>, tensor<2xi64>) -> tensor<2xi1> - %56 = stablehlo.constant dense<4> : tensor - %57 = stablehlo.broadcast_in_dim %56, dims = [] : (tensor) -> tensor<2xi64> -- %58 = stablehlo.add %23, %57 : tensor<2xi64> -+ %58 = stablehlo.add %57, %23 : tensor<2xi64> - %59 = stablehlo.select %55, %58, %23 : tensor<2xi1>, tensor<2xi64> - %60 = stablehlo.convert %59 : (tensor<2xi64>) -> tensor<2xi32> - %61 = stablehlo.broadcast_in_dim %60, dims = [0] : (tensor<2xi32>) -> tensor<2x1xi32> -diff --ruN a/stablehlo/stablehlo/testdata/vmap_gather_from_take_indices_name__1__axis_2_enable_xla_True_mode_fill_dynamic.mlir b/stablehlo/stablehlo/testdata/vmap_gather_from_take_indices_name__1__axis_2_enable_xla_True_mode_fill_dynamic.mlir ---- stablehlo/stablehlo/testdata/vmap_gather_from_take_indices_name__1__axis_2_enable_xla_True_mode_fill_dynamic.mlir -+++ stablehlo/stablehlo/testdata/vmap_gather_from_take_indices_name__1__axis_2_enable_xla_True_mode_fill_dynamic.mlir -@@ -45,7 +45,7 @@ - %36 = stablehlo.compare LT, %22, %35, SIGNED : (tensor<2xi64>, tensor<2xi64>) -> tensor<2xi1> - %37 = stablehlo.constant dense<4> : tensor - %38 = stablehlo.broadcast_in_dim %37, dims = [] : (tensor) -> tensor<2xi64> -- %39 = stablehlo.add %22, %38 : tensor<2xi64> -+ %39 = stablehlo.add %38, %22 : tensor<2xi64> - %40 = stablehlo.select %36, %39, %22 : tensor<2xi1>, tensor<2xi64> - %41 = stablehlo.convert %40 : (tensor<2xi64>) -> tensor<2xi32> - %42 = stablehlo.broadcast_in_dim %41, dims = [0] : (tensor<2xi32>) -> tensor<2x1xi32> -@@ -64,7 +64,7 @@ - %55 = stablehlo.compare LT, %23, %54, SIGNED : (tensor<2xi64>, tensor<2xi64>) -> tensor<2xi1> - %56 = stablehlo.constant dense<4> : tensor - %57 = stablehlo.broadcast_in_dim %56, dims = [] : (tensor) -> tensor<2xi64> -- %58 = stablehlo.add %23, %57 : tensor<2xi64> -+ %58 = stablehlo.add %57, %23 : tensor<2xi64> - %59 = stablehlo.select %55, %58, %23 : tensor<2xi1>, tensor<2xi64> - %60 = stablehlo.convert %59 : (tensor<2xi64>) -> tensor<2xi32> - %61 = stablehlo.broadcast_in_dim %60, dims = [0] : (tensor<2xi32>) -> tensor<2x1xi32> -diff --ruN a/stablehlo/stablehlo/tests/stablehlo_canonicalize_dynamism.mlir b/stablehlo/stablehlo/tests/stablehlo_canonicalize_dynamism.mlir ---- stablehlo/stablehlo/tests/stablehlo_canonicalize_dynamism.mlir -+++ stablehlo/stablehlo/tests/stablehlo_canonicalize_dynamism.mlir -@@ -426,6 +426,172 @@ - - // ----- - -+// CHECK-LABEL: func @dynamic_reduce_window_success_static_result_type -+func.func @dynamic_reduce_window_success_static_result_type(%arg0: tensor<3x2xf32>, %arg1: tensor) -> tensor<2x2xf32> { -+ // CHECK-NOT: stablehlo.dynamic_reduce_window -+ // CHECK: "stablehlo.reduce_window"(%arg0, %arg1) ({ -+ // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG2:arg.*]]: tensor, %[[ARG3:arg.*]]: tensor): -+ // CHECK-NEXT: %[[VAL1:.*]] = stablehlo.add %arg2, %arg3 : tensor -+ // CHECK-NEXT: stablehlo.return %[[VAL1]] : tensor -+ // CHECK-NEXT: }) { -+ // CHECK-SAME: base_dilations = dense<[2, 1]> : tensor<2xi64>, -+ // CHECK-SAME{LITERAL}: padding = dense<[[2, 1], [0, 0]]> : tensor<2x2xi64>, -+ // CHECK-SAME: window_dilations = dense<[3, 1]> : tensor<2xi64>, -+ // CHECK-SAME: window_dimensions = dense<[2, 1]> : tensor<2xi64>, -+ // CHECK-SAME: window_strides = dense<[4, 1]> : tensor<2xi64> -+ // CHECK-SAME: } : (tensor<3x2xf32>, tensor) -> tensor<2x2xf32> -+ %0 = stablehlo.constant dense<[2, 1]> : tensor<2xi64> -+ %1 = stablehlo.constant dense<[4, 1]> : tensor<2xi64> -+ %2 = stablehlo.constant dense<[2, 1]> : tensor<2xi64> -+ %3 = stablehlo.constant dense<[3, 1]> : tensor<2xi64> -+ %4 = stablehlo.constant dense<[[2, 1], [0, 0]]> : tensor<2x2xi64> -+ %5 = stablehlo.custom_call @stablehlo.dynamic_reduce_window(%arg0, %arg1, %0, %1, %2, %3, %4) { -+ called_computations = [@dynamic_reduce_window0] -+ } : (tensor<3x2xf32>, tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2x2xi64>) -> tensor<2x2xf32> -+ func.return %5 : tensor<2x2xf32> -+} -+ -+func.func private @dynamic_reduce_window0(%arg0: tensor, %arg1: tensor) -> tensor { -+ %0 = stablehlo.add %arg0, %arg1 : tensor -+ func.return %0 : tensor -+} -+ -+// ----- -+ -+// CHECK-LABEL: func @dynamic_reduce_window_success_dynamic_result_type -+func.func @dynamic_reduce_window_success_dynamic_result_type(%arg0: tensor, %arg1: tensor) -> tensor { -+ // CHECK-NOT: stablehlo.dynamic_reduce_window -+ // CHECK: "stablehlo.reduce_window"(%arg0, %arg1) ({ -+ // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG2:arg.*]]: tensor, %[[ARG3:arg.*]]: tensor): -+ // CHECK-NEXT: %[[VAL1:.*]] = stablehlo.add %arg2, %arg3 : tensor -+ // CHECK-NEXT: stablehlo.return %[[VAL1]] : tensor -+ // CHECK-NEXT: }) { -+ // CHECK-SAME: base_dilations = dense<[2, 1]> : tensor<2xi64>, -+ // CHECK-SAME{LITERAL}: padding = dense<[[2, 1], [0, 0]]> : tensor<2x2xi64>, -+ // CHECK-SAME: window_dilations = dense<[3, 1]> : tensor<2xi64>, -+ // CHECK-SAME: window_dimensions = dense<[2, 1]> : tensor<2xi64>, -+ // CHECK-SAME: window_strides = dense<[4, 1]> : tensor<2xi64> -+ // CHECK-SAME: } : (tensor, tensor) -> tensor -+ %0 = stablehlo.constant dense<[2, 1]> : tensor<2xi64> -+ %1 = stablehlo.constant dense<[4, 1]> : tensor<2xi64> -+ %2 = stablehlo.constant dense<[2, 1]> : tensor<2xi64> -+ %3 = stablehlo.constant dense<[3, 1]> : tensor<2xi64> -+ %4 = stablehlo.constant dense<[[2, 1], [0, 0]]> : tensor<2x2xi64> -+ %5 = stablehlo.custom_call @stablehlo.dynamic_reduce_window(%arg0, %arg1, %0, %1, %2, %3, %4) { -+ called_computations = [@dynamic_reduce_window0] -+ } : (tensor, tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2x2xi64>) -> tensor -+ func.return %5 : tensor -+} -+ -+func.func private @dynamic_reduce_window0(%arg0: tensor, %arg1: tensor) -> tensor { -+ %0 = stablehlo.add %arg0, %arg1 : tensor -+ func.return %0 : tensor -+} -+ -+// TODO(burmako): Implement tests for verification failures for dynamic_reduce_window. -+ -+// ----- -+ -+// CHECK-LABEL: func @dynamic_reduce_window_inapplicable_dynamic_window_dimensions -+func.func @dynamic_reduce_window_inapplicable_dynamic_window_dimensions(%arg0: tensor<3x2xf32>, %arg1: tensor, %arg2: tensor<2xi64>) -> tensor<2x2xf32> { -+ // CHECK: stablehlo.dynamic_reduce_window -+ %0 = stablehlo.constant dense<[4, 1]> : tensor<2xi64> -+ %1 = stablehlo.constant dense<[2, 1]> : tensor<2xi64> -+ %2 = stablehlo.constant dense<[3, 1]> : tensor<2xi64> -+ %3 = stablehlo.constant dense<[[2, 1], [0, 0]]> : tensor<2x2xi64> -+ %4 = stablehlo.custom_call @stablehlo.dynamic_reduce_window(%arg0, %arg1, %arg2, %0, %1, %2, %3) { -+ called_computations = [@dynamic_reduce_window0] -+ } : (tensor<3x2xf32>, tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2x2xi64>) -> tensor<2x2xf32> -+ func.return %4 : tensor<2x2xf32> -+} -+ -+func.func private @dynamic_reduce_window0(%arg0: tensor, %arg1: tensor) -> tensor { -+ %0 = stablehlo.add %arg0, %arg1 : tensor -+ func.return %0 : tensor -+} -+ -+// ----- -+ -+// CHECK-LABEL: func @dynamic_reduce_window_inapplicable_dynamic_window_strides -+func.func @dynamic_reduce_window_inapplicable_dynamic_window_strides(%arg0: tensor<3x2xf32>, %arg1: tensor, %arg2: tensor<2xi64>) -> tensor<2x2xf32> { -+ // CHECK: stablehlo.dynamic_reduce_window -+ %0 = stablehlo.constant dense<[2, 1]> : tensor<2xi64> -+ %1 = stablehlo.constant dense<[2, 1]> : tensor<2xi64> -+ %2 = stablehlo.constant dense<[3, 1]> : tensor<2xi64> -+ %3 = stablehlo.constant dense<[[2, 1], [0, 0]]> : tensor<2x2xi64> -+ %4 = stablehlo.custom_call @stablehlo.dynamic_reduce_window(%arg0, %arg1, %0, %arg2, %1, %2, %3) { -+ called_computations = [@dynamic_reduce_window0] -+ } : (tensor<3x2xf32>, tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2x2xi64>) -> tensor<2x2xf32> -+ func.return %4 : tensor<2x2xf32> -+} -+ -+func.func private @dynamic_reduce_window0(%arg0: tensor, %arg1: tensor) -> tensor { -+ %0 = stablehlo.add %arg0, %arg1 : tensor -+ func.return %0 : tensor -+} -+ -+// ----- -+ -+// CHECK-LABEL: func @dynamic_reduce_window_inapplicable_dynamic_base_dilations -+func.func @dynamic_reduce_window_inapplicable_dynamic_base_dilations(%arg0: tensor<3x2xf32>, %arg1: tensor, %arg2: tensor<2xi64>) -> tensor<2x2xf32> { -+ // CHECK: stablehlo.dynamic_reduce_window -+ %0 = stablehlo.constant dense<[2, 1]> : tensor<2xi64> -+ %1 = stablehlo.constant dense<[4, 1]> : tensor<2xi64> -+ %2 = stablehlo.constant dense<[3, 1]> : tensor<2xi64> -+ %3 = stablehlo.constant dense<[[2, 1], [0, 0]]> : tensor<2x2xi64> -+ %4 = stablehlo.custom_call @stablehlo.dynamic_reduce_window(%arg0, %arg1, %0, %1, %arg2, %2, %3) { -+ called_computations = [@dynamic_reduce_window0] -+ } : (tensor<3x2xf32>, tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2x2xi64>) -> tensor<2x2xf32> -+ func.return %4 : tensor<2x2xf32> -+} -+ -+func.func private @dynamic_reduce_window0(%arg0: tensor, %arg1: tensor) -> tensor { -+ %0 = stablehlo.add %arg0, %arg1 : tensor -+ func.return %0 : tensor -+} -+ -+// ----- -+ -+// CHECK-LABEL: func @dynamic_reduce_window_inapplicable_dynamic_window_dilations -+func.func @dynamic_reduce_window_inapplicable_dynamic_window_dilations(%arg0: tensor<3x2xf32>, %arg1: tensor, %arg2: tensor<2xi64>) -> tensor<2x2xf32> { -+ // CHECK: stablehlo.dynamic_reduce_window -+ %0 = stablehlo.constant dense<[2, 1]> : tensor<2xi64> -+ %1 = stablehlo.constant dense<[4, 1]> : tensor<2xi64> -+ %2 = stablehlo.constant dense<[2, 1]> : tensor<2xi64> -+ %3 = stablehlo.constant dense<[[2, 1], [0, 0]]> : tensor<2x2xi64> -+ %4 = stablehlo.custom_call @stablehlo.dynamic_reduce_window(%arg0, %arg1, %0, %1, %2, %arg2, %3) { -+ called_computations = [@dynamic_reduce_window0] -+ } : (tensor<3x2xf32>, tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2x2xi64>) -> tensor<2x2xf32> -+ func.return %4 : tensor<2x2xf32> -+} -+ -+func.func private @dynamic_reduce_window0(%arg0: tensor, %arg1: tensor) -> tensor { -+ %0 = stablehlo.add %arg0, %arg1 : tensor -+ func.return %0 : tensor -+} -+ -+// ----- -+ -+// CHECK-LABEL: func @dynamic_reduce_window_inapplicable_dynamic_padding -+func.func @dynamic_reduce_window_inapplicable_dynamic_padding(%arg0: tensor<3x2xf32>, %arg1: tensor, %arg2: tensor<2x2xi64>) -> tensor<2x2xf32> { -+ // CHECK: stablehlo.dynamic_reduce_window -+ %0 = stablehlo.constant dense<[2, 1]> : tensor<2xi64> -+ %1 = stablehlo.constant dense<[4, 1]> : tensor<2xi64> -+ %2 = stablehlo.constant dense<[2, 1]> : tensor<2xi64> -+ %3 = stablehlo.constant dense<[3, 1]> : tensor<2xi64> -+ %4 = stablehlo.custom_call @stablehlo.dynamic_reduce_window(%arg0, %arg1, %0, %1, %2, %3, %arg2) { -+ called_computations = [@dynamic_reduce_window0] -+ } : (tensor<3x2xf32>, tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2x2xi64>) -> tensor<2x2xf32> -+ func.return %4 : tensor<2x2xf32> -+} -+ -+func.func private @dynamic_reduce_window0(%arg0: tensor, %arg1: tensor) -> tensor { -+ %0 = stablehlo.add %arg0, %arg1 : tensor -+ func.return %0 : tensor -+} -+ -+// ----- -+ - // CHECK-LABEL: func @dynamic_reshape_success - func.func @dynamic_reshape_success(%arg0: tensor<4xf32>) -> tensor<1x4xf32> { - // CHECK-NOT: stablehlo.dynamic_reshape -@@ -452,6 +618,44 @@ - %0 = stablehlo.constant dense<[1, 4]> : tensor<2xi64> - %1 = stablehlo.dynamic_reshape %arg0, %0 : (tensor<4xf32>, tensor<2xi64>) -> tensor<1x?xf32> - return %1 : tensor<1x?xf32> -+} -+ -+// ----- -+ -+// CHECK-LABEL: func @dynamic_rng_bit_generator_success -+func.func @dynamic_rng_bit_generator_success(%arg0: tensor<2xui64>) -> tensor<1x4xf32> { -+ // CHECK-NOT: stablehlo.dynamic_rng_bit_generator -+ // CHECK: stablehlo.rng_bit_generator %arg0, algorithm = DEFAULT : (tensor<2xui64>) -> (tensor<2xui64>, tensor<1x4xf32>) -+ %0 = stablehlo.constant dense<[1, 4]> : tensor<2xi64> -+ %1:2 = stablehlo.custom_call @stablehlo.dynamic_rng_bit_generator(%arg0, %0) { -+ rng_algorithm = #stablehlo -+ } : (tensor<2xui64>, tensor<2xi64>) -> (tensor<2xui64>, tensor<1x4xf32>) -+ return %1#1 : tensor<1x4xf32> -+} -+ -+// TODO(burmako): Implement tests for verification failures for dynamic_rng_bit_generator. -+ -+// ----- -+ -+// CHECK-LABEL: func @dynamic_rng_bit_generator_inapplicable_dynamic_output_shape -+func.func @dynamic_rng_bit_generator_inapplicable_dynamic_output_shape(%arg0: tensor<2xui64>, %arg1: tensor<2xi64>) -> tensor<1x4xf32> { -+ // CHECK: stablehlo.dynamic_rng_bit_generator -+ %1:2 = stablehlo.custom_call @stablehlo.dynamic_rng_bit_generator(%arg0, %arg1) { -+ rng_algorithm = #stablehlo -+ } : (tensor<2xui64>, tensor<2xi64>) -> (tensor<2xui64>, tensor<1x4xf32>) -+ return %1#1 : tensor<1x4xf32> -+} -+ -+// ----- -+ -+// CHECK-LABEL: func @dynamic_rng_bit_generator_inapplicable_dynamic_output_type -+func.func @dynamic_rng_bit_generator_inapplicable_dynamic_output_type(%arg0: tensor<2xui64>) -> tensor { -+ // CHECK: stablehlo.dynamic_rng_bit_generator -+ %0 = stablehlo.constant dense<[1, 4]> : tensor<2xi64> -+ %1:2 = stablehlo.custom_call @stablehlo.dynamic_rng_bit_generator(%arg0, %0) { -+ rng_algorithm = #stablehlo -+ } : (tensor<2xui64>, tensor<2xi64>) -> (tensor<2xui64>, tensor) -+ return %1#1 : tensor - } - - // ----- -diff --ruN a/stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir b/stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir ---- stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir -+++ stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir -@@ -607,12 +607,45 @@ - - // ----- - -+// CHECK-LABEL: @main -+func.func @main(%arg0: tensor<3x2xf32>, %arg1: tensor) -> tensor<*xf32> { -+ // CHECK: stablehlo.dynamic_reduce_window{{.*}} -> tensor<2x2xf32> -+ %0 = stablehlo.constant dense<[2, 1]> : tensor<2xi64> -+ %1 = stablehlo.constant dense<[4, 1]> : tensor<2xi64> -+ %2 = stablehlo.constant dense<[2, 1]> : tensor<2xi64> -+ %3 = stablehlo.constant dense<[3, 1]> : tensor<2xi64> -+ %4 = stablehlo.constant dense<[[2, 1], [0, 0]]> : tensor<2x2xi64> -+ %5 = stablehlo.custom_call @stablehlo.dynamic_reduce_window(%arg0, %arg1, %0, %1, %2, %3, %4) { -+ called_computations = [@dynamic_reduce_window0] -+ } : (tensor<3x2xf32>, tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2x2xi64>) -> tensor<*xf32> -+ func.return %5 : tensor<*xf32> -+} -+ -+func.func private @dynamic_reduce_window0(%arg0: tensor, %arg1: tensor) -> tensor { -+ %0 = stablehlo.add %arg0, %arg1 : tensor -+ func.return %0 : tensor -+} -+ -+// ----- -+ - // CHECK-LABEL: @refine_dynamic_reshape - func.func @refine_dynamic_reshape(%arg0: tensor<4xf32>) -> tensor<*xf32> { - // CHECK: stablehlo.dynamic_reshape{{.*}} -> tensor<1x4xf32> - %0 = stablehlo.constant dense<[1, 4]> : tensor<2xi64> - %1 = stablehlo.dynamic_reshape %arg0, %0 : (tensor<4xf32>, tensor<2xi64>) -> tensor<*xf32> - func.return %1 : tensor<*xf32> -+} -+ -+// ----- -+ -+// CHECK-LABEL: @refine_dynamic_rng_bit_generator -+func.func @refine_dynamic_rng_bit_generator(%arg0: tensor<2xui64>) -> (tensor, tensor<*xf32>) { -+ // CHECK: stablehlo.dynamic_rng_bit_generator{{.*}} -> (tensor<2xui64>, tensor<1x4xf32>) -+ %0 = stablehlo.constant dense<[1, 4]> : tensor<2xi64> -+ %1:2 = stablehlo.custom_call @stablehlo.dynamic_rng_bit_generator(%arg0, %0) { -+ rng_algorithm = #stablehlo -+ } : (tensor<2xui64>, tensor<2xi64>) -> (tensor, tensor<*xf32>) -+ func.return %1#0, %1#1 : tensor, tensor<*xf32> - } - - // ----- -diff --ruN a/stablehlo/stablehlo/transforms/StablehloCanonicalizeDynamism.cpp b/stablehlo/stablehlo/transforms/StablehloCanonicalizeDynamism.cpp ---- stablehlo/stablehlo/transforms/StablehloCanonicalizeDynamism.cpp -+++ stablehlo/stablehlo/transforms/StablehloCanonicalizeDynamism.cpp -@@ -24,6 +24,7 @@ - #include "mlir/Interfaces/InferTypeOpInterface.h" - #include "mlir/Support/LogicalResult.h" - #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -+#include "stablehlo/dialect/ExperimentalOps.h" - #include "stablehlo/dialect/StablehloOps.h" - #include "stablehlo/transforms/Passes.h" - -@@ -198,6 +199,54 @@ - } - }; - -+struct CanonicalizeDynamicReduceWindowOpPattern -+ : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(CustomCallOp impl, -+ PatternRewriter& rewriter) const override { -+ auto maybeOp = getDynamicReduceWindowOp(impl); -+ if (!maybeOp || failed(maybeOp->verify())) return failure(); -+ DynamicReduceWindowOpAdaptor op = *maybeOp; -+ -+ // ReduceWindowOp supports dynamic shapes for operands and results, so we -+ // don't check for that here unlike in some other patterns in this pass. -+ SmallVector windowDimensions, windowStrides, baseDilations, -+ windowDilations, padding; -+ if (failed(hlo::matchInts(op.getWindowDimensions(), windowDimensions))) -+ return rewriter.notifyMatchFailure(op, -+ "expected static window_dimensions"); -+ if (failed(hlo::matchInts(op.getWindowStrides(), windowStrides))) -+ return rewriter.notifyMatchFailure(op, "expected static window_strides"); -+ if (failed(hlo::matchInts(op.getBaseDilations(), baseDilations))) -+ return rewriter.notifyMatchFailure(op, "expected static base_dilations"); -+ if (failed(hlo::matchInts(op.getWindowDilations(), windowDilations))) -+ return rewriter.notifyMatchFailure(op, -+ "expected static window_dilations"); -+ if (failed(hlo::matchInts(op.getPadding(), padding))) -+ return rewriter.notifyMatchFailure(op, "expected static padding"); -+ auto newOp = rewriter.create( -+ op->getLoc(), op->getResultTypes(), op.getInputs(), op.getInitValues(), -+ rewriter.getI64TensorAttr(windowDimensions), -+ rewriter.getI64TensorAttr(windowStrides), -+ rewriter.getI64TensorAttr(baseDilations), -+ rewriter.getI64TensorAttr(windowDilations), -+ hlo::getPaddingAttr(&rewriter, padding)); -+ -+ // Inline the called computation into newOp. -+ // This is somewhat annoying because we also have to rewrite the original -+ // func::ReturnOp into stablehlo::ReturnOp. -+ rewriter.cloneRegionBefore(op.getBody(), newOp.getBody(), -+ newOp.getBody().end()); -+ auto funcReturnOp = -+ cast(newOp.getBody().front().getTerminator()); -+ rewriter.setInsertionPointToEnd(&newOp.getBody().front()); -+ rewriter.replaceOpWithNewOp( -+ funcReturnOp, funcReturnOp.getOperands()); -+ rewriter.replaceOp(op, newOp->getResults()); -+ return success(); -+ } -+}; -+ - struct CanonicalizeDynamicReshapeOpPattern - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; -@@ -210,6 +259,27 @@ - if (!op.getType().hasStaticShape()) - return rewriter.notifyMatchFailure(op, "expected static result type"); - rewriter.replaceOpWithNewOp(op, op.getType(), op.getOperand()); -+ return success(); -+ } -+}; -+ -+struct CanonicalizeDynamicRngBitGeneratorOpPattern -+ : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(CustomCallOp impl, -+ PatternRewriter& rewriter) const override { -+ auto maybeOp = getDynamicRngBitGeneratorOp(impl); -+ if (!maybeOp || failed(maybeOp->verify())) return failure(); -+ DynamicRngBitGeneratorOpAdaptor op = *maybeOp; -+ -+ // This pattern ignores and discards the output_shape operand. We rely on -+ // the verifier to make sure that its value is consistent with result type. -+ if (!succeeded(hlo::matchInts(op.getOutputShape()))) -+ return rewriter.notifyMatchFailure(op, "expected static output_shape"); -+ if (!op.getOutput().getType().cast().hasStaticShape()) -+ return rewriter.notifyMatchFailure(op, "expected static output type"); -+ rewriter.replaceOpWithNewOp( -+ op, op->getResultTypes(), op.getRngAlgorithm(), op.getInitialState()); - return success(); - } - }; -@@ -320,7 +390,9 @@ - patterns.add(&getContext()); - patterns.add(&getContext()); - patterns.add(&getContext()); -+ patterns.add(&getContext()); - patterns.add(&getContext()); -+ patterns.add(&getContext()); - patterns.add( - &getContext()); - patterns.add(&getContext()); -diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp ---- stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp -+++ stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp -@@ -43,6 +43,7 @@ - #include "mlir/Transforms/GreedyPatternRewriteDriver.h" - #include "stablehlo/dialect/Base.h" - #include "stablehlo/dialect/ChloOps.h" -+#include "stablehlo/dialect/ExperimentalOps.h" - #include "stablehlo/dialect/StablehloOps.h" - #include "stablehlo/dialect/TypeInference.h" - #include "stablehlo/transforms/Passes.h" -@@ -844,12 +845,78 @@ - } - }; - -+struct RefineDynamicReduceWindowOpPattern -+ : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(CustomCallOp impl, -+ PatternRewriter& rewriter) const override { -+ auto maybeOp = getDynamicReduceWindowOp(impl); -+ if (!maybeOp || failed(maybeOp->verify())) return failure(); -+ DynamicReduceWindowOpAdaptor op = *maybeOp; -+ -+ // At the moment, we only support refining return types using fully static -+ // shape values which serves the current use cases well. -+ // Support for partially static shape values is left for future work. -+ SmallVector windowDimensions, windowStrides, baseDilations, -+ windowDilations, padding; -+ if (failed(hlo::matchInts(op.getWindowDimensions(), windowDimensions))) -+ return rewriter.notifyMatchFailure(op, -+ "expected constant window_dimensions"); -+ if (failed(hlo::matchInts(op.getWindowStrides(), windowStrides))) -+ return rewriter.notifyMatchFailure(op, -+ "expected constant window_strides"); -+ if (failed(hlo::matchInts(op.getBaseDilations(), baseDilations))) -+ return rewriter.notifyMatchFailure(op, -+ "expected constant base_dilations"); -+ if (failed(hlo::matchInts(op.getWindowDilations(), windowDilations))) -+ return rewriter.notifyMatchFailure(op, -+ "expected constant window_dilations"); -+ if (failed(hlo::matchInts(op.getPadding(), padding))) -+ return rewriter.notifyMatchFailure(op, "expected constant padding"); -+ -+ SmallVector inferredReturnTypes; -+ if (failed(hlo::inferReduceWindowOp( -+ /*location=*/{}, op.getInputs(), op.getInitValues(), -+ rewriter.getI64TensorAttr(windowDimensions), -+ rewriter.getI64TensorAttr(windowStrides), -+ rewriter.getI64TensorAttr(baseDilations), -+ rewriter.getI64TensorAttr(windowDilations), -+ hlo::getPaddingAttr(&rewriter, padding), inferredReturnTypes))) -+ return rewriter.notifyMatchFailure(op, "inferReduceWindowOp failed"); -+ return refineReturnTypes(rewriter, op, inferredReturnTypes); -+ } -+}; -+ - struct RefineDynamicReshapeOpPattern - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(DynamicReshapeOp op, - PatternRewriter& rewriter) const override { - return refineReturnShape(rewriter, op, op.getOutputShape()); -+ } -+}; -+ -+struct RefineDynamicRngBitGeneratorOpPattern -+ : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ LogicalResult matchAndRewrite(CustomCallOp impl, -+ PatternRewriter& rewriter) const override { -+ auto maybeOp = getDynamicRngBitGeneratorOp(impl); -+ if (!maybeOp || failed(maybeOp->verify())) return failure(); -+ DynamicRngBitGeneratorOpAdaptor op = *maybeOp; -+ -+ // At the moment, we only support refining return types using fully static -+ // shape values which serves the current use cases well. -+ // Support for partially static shape values is left for future work. -+ auto initialStateType = op.getInitialState().getType().cast(); -+ SmallVector outputShape; -+ if (failed(hlo::matchInts(op.getOutputShape(), outputShape))) -+ return rewriter.notifyMatchFailure(op, "expected constant output_shape"); -+ -+ // We only need to refine the shape of `output` (the second result). -+ // The shape of `output_state` (the first result) is determined by the shape -+ // of `initial_state`, so we ignore it and provide an empty refinement. -+ return refineReturnTypes(rewriter, op, {{initialStateType}, {outputShape}}); - } - }; - -@@ -1181,7 +1248,9 @@ - patterns.add(&getContext()); - patterns.add(&getContext()); - patterns.add(&getContext()); -+ patterns.add(&getContext()); - patterns.add(&getContext()); -+ patterns.add(&getContext()); - patterns.add(&getContext()); - patterns.add(&getContext()); - patterns.add(&getContext()); - From 00101baa56c6b78bad7ac3fd1064b18e65f4b76d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 27 Jul 2023 22:49:30 -0700 Subject: [PATCH 294/410] Internal Code Change PiperOrigin-RevId: 551743088 --- tensorflow/tsl/platform/ram_file_system.h | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tensorflow/tsl/platform/ram_file_system.h b/tensorflow/tsl/platform/ram_file_system.h index 65aa9873ee50d0..8faa450cae9bd5 100644 --- a/tensorflow/tsl/platform/ram_file_system.h +++ b/tensorflow/tsl/platform/ram_file_system.h @@ -28,6 +28,7 @@ limitations under the License. #include +#include "absl/strings/match.h" #include "tensorflow/tsl/platform/env.h" #include "tensorflow/tsl/platform/file_system.h" #include "tensorflow/tsl/platform/mutex.h" @@ -332,11 +333,11 @@ class RamFileSystem : public FileSystem { } bool StartsWith(std::string s, std::string prefix) { - return s.find(prefix) == 0; + return absl::StartsWith(s, prefix); } string StripPrefix(std::string s, std::string prefix) { - if (s.find(prefix) == 0) { + if (absl::StartsWith(s, prefix)) { return s.erase(0, prefix.size()); } return s; From 7a1c8e087bfdba5debf6698f09e6e04e7169294b Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Thu, 27 Jul 2023 23:44:50 -0700 Subject: [PATCH 295/410] [xla:gpu] Add lmhlo_gpu.gemm (cuBLAS) support Implement XLA<->IREE<->StreamExecutor integration as a custom VM module (XlaGpuModule). Coming next: - Tracing/XProf integration - GEMM config caching/global initializer required for performance PiperOrigin-RevId: 551752820 --- .../compiler/xla/mlir/backends/openxla/BUILD | 1 + .../mlir/backends/openxla/conversion/BUILD | 30 +- .../openxla/conversion/convert_library_ops.cc | 292 ++++++++++++++++++ .../openxla/conversion/convert_library_ops.h | 37 +++ .../xla/mlir/backends/openxla/ir/BUILD | 48 +++ .../backends/openxla/ir/xla_gpu_dialect.cc | 39 +++ .../backends/openxla/ir/xla_gpu_dialect.h | 28 ++ .../backends/openxla/ir/xla_gpu_dialect.td | 69 +++++ .../mlir/backends/openxla/transforms/BUILD | 2 + .../openxla/transforms/convert_to_openxla.cc | 18 ++ .../backends/openxla/transforms/passes.td | 1 + .../transforms/tests/fusion_to_openxla.mlir | 3 + .../transforms/tests/gemm_to_openxla.mlir | 44 +++ .../transforms/tests/memref_to_openxla.mlir | 4 + .../transforms/tests/sort_to_openxla.mlir | 1 + .../tests/while_loop_to_openxla.mlir | 1 + .../mlir/backends/openxla/xla-openxla-opt.cc | 6 + .../compiler/xla/service/gpu/openxla/BUILD | 152 +++++---- .../xla/service/gpu/openxla/build_config.bzl | 13 + .../xla/service/gpu/openxla/compiler.cc | 2 +- .../xla/service/gpu/openxla/executable.cc | 17 +- .../xla/service/gpu/openxla/executable.h | 2 +- .../compiler/xla/service/gpu/openxla/gemm.cc | 148 +++++++++ .../compiler/xla/service/gpu/openxla/gemm.h | 113 +++++++ .../compiler/xla/service/gpu/openxla/hal.cc | 77 +++++ .../compiler/xla/service/gpu/openxla/hal.h | 38 +++ .../xla/service/gpu/openxla/module.cc | 161 ++++++++++ .../compiler/xla/service/gpu/openxla/module.h | 35 +++ .../compiler/xla/service/gpu/openxla/vm.cc | 46 +++ .../compiler/xla/service/gpu/openxla/vm.h | 62 ++++ 30 files changed, 1427 insertions(+), 63 deletions(-) create mode 100644 tensorflow/compiler/xla/mlir/backends/openxla/conversion/convert_library_ops.cc create mode 100644 tensorflow/compiler/xla/mlir/backends/openxla/conversion/convert_library_ops.h create mode 100644 tensorflow/compiler/xla/mlir/backends/openxla/ir/BUILD create mode 100644 tensorflow/compiler/xla/mlir/backends/openxla/ir/xla_gpu_dialect.cc create mode 100644 tensorflow/compiler/xla/mlir/backends/openxla/ir/xla_gpu_dialect.h create mode 100644 tensorflow/compiler/xla/mlir/backends/openxla/ir/xla_gpu_dialect.td create mode 100644 tensorflow/compiler/xla/mlir/backends/openxla/transforms/tests/gemm_to_openxla.mlir create mode 100644 tensorflow/compiler/xla/service/gpu/openxla/build_config.bzl create mode 100644 tensorflow/compiler/xla/service/gpu/openxla/gemm.cc create mode 100644 tensorflow/compiler/xla/service/gpu/openxla/gemm.h create mode 100644 tensorflow/compiler/xla/service/gpu/openxla/hal.cc create mode 100644 tensorflow/compiler/xla/service/gpu/openxla/hal.h create mode 100644 tensorflow/compiler/xla/service/gpu/openxla/module.cc create mode 100644 tensorflow/compiler/xla/service/gpu/openxla/module.h create mode 100644 tensorflow/compiler/xla/service/gpu/openxla/vm.cc create mode 100644 tensorflow/compiler/xla/service/gpu/openxla/vm.h diff --git a/tensorflow/compiler/xla/mlir/backends/openxla/BUILD b/tensorflow/compiler/xla/mlir/backends/openxla/BUILD index e17129b1968409..886086e8efcb73 100644 --- a/tensorflow/compiler/xla/mlir/backends/openxla/BUILD +++ b/tensorflow/compiler/xla/mlir/backends/openxla/BUILD @@ -30,6 +30,7 @@ config_setting( # "@llvm-project//mlir:FuncDialect", # "@llvm-project//mlir:MemRefDialect", # "@llvm-project//mlir:MlirOptLib", +# "@llvm-project//mlir:Transforms", # "//tensorflow/compiler/xla/mlir/backends/openxla/transforms:passes", # "//tensorflow/compiler/xla/mlir_hlo:lhlo", # "//tensorflow/compiler/xla/mlir_hlo:lhlo_gpu", diff --git a/tensorflow/compiler/xla/mlir/backends/openxla/conversion/BUILD b/tensorflow/compiler/xla/mlir/backends/openxla/conversion/BUILD index adc00ecc48d2d1..589f85c39b3750 100644 --- a/tensorflow/compiler/xla/mlir/backends/openxla/conversion/BUILD +++ b/tensorflow/compiler/xla/mlir/backends/openxla/conversion/BUILD @@ -11,7 +11,7 @@ package( # # cc_library( # name = "de_bufferization", -# hdrs = if_openxla(["de_bufferization.h"]), +# hdrs = ["de_bufferization.h"], # deps = [ # "@llvm-project//llvm:Support", # "@llvm-project//mlir:IR", @@ -24,7 +24,7 @@ package( # srcs = if_openxla(["convert_compiled_ops.cc"]), # hdrs = if_openxla(["convert_compiled_ops.h"]), # # TODO(ezhulenev): Override cc_library()'s internal default value of ["//buildenv/target:gce"] -# # because IREE targets are not compatible with `non_prod` constraint. +# # because IREE targets are not compatible with the `non_prod` constraint. # compatible_with = [], # deps = [ # ":de_bufferization", @@ -45,11 +45,33 @@ package( # ) # # cc_library( +# name = "convert_library_ops", +# srcs = if_openxla(["convert_library_ops.cc"]), +# hdrs = if_openxla(["convert_library_ops.h"]), +# # TODO(ezhulenev): Override cc_library()'s internal default value of ["//buildenv/target:gce"] +# # because IREE targets are not compatible with the `non_prod` constraint. +# compatible_with = [], +# deps = [ +# ":de_bufferization", +# "@llvm-project//llvm:Support", +# "@llvm-project//mlir:ArithDialect", +# "@llvm-project//mlir:FuncDialect", +# "@llvm-project//mlir:IR", +# "@llvm-project//mlir:MemRefDialect", +# "@llvm-project//mlir:Support", +# "@llvm-project//mlir:TensorDialect", +# "@llvm-project//mlir:Transforms", +# "//tensorflow/compiler/xla/mlir/backends/openxla/ir:xla_gpu", +# "//tensorflow/compiler/xla/mlir_hlo:lhlo_gpu", +# ] + if_openxla(["//third_party/iree/llvm-external-projects/iree-dialects:IREEInputDialect"]), +# ) +# +# cc_library( # name = "convert_memref_ops", # srcs = if_openxla(["convert_memref_ops.cc"]), # hdrs = if_openxla(["convert_memref_ops.h"]), # # TODO(ezhulenev): Override cc_library()'s internal default value of ["//buildenv/target:gce"] -# # because IREE targets are not compatible with `non_prod` constraint. +# # because IREE targets are not compatible with the `non_prod` constraint. # compatible_with = [], # deps = [ # ":de_bufferization", @@ -67,7 +89,7 @@ package( # srcs = if_openxla(["convert_while_op.cc"]), # hdrs = if_openxla(["convert_while_op.h"]), # # TODO(ezhulenev): Override cc_library()'s internal default value of ["//buildenv/target:gce"] -# # because IREE targets are not compatible with `non_prod` constraint. +# # because IREE targets are not compatible with the `non_prod` constraint. # compatible_with = [], # deps = [ # ":de_bufferization", diff --git a/tensorflow/compiler/xla/mlir/backends/openxla/conversion/convert_library_ops.cc b/tensorflow/compiler/xla/mlir/backends/openxla/conversion/convert_library_ops.cc new file mode 100644 index 00000000000000..9118f6c9d1039b --- /dev/null +++ b/tensorflow/compiler/xla/mlir/backends/openxla/conversion/convert_library_ops.cc @@ -0,0 +1,292 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/mlir/backends/openxla/conversion/convert_library_ops.h" + +#include +#include +#include +#include + +#include "third_party/iree/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/Input/InputDialect.h" +#include "third_party/iree/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/Input/InputOps.h" +#include "llvm/ADT/STLExtras.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/xla/mlir/backends/openxla/ir/xla_gpu_dialect.h" +#include "tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.h" + +namespace xla::gpu { + +namespace { +using namespace mlir; // NOLINT +using namespace mlir::iree_compiler; // NOLINT + +using arith::ConstantIndexOp; +using arith::ConstantIntOp; +using arith::ConstantOp; + +//===----------------------------------------------------------------------===// +// Helper class to set up OpenXLA runtime API declarations +//===----------------------------------------------------------------------===// + +// XLA GPU <-> StreamExecutor integration as API declarations. +class XlaGpuApi { + public: + // Imports `@xla_gpu.dot_dimension_numbers.create` into the module. + func::FuncOp getCreateDotDimensionsNumbers(OpBuilder &b, ModuleOp module); + + // Imports `@xla_gpu.dot_precision.create` into the module. + func::FuncOp getCreateDotPrecision(OpBuilder &b, ModuleOp module); + + // Imports `@xla_gpu.dot_config.create` into the module. + func::FuncOp getCreateDotConfig(OpBuilder &b, ModuleOp module); + + // Imports `@xla_gpu.gemm.dispatch` into the module. + func::FuncOp getDispatchGemm(OpBuilder &b, ModuleOp module); + + private: + SymbolTable &symTable(ModuleOp module); + + func::FuncOp addDecl(OpBuilder &b, ModuleOp module, std::string_view name, + FunctionType function_type); + + SymbolTableCollection sym_table_; +}; + +Type getI64ListType(MLIRContext *ctx) { + return IREE::Input::ListType::get(ctx, IntegerType::get(ctx, 64)); +} + +func::FuncOp XlaGpuApi::getCreateDotDimensionsNumbers(OpBuilder &b, + ModuleOp module) { + auto i64_list = getI64ListType(b.getContext()); + SmallVector args = {/*lhs_batch_dimensions=*/i64_list, + /*rhs_batch_dimensions=*/i64_list, + /*lhs_contracting_dimensions=*/i64_list, + /*rhs_contracting_dimensions=*/i64_list}; + SmallVector rets = {b.getType()}; + return addDecl(b, module, "xla_gpu.dot_dimension_numbers.create", + FunctionType::get(b.getContext(), args, rets)); +} + +func::FuncOp XlaGpuApi::getCreateDotPrecision(OpBuilder &b, ModuleOp module) { + SmallVector args = {getI64ListType(b.getContext())}; + SmallVector rets = {b.getType()}; + return addDecl(b, module, "xla_gpu.dot_precision.create", + FunctionType::get(b.getContext(), args, rets)); +} + +func::FuncOp XlaGpuApi::getCreateDotConfig(OpBuilder &b, ModuleOp module) { + SmallVector args = {b.getI32Type(), // algorithm + b.getF64Type(), // alpha_real + b.getF64Type(), // alpha_imag + b.getF64Type(), // beta + b.getType(), + b.getType()}; + SmallVector rets = {b.getType()}; + return addDecl(b, module, "xla_gpu.dot_config.create", + FunctionType::get(b.getContext(), args, rets)); +} + +func::FuncOp XlaGpuApi::getDispatchGemm(OpBuilder &b, ModuleOp module) { + auto execution_context = b.getType(); + auto buffer_view = b.getType(); + SmallVector args = {execution_context, buffer_view, buffer_view, + buffer_view, b.getType()}; + return addDecl(b, module, "xla_gpu.gemm.dispatch", + FunctionType::get(b.getContext(), args, /*rets=*/TypeRange())); +} + +SymbolTable &XlaGpuApi::symTable(ModuleOp module) { + return sym_table_.getSymbolTable(module); +} + +func::FuncOp XlaGpuApi::addDecl(OpBuilder &b, ModuleOp module, + std::string_view name, + FunctionType function_type) { + if (auto fn = sym_table_.lookupNearestSymbolFrom( + module, b.getStringAttr(name))) + return fn; + + Location loc = UnknownLoc::get(module->getContext()); + + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToEnd(module.getBody()); + + auto fn = b.create(loc, name, function_type); + fn.setPrivate(); + symTable(module).insert(fn); + return fn; +} + +//===----------------------------------------------------------------------===// +// Helper functions to build arguments to API functions. +//===----------------------------------------------------------------------===// + +// Creates `iree_input.list` list from the given values. +TypedValue getI64List(ImplicitLocOpBuilder &b, + ArrayRef values) { + MLIRContext *ctx = b.getContext(); + + Value size = b.create(values.size()); + Value list = b.create(getI64ListType(ctx), size); + + if (!values.empty()) b.create(list, size); + for (auto indexed : llvm::enumerate(values)) { + Value index = b.create(indexed.index()); + Value value = b.create(indexed.value(), 64); + b.create(list, index, value); + } + + return list.cast>(); +} + +//===----------------------------------------------------------------------===// +// Converts lmhlo_gpu.gemm op to OpenXLA runtime calls +//===----------------------------------------------------------------------===// + +TypedValue getDotDimensionNumbers( + XlaGpuApi &api, ImplicitLocOpBuilder &b, ModuleOp module, + lmhlo_gpu::GEMMOp op) { + mhlo::DotDimensionNumbersAttr attr = op.getDotDimensionNumbersAttr(); + SmallVector args = {getI64List(b, attr.getLhsBatchingDimensions()), + getI64List(b, attr.getRhsBatchingDimensions()), + getI64List(b, attr.getLhsContractingDimensions()), + getI64List(b, attr.getRhsContractingDimensions())}; + + auto api_func = api.getCreateDotDimensionsNumbers(b, module); + auto call = b.create(api_func.getSymName(), + api_func.getResultTypes(), args); + + return call.getResult(0).cast>(); +} + +TypedValue getDotPrecision(XlaGpuApi &api, + ImplicitLocOpBuilder &b, + ModuleOp module, + lmhlo_gpu::GEMMOp op) { + SmallVector precision = llvm::to_vector( + llvm::map_range(op.getPrecisionConfigAttr(), [](Attribute attr) { + auto value = attr.cast().getValue(); + return static_cast(value); + })); + + SmallVector args = {getI64List(b, precision)}; + + auto api_func = api.getCreateDotPrecision(b, module); + auto call = b.create(api_func.getSymName(), + api_func.getResultTypes(), args); + + return call.getResult(0).cast>(); +} + +TypedValue getDotConfig(XlaGpuApi &api, ImplicitLocOpBuilder &b, + ModuleOp module, lmhlo_gpu::GEMMOp op) { + int32_t algorithm = op.getAlgorithm().value_or(-1); + + SmallVector args = {b.create(algorithm, 32), + b.create(op.getAlphaRealAttr()), + b.create(op.getAlphaImagAttr()), + b.create(op.getBetaAttr()), + getDotDimensionNumbers(api, b, module, op), + getDotPrecision(api, b, module, op)}; + + auto api_func = api.getCreateDotConfig(b, module); + auto call = b.create(api_func.getSymName(), + api_func.getResultTypes(), args); + + return call.getResult(0).cast>(); +} + +TypedValue getExecutionContext(Operation *op) { + auto func = op->getParentOfType(); + return func.getArguments().front().cast>(); +} + +struct ConvertGemmOp : public OpConversionPattern { + ConvertGemmOp(TypeConverter &converter, MLIRContext *ctx, + std::shared_ptr state, + std::shared_ptr api) + : OpConversionPattern(converter, ctx), + state(std::move(state)), + api(std::move(api)) {} + + LogicalResult matchAndRewrite( + lmhlo_gpu::GEMMOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto *block = op->getBlock(); + auto module = op->getParentOfType(); + + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + + auto dot_config = getDotConfig(*api, b, module, op); + + // Export arguments to buffer views. + auto lhs = state->remapped[block][op.getA()]; + auto rhs = state->remapped[block][op.getB()]; + auto out = state->remapped[block][op.getC()]; + + if (!lhs || !rhs || !out) { + return rewriter.notifyMatchFailure( + op, "missing memref to tensor mapping for lmhlo_gpu.gemm arguments"); + } + + // Arguments to a gemm dispatch. + SmallVector args = {getExecutionContext(op)}; + for (TypedValue src : {lhs, rhs, out}) { + auto export_op = b.create( + b.getType(), src, + /*source_dims=*/ValueRange()); + args.push_back(export_op.getResult()); + } + args.push_back(dot_config); + + // TODO(ezhulenev): Should we import buffer view back and update remapping? + auto api_func = api->getDispatchGemm(b, module); + b.create(api_func.getSymName(), api_func.getResultTypes(), + args); + + rewriter.eraseOp(op); + return success(); + } + + std::shared_ptr state; + std::shared_ptr api; +}; + +} // namespace + +//===----------------------------------------------------------------------===// + +void populateLibraryOpsConversionPatterns( + RewritePatternSet &patterns, TypeConverter &converter, + std::shared_ptr state) { + auto api = std::make_shared(); + auto *ctx = patterns.getContext(); + patterns.insert(converter, ctx, state, api); +} + +} // namespace xla::gpu diff --git a/tensorflow/compiler/xla/mlir/backends/openxla/conversion/convert_library_ops.h b/tensorflow/compiler/xla/mlir/backends/openxla/conversion/convert_library_ops.h new file mode 100644 index 00000000000000..005eb685985b98 --- /dev/null +++ b/tensorflow/compiler/xla/mlir/backends/openxla/conversion/convert_library_ops.h @@ -0,0 +1,37 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_MLIR_BACKENDS_OPENXLA_CONVERSION_CONVERT_LIBRARY_OPS_H_ +#define TENSORFLOW_COMPILER_XLA_MLIR_BACKENDS_OPENXLA_CONVERSION_CONVERT_LIBRARY_OPS_H_ + +#include + +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "tensorflow/compiler/xla/mlir/backends/openxla/conversion/de_bufferization.h" + +namespace xla { +namespace gpu { + +// Appends patterns to convert `lmhlo_gpu` operations corresponding to library +// calls (cuBLAS, cuDNN, etc. operations). +void populateLibraryOpsConversionPatterns( + mlir::RewritePatternSet &patterns, mlir::TypeConverter &converter, + std::shared_ptr state); + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_MLIR_BACKENDS_OPENXLA_CONVERSION_CONVERT_LIBRARY_OPS_H_ diff --git a/tensorflow/compiler/xla/mlir/backends/openxla/ir/BUILD b/tensorflow/compiler/xla/mlir/backends/openxla/ir/BUILD new file mode 100644 index 00000000000000..500442869f011a --- /dev/null +++ b/tensorflow/compiler/xla/mlir/backends/openxla/ir/BUILD @@ -0,0 +1,48 @@ +load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") +load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_portable") +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//tensorflow/compiler/xla:internal"], + licenses = ["notice"], +) + +gentbl_cc_library( + name = "xla_gpu_inc_gen", + compatible_with = get_compatible_with_portable(), + tbl_outs = [ + ( + ["-gen-dialect-decls"], + "xla_gpu_dialect.h.inc", + ), + ( + ["-gen-dialect-defs"], + "xla_gpu_dialect.cc.inc", + ), + ( + ["-gen-typedef-decls"], + "xla_gpu_types.h.inc", + ), + ( + ["-gen-typedef-defs"], + "xla_gpu_types.cc.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "xla_gpu_dialect.td", + deps = ["@llvm-project//mlir:OpBaseTdFiles"], +) + +cc_library( + name = "xla_gpu", + srcs = ["xla_gpu_dialect.cc"], + hdrs = ["xla_gpu_dialect.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + ":xla_gpu_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) diff --git a/tensorflow/compiler/xla/mlir/backends/openxla/ir/xla_gpu_dialect.cc b/tensorflow/compiler/xla/mlir/backends/openxla/ir/xla_gpu_dialect.cc new file mode 100644 index 00000000000000..455b9c2d0ca89b --- /dev/null +++ b/tensorflow/compiler/xla/mlir/backends/openxla/ir/xla_gpu_dialect.cc @@ -0,0 +1,39 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/mlir/backends/openxla/ir/xla_gpu_dialect.h" + +#include "llvm/ADT/TypeSwitch.h" // IWYU pragma: keep +#include "mlir/IR/DialectImplementation.h" // from @llvm-project // IWYU pragma: keep + +//===----------------------------------------------------------------------===// +// XLA GPU Dialect +//===----------------------------------------------------------------------===// + +#include "tensorflow/compiler/xla/mlir/backends/openxla/ir/xla_gpu_dialect.cc.inc" + +namespace xla::gpu { + +void XlaGpuDialect::initialize() { + addTypes< +#define GET_TYPEDEF_LIST +#include "tensorflow/compiler/xla/mlir/backends/openxla/ir/xla_gpu_types.cc.inc" + >(); +} + +} // namespace xla::gpu + +#define GET_TYPEDEF_CLASSES +#include "tensorflow/compiler/xla/mlir/backends/openxla/ir/xla_gpu_types.cc.inc" diff --git a/tensorflow/compiler/xla/mlir/backends/openxla/ir/xla_gpu_dialect.h b/tensorflow/compiler/xla/mlir/backends/openxla/ir/xla_gpu_dialect.h new file mode 100644 index 00000000000000..bc351136517001 --- /dev/null +++ b/tensorflow/compiler/xla/mlir/backends/openxla/ir/xla_gpu_dialect.h @@ -0,0 +1,28 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_MLIR_BACKENDS_OPENXLA_IR_XLA_GPU_DIALECT_H_ +#define TENSORFLOW_COMPILER_XLA_MLIR_BACKENDS_OPENXLA_IR_XLA_GPU_DIALECT_H_ + +#include "mlir/IR/Dialect.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/IR/OpImplementation.h" // from @llvm-project // IWYU pragma: keep + +// XLA GPU dialect definition. +#include "tensorflow/compiler/xla/mlir/backends/openxla/ir/xla_gpu_dialect.h.inc" + +#define GET_TYPEDEF_CLASSES +#include "tensorflow/compiler/xla/mlir/backends/openxla/ir/xla_gpu_types.h.inc" + +#endif // TENSORFLOW_COMPILER_XLA_MLIR_BACKENDS_OPENXLA_IR_XLA_GPU_DIALECT_H_ diff --git a/tensorflow/compiler/xla/mlir/backends/openxla/ir/xla_gpu_dialect.td b/tensorflow/compiler/xla/mlir/backends/openxla/ir/xla_gpu_dialect.td new file mode 100644 index 00000000000000..f4042b2672b01a --- /dev/null +++ b/tensorflow/compiler/xla/mlir/backends/openxla/ir/xla_gpu_dialect.td @@ -0,0 +1,69 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifdef XLA_GPU_DIALECT +#else +#define XLA_GPU_DIALECT + +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/DialectBase.td" + +//===----------------------------------------------------------------------===// +// XLA GPU Dialect +//===----------------------------------------------------------------------===// + +def XlaGpuDialect : Dialect { + let name = "xla_gpu"; + + let description = [{ + This dialect contains types required for lowering XLA:GPU programs to + OpenXLA runtime with all the library integrations. At run-time all MHLO + attributes (e.g. `#mhlo.dot` on `mhlo.dot_general` operation) are passed + around as reference-counted values (e.g. dot dimensions is just a struct). + }]; + + let cppNamespace = "::xla::gpu"; + let useDefaultTypePrinterParser = 1; +} + +//===----------------------------------------------------------------------===// +// XLA GPU Types +//===----------------------------------------------------------------------===// + +class XLA_GPU_Type : + TypeDef { + let mnemonic = typeMnemonic; +} + +def ExecutionContextType : XLA_GPU_Type<"ExecutionContext", + "execution_context"> { + let summary = "XLA:GPU execution context"; +} + +def DotConfigType : XLA_GPU_Type<"DotConfig", "dot_config"> { + let summary = "Config for dot operation"; +} + +def DotDimensionNumbersType : XLA_GPU_Type<"DotDimensionNumbers", + "dot_dimension_numbers"> { + let summary = "Dimension numbers for dot operation"; +} + +def DotPrecisionType : XLA_GPU_Type<"DotPrecision", "dot_precision"> { + let summary = "Precision for dot operation"; +} + + +#endif // XLA_GPU_DIALECT diff --git a/tensorflow/compiler/xla/mlir/backends/openxla/transforms/BUILD b/tensorflow/compiler/xla/mlir/backends/openxla/transforms/BUILD index 335fad7d0b6195..b796a530a0460d 100644 --- a/tensorflow/compiler/xla/mlir/backends/openxla/transforms/BUILD +++ b/tensorflow/compiler/xla/mlir/backends/openxla/transforms/BUILD @@ -49,8 +49,10 @@ package( # "@llvm-project//mlir:TensorDialect", # "@llvm-project//mlir:Transforms", # "//tensorflow/compiler/xla/mlir/backends/openxla/conversion:convert_compiled_ops", +# "//tensorflow/compiler/xla/mlir/backends/openxla/conversion:convert_library_ops", # "//tensorflow/compiler/xla/mlir/backends/openxla/conversion:convert_memref_ops", # "//tensorflow/compiler/xla/mlir/backends/openxla/conversion:convert_while_op", +# "//tensorflow/compiler/xla/mlir/backends/openxla/ir:xla_gpu", # "//tensorflow/compiler/xla/mlir_hlo:lhlo", # ] + if_openxla(["//third_party/iree/llvm-external-projects/iree-dialects:IREEInputDialect"]), # ) diff --git a/tensorflow/compiler/xla/mlir/backends/openxla/transforms/convert_to_openxla.cc b/tensorflow/compiler/xla/mlir/backends/openxla/transforms/convert_to_openxla.cc index 15220184f4d101..ed343ccf6ac690 100644 --- a/tensorflow/compiler/xla/mlir/backends/openxla/transforms/convert_to_openxla.cc +++ b/tensorflow/compiler/xla/mlir/backends/openxla/transforms/convert_to_openxla.cc @@ -32,8 +32,10 @@ limitations under the License. #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/xla/mlir/backends/openxla/conversion/convert_compiled_ops.h" +#include "tensorflow/compiler/xla/mlir/backends/openxla/conversion/convert_library_ops.h" #include "tensorflow/compiler/xla/mlir/backends/openxla/conversion/convert_memref_ops.h" #include "tensorflow/compiler/xla/mlir/backends/openxla/conversion/convert_while_op.h" +#include "tensorflow/compiler/xla/mlir/backends/openxla/ir/xla_gpu_dialect.h" #include "tensorflow/compiler/xla/mlir_hlo/lhlo/IR/lhlo_ops.h" #define GEN_PASS_DECL_CONVERTTOOPENXLA @@ -74,6 +76,18 @@ IREE::Input::ExecutableSourceOp createXlaExecutableSource(ModuleOp module) { //===----------------------------------------------------------------------===// +// Adds `xla_gpu.execution_context` argument to all functions in the module. +static void addExecutionContextArgument(ModuleOp module) { + MLIRContext *ctx = module.getContext(); + + Type arg = ExecutionContextType::get(ctx); + DictionaryAttr attrs = DictionaryAttr::get(ctx); + + for (func::FuncOp func : module.getOps()) { + func.insertArguments({0}, {arg}, {attrs}, {func.getLoc()}); + } +} + class ConvertToOpenXlaPass : public ::impl::ConvertToOpenXlaBase { public: @@ -83,6 +97,9 @@ class ConvertToOpenXlaPass void runOnOperation() override { auto *ctx = &getContext(); + // Add execution context argument to all functions in the module. + addExecutionContextArgument(getOperation()); + // Add all pre-compiled XLA fusions to the module as an executable source. auto executable_source = createXlaExecutableSource(getOperation()); @@ -109,6 +126,7 @@ class ConvertToOpenXlaPass RewritePatternSet patterns(&getContext()); populateAnyFunctionOpInterfaceTypeConversionPattern(patterns, converter); + populateLibraryOpsConversionPatterns(patterns, converter, state); populateMemrefConversionPatterns(patterns, converter, state); populateWhileOpConversionPatterns(patterns, converter, state); populateCompiledOpsConversionPatterns( diff --git a/tensorflow/compiler/xla/mlir/backends/openxla/transforms/passes.td b/tensorflow/compiler/xla/mlir/backends/openxla/transforms/passes.td index ecaa6d69c0be14..b7b33e34e29b3e 100644 --- a/tensorflow/compiler/xla/mlir/backends/openxla/transforms/passes.td +++ b/tensorflow/compiler/xla/mlir/backends/openxla/transforms/passes.td @@ -30,6 +30,7 @@ def ConvertToOpenXla : Pass<"xla-gpu-to-openxla", "mlir::ModuleOp"> { let dependentDialects = [ "mlir::iree_compiler::IREE::Input::IREEInputDialect", "mlir::scf::SCFDialect", + "xla::gpu::XlaGpuDialect", ]; } diff --git a/tensorflow/compiler/xla/mlir/backends/openxla/transforms/tests/fusion_to_openxla.mlir b/tensorflow/compiler/xla/mlir/backends/openxla/transforms/tests/fusion_to_openxla.mlir index 48fba260e756af..450e193d294358 100644 --- a/tensorflow/compiler/xla/mlir/backends/openxla/transforms/tests/fusion_to_openxla.mlir +++ b/tensorflow/compiler/xla/mlir/backends/openxla/transforms/tests/fusion_to_openxla.mlir @@ -21,6 +21,7 @@ func.func @fusion( } // CHECK-LABEL: func @fusion( +// CHECK: %[[CTX:.*]]: !xla_gpu.execution_context, // CHECK: %[[ARG0:.*]]: tensor<12xi8>, %[[ARG1:.*]]: tensor<12xi8>, // CHECK: %[[ARG2:.*]]: tensor<12xi8> {lmhlo.output_index = {{.*}}} // CHECK: ) { @@ -74,6 +75,7 @@ func.func @fusions( // with tied operands. // CHECK-LABEL: func @fusions( +// CHECK: %[[CTX:.*]]: !xla_gpu.execution_context, // CHECK: %[[ARG0:.*]]: tensor<12xi8>, %[[ARG1:.*]]: tensor<12xi8>, // CHECK: %[[ARG2:.*]]: tensor<12xi8> {lmhlo.output_index = {{.*}}} // CHECK: ) { @@ -116,6 +118,7 @@ func.func @reinterpret_cast( // it when lowering fustions to dispatches. // CHECK-LABEL: func @reinterpret_cast( +// CHECK: %[[CTX:.*]]: !xla_gpu.execution_context, // CHECK: %[[ARG0:.*]]: tensor<66560xi8> {lmhlo.output_index = {{.*}}} // CHECK: ) { // CHECK: %[[T:.*]] = iree_input.tensor.import {{.*}} tensor<1x4x128x65xbf16> diff --git a/tensorflow/compiler/xla/mlir/backends/openxla/transforms/tests/gemm_to_openxla.mlir b/tensorflow/compiler/xla/mlir/backends/openxla/transforms/tests/gemm_to_openxla.mlir new file mode 100644 index 00000000000000..5475e64970e3d6 --- /dev/null +++ b/tensorflow/compiler/xla/mlir/backends/openxla/transforms/tests/gemm_to_openxla.mlir @@ -0,0 +1,44 @@ +// RUN: export MSAN_OPTIONS=intercept_strpbrk=0 +// RUN: xla-openxla-opt %s --xla-gpu-to-openxla --split-input-file \ +// RUN: | FileCheck %s + +func.func @gemm( + %arg0: memref<128xi8>, %arg1: memref<64xi8>, + %arg2: memref<32xi8> {lmhlo.output_index = dense<> : tensor<0xi64>} +) { + %c0 = arith.constant 0 : index + %view0 = memref.view %arg0[%c0][] : memref<128xi8> to memref<4x8xf32> + %view1 = memref.view %arg1[%c0][] : memref<64xi8> to memref<8x2xf32> + %view2 = memref.view %arg2[%c0][] : memref<32xi8> to memref<4x2xf32> + + "lmhlo_gpu.gemm"(%view0, %view1, %view2) { + alpha_imag = 0.000000e+00 : f64, + alpha_real = 1.000000e+00 : f64, + beta = 0.000000e+00 : f64, + dot_dimension_numbers = #mhlo.dot, + precision_config = [#mhlo, #mhlo] + } : (memref<4x8xf32>, memref<8x2xf32>, memref<4x2xf32>) -> () + + return +} + +// CHECK-LABEL: func @gemm( +// CHECK: %[[CTX:.*]]: !xla_gpu.execution_context, +// CHECK: %[[ARG0:.*]]: tensor<128xi8>, +// CHECK: %[[ARG1:.*]]: tensor<64xi8>, +// CHECK: %[[ARG2:.*]]: tensor<32xi8> {lmhlo.output_index = {{.*}}} +// CHECK: ) { +// CHECK: %[[LHS:.*]] = iree_input.tensor.import {{.*}} -> tensor<4x8xf32> +// CHECK: %[[RHS:.*]] = iree_input.tensor.import {{.*}} -> tensor<8x2xf32> +// CHECK: %[[OUT:.*]] = iree_input.tensor.import {{.*}} -> tensor<4x2xf32> +// CHECK: %[[DIMS:.*]] = call @xla_gpu.dot_dimension_numbers.create +// CHECK: %[[PRECISION:.*]] = call @xla_gpu.dot_precision.create +// CHECK: %[[CONFIG:.*]] = call @xla_gpu.dot_config.create +// CHECK: %[[LHS_BUF:.*]] = iree_input.tensor.export %[[LHS]] +// CHECK: %[[RHS_BUF:.*]] = iree_input.tensor.export %[[RHS]] +// CHECK: %[[OUT_BUF:.*]] = iree_input.tensor.export %[[OUT]] +// CHECK: call @xla_gpu.gemm.dispatch( +// CHECK: %[[CTX]], %[[LHS_BUF]], %[[RHS_BUF]], %[[OUT_BUF]], %[[CONFIG]] +// CHECK: ) +// CHECK: } diff --git a/tensorflow/compiler/xla/mlir/backends/openxla/transforms/tests/memref_to_openxla.mlir b/tensorflow/compiler/xla/mlir/backends/openxla/transforms/tests/memref_to_openxla.mlir index 4ba2e2a8a6464c..841c84b9ed66b5 100644 --- a/tensorflow/compiler/xla/mlir/backends/openxla/transforms/tests/memref_to_openxla.mlir +++ b/tensorflow/compiler/xla/mlir/backends/openxla/transforms/tests/memref_to_openxla.mlir @@ -9,6 +9,7 @@ func.func @main(%arg0: memref<12xi8>) { } // CHECK-LABEL: func @main( +// CHECK: %[[CTX:.*]]: !xla_gpu.execution_context, // CHECK: %[[ARG0:.*]]: tensor<12xi8> // CHECK: ) { // CHECK: %[[C0:.*]] = arith.constant 0 : index @@ -35,6 +36,7 @@ func.func @main(%arg0: memref<12xi8>) { } // CHECK-LABEL: func @main( +// CHECK: %[[CTX:.*]]: !xla_gpu.execution_context, // CHECK: %[[ARG0:.*]]: tensor<12xi8> // CHECK: ) { // CHECK: %[[C8:.*]] = arith.constant 8 : index @@ -65,6 +67,7 @@ func.func @main(%arg0: memref<8xi8> {lmhlo.constant_name = "cst"}) { // with an argument itself. // CHECK-LABEL: func @main( +// CHECK: %[[CTX:.*]]: !xla_gpu.execution_context, // CHECK: %[[ARG0:.*]]: tensor<8xi8> // CHECK: ) { // CHECK: %[[BUF:.*]] = iree_input.tensor.export %[[ARG0]] @@ -92,6 +95,7 @@ func.func @main(%arg0: memref<66560xi8>) { // either the buffer view itself, or as a separate metadata object. // CHECK-LABEL: func @main( +// CHECK: %[[CTX:.*]]: !xla_gpu.execution_context, // CHECK: %[[ARG0:.*]]: tensor<66560xi8> // CHECK: ) { // CHECK: %[[C0:.*]] = arith.constant 0 : index diff --git a/tensorflow/compiler/xla/mlir/backends/openxla/transforms/tests/sort_to_openxla.mlir b/tensorflow/compiler/xla/mlir/backends/openxla/transforms/tests/sort_to_openxla.mlir index 6a675642806bfa..a3cba412e703df 100644 --- a/tensorflow/compiler/xla/mlir/backends/openxla/transforms/tests/sort_to_openxla.mlir +++ b/tensorflow/compiler/xla/mlir/backends/openxla/transforms/tests/sort_to_openxla.mlir @@ -26,6 +26,7 @@ func.func @main(%arg0: memref<16xi8>, %arg1: memref<16xi8>, } // CHECK-LABEL: func @main( +// CHECK: %[[CTX:.*]]: !xla_gpu.execution_context, // CHECK: %[[ARG0:.*]]: tensor<16xi8>, %[[ARG1:.*]]: tensor<16xi8>, // CHECK: %[[ARG2:.*]]: tensor<16xi8> // CHECK: ) { diff --git a/tensorflow/compiler/xla/mlir/backends/openxla/transforms/tests/while_loop_to_openxla.mlir b/tensorflow/compiler/xla/mlir/backends/openxla/transforms/tests/while_loop_to_openxla.mlir index 1e8bdb70a81c3f..1b0d887c0c30f1 100644 --- a/tensorflow/compiler/xla/mlir/backends/openxla/transforms/tests/while_loop_to_openxla.mlir +++ b/tensorflow/compiler/xla/mlir/backends/openxla/transforms/tests/while_loop_to_openxla.mlir @@ -35,6 +35,7 @@ func.func @main(%arg0: memref<4xi8>, %arg1: memref<4xi8>, %arg2: memref<1xi8>) { // of both tensors as a result. // CHECK-LABEL: func @main( +// CHECK: %[[CTX:.*]]: !xla_gpu.execution_context, // CHECK: %[[ARG0:.*]]: tensor<4xi8>, %[[ARG1:.*]]: tensor<4xi8>, // CHECK: %[[ARG2:.*]]: tensor<1xi8> // CHECK: ) { diff --git a/tensorflow/compiler/xla/mlir/backends/openxla/xla-openxla-opt.cc b/tensorflow/compiler/xla/mlir/backends/openxla/xla-openxla-opt.cc index 417cb72706fe74..843cd7a0b141dd 100644 --- a/tensorflow/compiler/xla/mlir/backends/openxla/xla-openxla-opt.cc +++ b/tensorflow/compiler/xla/mlir/backends/openxla/xla-openxla-opt.cc @@ -16,6 +16,7 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project #include "mlir/Tools/mlir-opt/MlirOptMain.h" // from @llvm-project +#include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow/compiler/xla/mlir/backends/openxla/transforms/passes.h" #include "tensorflow/compiler/xla/mlir_hlo/lhlo/IR/lhlo_ops.h" #include "tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.h" @@ -27,6 +28,11 @@ int main(int argc, char **argv) { registry.insert(); + + // General MLIR passes like `-cse` and `-canonicalize`. + registerTransformsPasses(); + + // Lowering to OpenXLA runtime. xla::gpu::registerOpenXlaPases(); return failed(MlirOptMain(argc, argv, "Xla OpenXLA Pass Driver\n", registry)); diff --git a/tensorflow/compiler/xla/service/gpu/openxla/BUILD b/tensorflow/compiler/xla/service/gpu/openxla/BUILD index 630a13b02cb2dc..ccbf7abb102460 100644 --- a/tensorflow/compiler/xla/service/gpu/openxla/BUILD +++ b/tensorflow/compiler/xla/service/gpu/openxla/BUILD @@ -1,5 +1,6 @@ load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") load("//tensorflow/tsl/platform:build_config.bzl", "tf_platform_deps") +load("//tensorflow/compiler/xla/service/gpu/openxla:build_config.bzl", "if_not_openxla", "if_openxla") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -12,37 +13,24 @@ package_group( includes = ["//tensorflow/compiler/xla:friends"], ) +# Add `--define=xla_gpu_with_openxla_runtime=1` to build command to enable experimental OpenXLA/IREE +# backend for XLA:GPU executables. +config_setting( + name = "with_openxla_runtime", + values = { + "define": "xla_gpu_with_openxla_runtime=1", + }, +) + # copybara:uncomment_begin(not supported in OSS build) # -# # Add `--define=xla_gpu_with_openxla_runtime=1` to build command to enable experimental OpenXLA/IREE -# # backend for XLA:GPU executables. -# config_setting( -# name = "with_openxla_runtime", -# values = { -# "define": "xla_gpu_with_openxla_runtime=1", -# }, -# ) -# # cc_library( # name = "compiler", -# srcs = select({ -# ":with_openxla_runtime": ["compiler.cc"], -# "//conditions:default": [], -# }), -# hdrs = select({ -# ":with_openxla_runtime": ["compiler.h"], -# "//conditions:default": [], -# }), +# srcs = if_openxla(["compiler.cc"]), +# hdrs = if_openxla(["compiler.h"]), # # TODO(ezhulenev): Override cc_library()'s default compatibility because IREE targets are not # # compatible with `non_prod` constraint. # compatible_with = [], -# data = select({ -# ":with_openxla_runtime": [ -# # IREE build is currently broken -# # "//third_party/iree/lib:libIREECompiler.so", -# ], -# "//conditions:default": [], -# }), # deps = [ # "@com_google_absl//absl/base", # "@llvm-project//llvm:Support", @@ -54,31 +42,19 @@ package_group( # ] + tf_platform_deps( # "compiler", # platform_dir = "//tensorflow/compiler/xla/service/gpu/openxla/", -# ) + select({ -# ":with_openxla_runtime": [ -# "//third_party/iree/compiler/bindings/c:headers", -# "//third_party/iree/compiler/bindings/c:loader", -# # IREE build is currently broken -# # "//third_party/iree/llvm-external-projects/iree-dialects:IREEInputDialect", -# ], -# "//conditions:default": [], -# }), +# ) + if_openxla([ +# "//third_party/iree/compiler/bindings/c:headers", +# "//third_party/iree/compiler/bindings/c:loader", +# "//third_party/iree/llvm-external-projects/iree-dialects:IREEInputDialect", +# ]), # ) # # cc_library( # name = "executable", -# srcs = select({ -# ":with_openxla_runtime": ["executable.cc"], -# "//conditions:default": [], -# }), +# srcs = if_openxla(["executable.cc"]), # hdrs = ["executable.h"], -# # TODO(ezhulenev): Override cc_library()'s default compatibility because IREE targets are not -# # compatible with `non_prod` constraint. # compatible_with = [], -# defines = select({ -# ":with_openxla_runtime": [], -# "//conditions:default": ["XLA_DISABLE_OPENXLA_RUNTIME=1"], -# }), +# defines = if_not_openxla(["XLA_DISABLE_OPENXLA_RUNTIME=1"]), # deps = [ # "@com_google_absl//absl/log", # "@com_google_absl//absl/log:check", @@ -92,19 +68,83 @@ package_group( # "//tensorflow/compiler/xla/service:executable", # "//tensorflow/compiler/xla/service/gpu:buffer_allocations", # "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", -# ] + select({ -# ":with_openxla_runtime": [ -# ":compiler", -# "//third_party/iree/runtime/src/iree/base", -# "//third_party/iree/runtime/src/iree/hal", -# "//third_party/iree/runtime/src/iree/hal/drivers/cuda", -# "//third_party/iree/runtime/src/iree/modules/hal", -# "//third_party/iree/runtime/src/iree/modules/hal:types", -# "//third_party/iree/runtime/src/iree/vm", -# "//third_party/iree/runtime/src/iree/vm/bytecode:module", -# ], -# "//conditions:default": [], -# }), +# ] + if_openxla([ +# ":compiler", +# ":module", +# ":vm", +# "//third_party/iree/runtime/src/iree/base", +# "//third_party/iree/runtime/src/iree/hal", +# "//third_party/iree/runtime/src/iree/hal/drivers/cuda", +# "//third_party/iree/runtime/src/iree/modules/hal", +# "//third_party/iree/runtime/src/iree/modules/hal:types", +# "//third_party/iree/runtime/src/iree/vm", +# "//third_party/iree/runtime/src/iree/vm/bytecode:module", +# ]), +# ) +# +# cc_library( +# name = "gemm", +# srcs = if_openxla(["gemm.cc"]), +# hdrs = if_openxla(["gemm.h"]), +# compatible_with = [], +# deps = [ +# "@com_google_absl//absl/log", +# "//tensorflow/compiler/xla:status", +# "//tensorflow/compiler/xla:statusor", +# "//tensorflow/compiler/xla:xla_data_proto_cc", +# "//tensorflow/compiler/xla/service:executable", +# "//tensorflow/compiler/xla/service/gpu:matmul_utils", +# ] + if_openxla([ +# ":hal", +# ":vm", +# "//third_party/iree/runtime/src/iree/hal", +# "//third_party/iree/runtime/src/iree/vm", +# ]), +# ) +# +# cc_library( +# name = "module", +# srcs = if_openxla(["module.cc"]), +# hdrs = if_openxla(["module.h"]), +# compatible_with = [], +# deps = [ +# "@com_google_absl//absl/log", +# ] + if_openxla([ +# ":gemm", +# ":hal", +# ":vm", +# "//third_party/iree/runtime/src/iree/base", +# "//third_party/iree/runtime/src/iree/hal", +# "//third_party/iree/runtime/src/iree/modules/hal:types", +# "//third_party/iree/runtime/src/iree/vm", +# "//third_party/iree/runtime/src/iree/vm:cc", +# ]), +# ) +# +# cc_library( +# name = "hal", +# srcs = if_openxla(["hal.cc"]), +# hdrs = if_openxla(["hal.h"]), +# compatible_with = [], +# deps = [ +# "@com_google_absl//absl/container:inlined_vector", +# "@com_google_absl//absl/types:span", +# "//tensorflow/compiler/xla:shape_util", +# "//tensorflow/compiler/xla/stream_executor:device_memory", +# ] + if_openxla([ +# "//third_party/iree/runtime/src/iree/hal", +# ]), +# ) +# +# cc_library( +# name = "vm", +# srcs = if_openxla(["vm.cc"]), +# hdrs = if_openxla(["vm.h"]), +# compatible_with = [], +# deps = if_openxla([ +# "//third_party/iree/runtime/src/iree/vm", +# "//third_party/iree/runtime/src/iree/vm:cc", +# ]), # ) # # copybara:uncomment_end_and_comment_begin diff --git a/tensorflow/compiler/xla/service/gpu/openxla/build_config.bzl b/tensorflow/compiler/xla/service/gpu/openxla/build_config.bzl new file mode 100644 index 00000000000000..1599be4c4374c1 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/openxla/build_config.bzl @@ -0,0 +1,13 @@ +"""Helpers for conditional OpenXLA compilation.""" + +def if_openxla(then, otherwise = []): + return select({ + ":with_openxla_runtime": then, + "//conditions:default": otherwise, + }) + +def if_not_openxla(then, otherwise = []): + return select({ + ":with_openxla_runtime": otherwise, + "//conditions:default": then, + }) diff --git a/tensorflow/compiler/xla/service/gpu/openxla/compiler.cc b/tensorflow/compiler/xla/service/gpu/openxla/compiler.cc index 74df74631a9a9b..1077534141b66f 100644 --- a/tensorflow/compiler/xla/service/gpu/openxla/compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/openxla/compiler.cc @@ -183,7 +183,7 @@ using namespace mlir::iree_compiler; // NOLINT // TODO(ezhulenev): Query compute capability from the XLA module and set it up // at the module level. -static constexpr int kComputeCapability = 60; +static constexpr int kComputeCapability = 35; static IREE::Input::ExecutableTargetAttr getExecutableTarget(MLIRContext* ctx) { Builder b(ctx); diff --git a/tensorflow/compiler/xla/service/gpu/openxla/executable.cc b/tensorflow/compiler/xla/service/gpu/openxla/executable.cc index f96a63c347c13a..d242bc969e12e5 100644 --- a/tensorflow/compiler/xla/service/gpu/openxla/executable.cc +++ b/tensorflow/compiler/xla/service/gpu/openxla/executable.cc @@ -34,6 +34,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" #include "tensorflow/compiler/xla/service/gpu/openxla/compiler.h" +#include "tensorflow/compiler/xla/service/gpu/openxla/module.h" +#include "tensorflow/compiler/xla/service/gpu/openxla/vm.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/service_executable_run_options.h" #include "tensorflow/compiler/xla/status.h" @@ -132,6 +134,12 @@ OpenXlaRuntimeExecutable::Create(std::unique_ptr program, IREE_HAL_MODULE_FLAG_NONE, allocator, &modules->emplace_back())); + // Load XLA:GPU module. + IREE_CHECK_OK(RegisterXlaGpuTypes(instance.get())); + IREE_CHECK_OK(CreateXlaGpuModule(instance.get(), allocator, + iree_hal_device_allocator(device->device), + &modules->emplace_back())); + // Load module compiled from XLA program to a VM flatbuffer. IREE_CHECK_OK(iree_vm_bytecode_module_create( instance.get(), @@ -197,15 +205,22 @@ Status OpenXlaRuntimeExecutable::Execute( iree_allocator_t allocator = iree_allocator_system(); - // Convert XLA buffer allocations to IREE buffer views. + // Prepare a list for passing arguments to the function. iree::vm::ref inputs; IREE_CHECK_OK(iree_vm_list_create(iree_vm_make_undefined_type_def(), buffer_allocations.size(), allocator, &inputs)); + // Add execution context as the first arguments. + auto execution_context = + iree::vm::make_ref(run_options, &debug_options_); + // TODO(ezhulenev): Can we do ref_move here? + IREE_CHECK_OK(iree_vm_list_push_ref_retain(inputs.get(), execution_context)); + // Import argument buffers as device-local IREE buffers. std::vector> buffers; + // Convert XLA buffer allocations to IREE buffer views. for (unsigned i = 0; i < num_buffer_allocations; ++i) { // Import XLA buffer as an IREE external buffer. iree_hal_external_buffer_t external_buffer; diff --git a/tensorflow/compiler/xla/service/gpu/openxla/executable.h b/tensorflow/compiler/xla/service/gpu/openxla/executable.h index 33ad00401adfbe..d34d9760a42f1a 100644 --- a/tensorflow/compiler/xla/service/gpu/openxla/executable.h +++ b/tensorflow/compiler/xla/service/gpu/openxla/executable.h @@ -66,7 +66,7 @@ struct OpenXlaRuntimeExecutable { #include #include -#include "third_party/iree/runtime/src/iree/vm/api.h" +#include "third_party/iree/runtime/src/iree/vm/api.h" // IWYU pragma: keep #include "mlir/IR/OwningOpRef.h" // from @llvm-project #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/gpu/openxla/compiler.h" diff --git a/tensorflow/compiler/xla/service/gpu/openxla/gemm.cc b/tensorflow/compiler/xla/service/gpu/openxla/gemm.cc new file mode 100644 index 00000000000000..ad7b64e824fa35 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/openxla/gemm.cc @@ -0,0 +1,148 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/openxla/gemm.h" + +#include +#include + +#include "tensorflow/compiler/xla/service/gpu/matmul_utils.h" +#include "tensorflow/compiler/xla/service/gpu/openxla/hal.h" +#include "tensorflow/compiler/xla/service/service_executable_run_options.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla::gpu { + +//===-----------------------------------------------------------------------===/ +// XLA:GPU gemm API +//===-----------------------------------------------------------------------===/ + +static StatusOr GetGemmConfig(const iree_hal_buffer_view_t* lhs, + const iree_hal_buffer_view_t* rhs, + const iree_hal_buffer_view_t* out, + const vm::DotConfig& config) { + int64_t compute_precision = + config.dot_precision->precision.empty() + ? se::blas::kDefaultComputePrecision + : *absl::c_max_element(config.dot_precision->precision); + + return GemmConfig::For( + GetShape(lhs), config.dot_dimension_numbers->lhs_batch_dims, + config.dot_dimension_numbers->lhs_contracting_dims, // lhs + GetShape(rhs), config.dot_dimension_numbers->rhs_batch_dims, + config.dot_dimension_numbers->rhs_contracting_dims, // rhs + GetShape(out), // out + config.alpha_real, config.alpha_imag, config.beta, config.algorithm, + compute_precision); +} + +Status DispatchGemm(const vm::ExecutionContext& ctx, + iree_hal_allocator_t* device_allocator, + iree_hal_buffer_view_t* lhs, iree_hal_buffer_view_t* rhs, + iree_hal_buffer_view_t* out, const vm::DotConfig& config) { + se::Stream* stream = ctx.run_options->stream(); + + TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase lhs_data, + GetDeviceMemory(device_allocator, lhs)); + TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase rhs_data, + GetDeviceMemory(device_allocator, rhs)); + TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase output_data, + GetDeviceMemory(device_allocator, out)); + + bool deterministic = ctx.debug_options->xla_gpu_deterministic_ops(); + + TF_ASSIGN_OR_RETURN(auto gemm_config, GetGemmConfig(lhs, rhs, out, config)); + return RunGemm(gemm_config, lhs_data, rhs_data, output_data, deterministic, + stream); +} + +//===-----------------------------------------------------------------------===/ +// XLA:GPU gemm custom module API +//===-----------------------------------------------------------------------===/ + +// TODO(ezhulenev): We need to find a way to pass original Status back to the +// caller preserving the location and stack frame. Can we use some diagnostic +// side channel via the ExecutionContext? +static iree::Status FromStatus(Status status) { + if (status.ok()) return iree_ok_status(); + + // TODO(ezhulenev): Convert from ABSL to IREE error code. + std::string err = status.ToString(); + return iree_make_status(IREE_STATUS_INTERNAL, "internal error: %s", + err.c_str()); +} + +GemmAPI::GemmAPI(iree_hal_allocator_t* device_allocator) + : device_allocator_(device_allocator) {} + +iree::StatusOr> +GemmAPI::DotDimensionNumbersCreate( + iree::vm::ref lhs_batching_dims, + iree::vm::ref rhs_batching_dims, + iree::vm::ref lhs_contracting_dims, + iree::vm::ref rhs_contracting_dims) { + auto ref = iree::vm::make_ref(); + + IREE_ASSIGN_OR_RETURN(ref->lhs_batch_dims, + vm::GetI64Vector(lhs_batching_dims.get())); + IREE_ASSIGN_OR_RETURN(ref->rhs_batch_dims, + vm::GetI64Vector(rhs_batching_dims.get())); + IREE_ASSIGN_OR_RETURN(ref->lhs_contracting_dims, + vm::GetI64Vector(lhs_contracting_dims.get())); + IREE_ASSIGN_OR_RETURN(ref->rhs_contracting_dims, + vm::GetI64Vector(rhs_contracting_dims.get())); + + return ref; +} + +iree::StatusOr> GemmAPI::DotPrecisionCreate( + iree::vm::ref precision) { + auto ref = iree::vm::make_ref(); + IREE_ASSIGN_OR_RETURN(ref->precision, vm::GetI64Vector(precision.get())); + return ref; +} + +iree::StatusOr> GemmAPI::DotConfigCreate( + int32_t algorithm, float alpha_real, float alpha_imag, float beta, + iree::vm::ref dot_dimension_numbers, + iree::vm::ref dot_precision) { + auto ref = iree::vm::make_ref(); + ref->algorithm = algorithm; + ref->alpha_real = alpha_real; + ref->alpha_imag = alpha_imag; + ref->beta = beta; + ref->dot_dimension_numbers = std::move(dot_dimension_numbers); + ref->dot_precision = std::move(dot_precision); + return ref; +} + +iree::Status GemmAPI::GemmDispatch(iree::vm::ref ctx, + iree::vm::ref lhs, + iree::vm::ref rhs, + iree::vm::ref out, + iree::vm::ref config) { + return FromStatus(DispatchGemm(*ctx, device_allocator_, lhs.get(), rhs.get(), + out.get(), *config)); +} +} // namespace xla::gpu + +//===----------------------------------------------------------------------===// +// Register types with IREE VM +//===----------------------------------------------------------------------===// + +IREE_VM_DEFINE_TYPE_ADAPTERS(dot_config, xla::gpu::vm::DotConfig); +IREE_VM_DEFINE_TYPE_ADAPTERS(dot_dimension_numbers, + xla::gpu::vm::DotDimensionNumbers); +IREE_VM_DEFINE_TYPE_ADAPTERS(dot_precision, xla::gpu::vm::DotPrecision); diff --git a/tensorflow/compiler/xla/service/gpu/openxla/gemm.h b/tensorflow/compiler/xla/service/gpu/openxla/gemm.h new file mode 100644 index 00000000000000..79dbd3fada167a --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/openxla/gemm.h @@ -0,0 +1,113 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_OPENXLA_GEMM_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_OPENXLA_GEMM_H_ + +#include +#include + +#include "third_party/iree/runtime/src/iree/hal/api.h" // IWYU pragma: keep +#include "third_party/iree/runtime/src/iree/vm/api.h" // IWYU pragma: keep +#include "tensorflow/compiler/xla/service/gpu/openxla/vm.h" +#include "tensorflow/compiler/xla/status.h" + +namespace xla::gpu { + +//===-----------------------------------------------------------------------===/ +// XLA:GPU custom module module types +//===-----------------------------------------------------------------------===/ + +namespace vm { + +struct DotDimensionNumbers : public iree::vm::RefObject { + std::vector lhs_batch_dims; + std::vector rhs_batch_dims; + std::vector lhs_contracting_dims; + std::vector rhs_contracting_dims; +}; + +struct DotPrecision : public iree::vm::RefObject { + std::vector precision; +}; + +struct DotConfig : public iree::vm::RefObject { + double alpha_real; + double alpha_imag; + double beta; + int32_t algorithm; + iree::vm::ref dot_dimension_numbers; + iree::vm::ref dot_precision; +}; + +} // namespace vm + +//===-----------------------------------------------------------------------===/ +// XLA:GPU gemm API +//===-----------------------------------------------------------------------===/ + +Status DispatchGemm(const vm::ExecutionContext& ctx, + iree_hal_allocator_t* device_allocator, + iree_hal_buffer_view_t* lhs, iree_hal_buffer_view_t* rhs, + iree_hal_buffer_view_t* out, const vm::DotConfig& config); + +//===-----------------------------------------------------------------------===/ +// XLA:GPU gemm custom module API +//===-----------------------------------------------------------------------===/ + +class GemmAPI { + public: + explicit GemmAPI(iree_hal_allocator_t* device_allocator); + + // Creates `xla_gpu.dot_dimension_numbers` value. + iree::StatusOr> + DotDimensionNumbersCreate(iree::vm::ref lhs_batching_dims, + iree::vm::ref rhs_batching_dims, + iree::vm::ref lhs_contracting_dims, + iree::vm::ref rhs_contracting_dims); + + // Creates `xla_gpu.dot_precision` value. + iree::StatusOr> DotPrecisionCreate( + iree::vm::ref precision); + + // Creates `xla_gpu.dot_config` value. + iree::StatusOr> DotConfigCreate( + int32_t algorithm, float alpha_real, float alpha_imag, float beta, + iree::vm::ref dot_dimension_numbers, + iree::vm::ref dot_precision); + + // Dispatches gemm operation with given buffers and config. + iree::Status GemmDispatch(iree::vm::ref ctx, + iree::vm::ref lhs, + iree::vm::ref rhs, + iree::vm::ref out, + iree::vm::ref config); + + private: + iree_hal_allocator_t* device_allocator_; +}; + +} // namespace xla::gpu + +//===----------------------------------------------------------------------===// +// Register types with IREE VM +//===----------------------------------------------------------------------===// + +IREE_VM_DECLARE_TYPE_ADAPTERS(dot_config, xla::gpu::vm::DotConfig); +IREE_VM_DECLARE_TYPE_ADAPTERS(dot_dimension_numbers, + xla::gpu::vm::DotDimensionNumbers); +IREE_VM_DECLARE_TYPE_ADAPTERS(dot_precision, xla::gpu::vm::DotPrecision); + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_OPENXLA_GEMM_H_ diff --git a/tensorflow/compiler/xla/service/gpu/openxla/hal.cc b/tensorflow/compiler/xla/service/gpu/openxla/hal.cc new file mode 100644 index 00000000000000..173c52795c5579 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/openxla/hal.cc @@ -0,0 +1,77 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/openxla/hal.h" + +#include "absl/container/inlined_vector.h" +#include "absl/types/span.h" +#include "third_party/iree/runtime/src/iree/hal/api.h" // IWYU pragma: keep +#include "tensorflow/compiler/xla/shape_util.h" + +namespace xla::gpu { + +//===----------------------------------------------------------------------===// +// Helper functions to work with IREE buffers +//===----------------------------------------------------------------------===// + +// TODO(ezhulenev): Add support for all element types. +static PrimitiveType GetElementType(const iree_hal_buffer_view_t* view) { + switch (iree_hal_buffer_view_element_type(view)) { + case IREE_HAL_ELEMENT_TYPE_FLOAT_32: + return PrimitiveType::F32; + default: + assert(false && "unsupported iree element type"); + return PrimitiveType::PRIMITIVE_TYPE_INVALID; + } +} + +static absl::InlinedVector GetDims( + const iree_hal_buffer_view_t* view) { + const iree_host_size_t* dims = iree_hal_buffer_view_shape_dims(view); + iree_host_size_t rank = iree_hal_buffer_view_shape_rank(view); + return absl::InlinedVector(dims, dims + rank); +} + +Shape GetShape(const iree_hal_buffer_view_t* view) { + return ShapeUtil::MakeShape(GetElementType(view), GetDims(view)); +} + +StatusOr GetDeviceMemory( + iree_hal_allocator_t* device_allocator, iree_hal_buffer_t* buffer) { + // Get original allocation behind a buffer subspan. + iree_hal_buffer_t* allocated = iree_hal_buffer_allocated_buffer(buffer); + + // Export original allocation as a device allocation. + iree_hal_external_buffer_t external_buffer; + iree::Status exported = iree_hal_allocator_export_buffer( + device_allocator, allocated, + IREE_HAL_EXTERNAL_BUFFER_TYPE_DEVICE_ALLOCATION, + IREE_HAL_EXTERNAL_BUFFER_FLAG_NONE, &external_buffer); + + if (!exported.ok()) + return absl::InternalError(absl::StrFormat( + "failed to export HAL buffer: %s", exported.ToString())); + + auto* data = reinterpret_cast( + external_buffer.handle.device_allocation.ptr); + return se::DeviceMemoryBase(data + iree_hal_buffer_byte_offset(buffer)); +} + +StatusOr GetDeviceMemory( + iree_hal_allocator_t* device_allocator, iree_hal_buffer_view_t* view) { + return GetDeviceMemory(device_allocator, iree_hal_buffer_view_buffer(view)); +} + +} // namespace xla::gpu diff --git a/tensorflow/compiler/xla/service/gpu/openxla/hal.h b/tensorflow/compiler/xla/service/gpu/openxla/hal.h new file mode 100644 index 00000000000000..ea33958cc01ea5 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/openxla/hal.h @@ -0,0 +1,38 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_OPENXLA_HAL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_OPENXLA_HAL_H_ + +#include "third_party/iree/runtime/src/iree/hal/api.h" // IWYU pragma: keep +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/stream_executor/device_memory.h" + +namespace xla::gpu { + +//===----------------------------------------------------------------------===// +// Helper functions to work with IREE buffers and buffer views +//===----------------------------------------------------------------------===// + +Shape GetShape(const iree_hal_buffer_view_t* view); + +StatusOr GetDeviceMemory( + iree_hal_allocator_t* device_allocator, iree_hal_buffer_t* buffer); +StatusOr GetDeviceMemory( + iree_hal_allocator_t* device_allocator, iree_hal_buffer_view_t* view); + +} // namespace xla::gpu + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_OPENXLA_HAL_H_ diff --git a/tensorflow/compiler/xla/service/gpu/openxla/module.cc b/tensorflow/compiler/xla/service/gpu/openxla/module.cc new file mode 100644 index 00000000000000..554df91150c7dc --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/openxla/module.cc @@ -0,0 +1,161 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/openxla/module.h" + +#include + +#include "third_party/iree/runtime/src/iree/base/api.h" // IWYU pragma: keep +#include "third_party/iree/runtime/src/iree/hal/api.h" // IWYU pragma: keep +#include "third_party/iree/runtime/src/iree/modules/hal/types.h" // IWYU pragma: keep +#include "third_party/iree/runtime/src/iree/vm/api.h" // IWYU pragma: keep +#include "third_party/iree/runtime/src/iree/vm/native_module_cc.h" +#include "third_party/iree/runtime/src/iree/vm/native_module_packing.h" +#include "tensorflow/compiler/xla/service/gpu/openxla/gemm.h" +#include "tensorflow/compiler/xla/service/gpu/openxla/vm.h" + +namespace xla::gpu { + +//===-----------------------------------------------------------------------===/ +// XLA:GPU custom module state +//===-----------------------------------------------------------------------===/ + +class XlaGpuModuleState : public GemmAPI { + public: + explicit XlaGpuModuleState(iree_hal_allocator_t* device_allocator) + : GemmAPI(device_allocator) {} +}; + +//===----------------------------------------------------------------------===// +// Helper functions for exporting native functions to a module +//===----------------------------------------------------------------------===// + +// Casts a pointer to a member function of a base class to a pointer to a member +// function of the parent class. ("Pointer to member conversions" [conv.mem]). +template +static constexpr auto UpCast(iree::Status (Base::*fn)(Params...)) { + return static_cast(fn); +} + +template +static constexpr auto UpCast(iree::StatusOr (Base::*fn)(Params...)) { + return static_cast (XlaGpuModuleState::*)(Params...)>(fn); +} + +template +static constexpr auto MakeApiFunction(const char* name, + iree::Status (Base::*fn)(Params...)) { + return iree::vm::MakeNativeFunction(name, UpCast(fn)); +} + +template +static constexpr auto MakeApiFunction( + const char* name, iree::StatusOr (Base::*fn)(Params...)) { + return iree::vm::MakeNativeFunction(name, UpCast(fn)); +} + +//===-----------------------------------------------------------------------===/ +// XLA:GPU custom module +//===-----------------------------------------------------------------------===/ + +static const iree::vm::NativeFunction kXlaGpuFunctions[] = { + // XLA:GPU Gemm APIs + MakeApiFunction("dot_dimension_numbers.create", + &GemmAPI::DotDimensionNumbersCreate), + MakeApiFunction("dot_precision.create", &GemmAPI::DotPrecisionCreate), + MakeApiFunction("dot_config.create", &GemmAPI::DotConfigCreate), + MakeApiFunction("gemm.dispatch", &GemmAPI::GemmDispatch), +}; + +class XlaGpuModule : public iree::vm::NativeModule { + using NativeModule = iree::vm::NativeModule; + + public: + XlaGpuModule(iree_vm_instance_t* instance, iree_allocator_t host_allocator, + iree_hal_allocator_t* device_allocator); + + iree::StatusOr> CreateState( + iree_allocator_t host_allocator) override; + + private: + static constexpr uint32_t kVersion = 0; + iree_hal_allocator_t* device_allocator_; +}; + +XlaGpuModule::XlaGpuModule(iree_vm_instance_t* instance, + iree_allocator_t host_allocator, + iree_hal_allocator_t* device_allocator) + : NativeModule("xla_gpu", kVersion, instance, host_allocator, + kXlaGpuFunctions), + device_allocator_(device_allocator) {} + +iree::StatusOr> XlaGpuModule::CreateState( + iree_allocator_t host_allocator) { + return std::make_unique(device_allocator_); +} + +//===-----------------------------------------------------------------------===/ +// XLA:GPU custom module constructor +//===-----------------------------------------------------------------------===/ + +iree_status_t CreateXlaGpuModule(iree_vm_instance_t* instance, + iree_allocator_t host_allocator, + iree_hal_allocator_t* device_allocator, + iree_vm_module_t** out_module) { + IREE_ASSERT_ARGUMENT(out_module); + + RegisterXlaGpuTypes(instance); + + auto module = std::make_unique(instance, host_allocator, + device_allocator); + *out_module = module.release()->interface(); + + return iree_ok_status(); +} + +//===-----------------------------------------------------------------------===/ +// XLA:GPU custom module type registration +//===-----------------------------------------------------------------------===/ + +template +static iree_status_t RegisterType(iree_vm_instance_t* instance, + const char* name, iree_vm_ref_type_t* out) { + static iree_vm_ref_type_descriptor_t descriptor = {nullptr}; + + descriptor.type_name = iree_make_cstring_view(name); + descriptor.offsetof_counter = T::offsetof_counter(); + descriptor.destroy = T::DirectDestroy; + + return iree_vm_instance_register_type(instance, &descriptor, out); +} + +iree_status_t RegisterXlaGpuTypes(iree_vm_instance_t* instance) { + // XLA:GPU Execution context type + IREE_RETURN_IF_ERROR(RegisterType( + instance, "xla_gpu.execution_context", &execution_context_registration)); + + // XLA:GPU Gemm types + IREE_RETURN_IF_ERROR(RegisterType( + instance, "xla_gpu.dot_dimension_numbers", + &dot_dimension_numbers_registration)); + IREE_RETURN_IF_ERROR(RegisterType( + instance, "xla_gpu.dot_precision", &dot_precision_registration)); + IREE_RETURN_IF_ERROR(RegisterType( + instance, "xla_gpu.dot_config", &dot_config_registration)); + + return iree_ok_status(); +} + +} // namespace xla::gpu diff --git a/tensorflow/compiler/xla/service/gpu/openxla/module.h b/tensorflow/compiler/xla/service/gpu/openxla/module.h new file mode 100644 index 00000000000000..db440c214eedce --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/openxla/module.h @@ -0,0 +1,35 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_OPENXLA_MODULE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_OPENXLA_MODULE_H_ + +#include "third_party/iree/runtime/src/iree/hal/api.h" // IWYU pragma: keep +#include "third_party/iree/runtime/src/iree/vm/api.h" // IWYU pragma: keep + +namespace xla::gpu { + +// Creates XLA:GPU custom module implementing StreamExecutor integration. +iree_status_t CreateXlaGpuModule(iree_vm_instance_t* instance, + iree_allocator_t host_allocator, + iree_hal_allocator_t* device_allocator, + iree_vm_module_t** out_module); + +// Register XLA:GPU custom module types with the IREE VM. +iree_status_t RegisterXlaGpuTypes(iree_vm_instance_t* instance); + +} // namespace xla::gpu + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_OPENXLA_MODULE_H_ diff --git a/tensorflow/compiler/xla/service/gpu/openxla/vm.cc b/tensorflow/compiler/xla/service/gpu/openxla/vm.cc new file mode 100644 index 00000000000000..a295b418ea3158 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/openxla/vm.cc @@ -0,0 +1,46 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/openxla/vm.h" + +#include + +namespace xla::gpu::vm { + +using iree::StatusOr; + +//===----------------------------------------------------------------------===// +// Helper functions to work with VM lists +//===----------------------------------------------------------------------===// + +StatusOr> GetI64Vector(const iree_vm_list_t* list) { + iree_host_size_t size = iree_vm_list_size(list); + std::vector values(size); + for (iree_host_size_t i = 0; i < size; ++i) { + iree_vm_value_t value; + IREE_RETURN_IF_ERROR( + iree_vm_list_get_value_as(list, i, IREE_VM_VALUE_TYPE_I64, &value)); + values[i] = value.i64; + } + return values; +} + +} // namespace xla::gpu::vm + +//===----------------------------------------------------------------------===// +// Register types with IREE VM +//===----------------------------------------------------------------------===// + +IREE_VM_DEFINE_TYPE_ADAPTERS(execution_context, xla::gpu::vm::ExecutionContext); diff --git a/tensorflow/compiler/xla/service/gpu/openxla/vm.h b/tensorflow/compiler/xla/service/gpu/openxla/vm.h new file mode 100644 index 00000000000000..d8cccd347bc01a --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/openxla/vm.h @@ -0,0 +1,62 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_OPENXLA_VM_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_OPENXLA_VM_H_ + +#include + +#include "third_party/iree/runtime/src/iree/vm/api.h" // IWYU pragma: keep + +namespace xla { + +class DebugOptions; +class ServiceExecutableRunOptions; + +namespace gpu::vm { + +//===----------------------------------------------------------------------===// +// Execution context of a single XLA invocation +//===----------------------------------------------------------------------===// + +// We use XLA:GPU execution context to pass XLA:GPU invocation details to all +// runtime APIs. For example through `run_options` pointer we get access to +// the current compute stream, stream borrower, parent executor, etc. +struct ExecutionContext : public iree::vm::RefObject { + ExecutionContext(const ServiceExecutableRunOptions* run_options, + const DebugOptions* debug_options) + : run_options(run_options), debug_options(debug_options) {} + + const ServiceExecutableRunOptions* run_options; + const DebugOptions* debug_options; +}; + +//===----------------------------------------------------------------------===// +// Helper functions to work with VM lists +//===----------------------------------------------------------------------===// + +iree::StatusOr> GetI64Vector(const iree_vm_list_t* list); + +} // namespace gpu::vm +} // namespace xla + +//===----------------------------------------------------------------------===// +// Register types with IREE VM +//===----------------------------------------------------------------------===// + +IREE_VM_DECLARE_TYPE_ADAPTERS(execution_context, + xla::gpu::vm::ExecutionContext); + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_OPENXLA_VM_H_ From 5bb35d2fee771d0f7ede8086d9c21bf499d674c2 Mon Sep 17 00:00:00 2001 From: Christian Sigg Date: Fri, 28 Jul 2023 00:11:24 -0700 Subject: [PATCH 296/410] [XLA:GPU] Bubble up triton autotuner error messages. PiperOrigin-RevId: 551758009 --- .../compiler/xla/service/gpu/triton_autotuner.cc | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/triton_autotuner.cc b/tensorflow/compiler/xla/service/gpu/triton_autotuner.cc index eabd445a72240c..e0587af9197537 100644 --- a/tensorflow/compiler/xla/service/gpu/triton_autotuner.cc +++ b/tensorflow/compiler/xla/service/gpu/triton_autotuner.cc @@ -258,8 +258,7 @@ class TritonAutotunerVisitor : public DfsHloRewriteVisitor { if (!rz_check_status.ok()) { LOG(ERROR) << "Red zone modified"; res.mutable_failure()->set_kind(AutotuneResult::REDZONE_MODIFIED); - *res.mutable_failure()->mutable_msg() = - rz_check_status.RedzoneFailureMsg(); + res.mutable_failure()->set_msg(rz_check_status.RedzoneFailureMsg()); CHECK(!config_.should_crash_on_check_failure()); continue; } @@ -270,12 +269,15 @@ class TritonAutotunerVisitor : public DfsHloRewriteVisitor { stream, /*current=*/profiling_output->output.root_buffer(), /*expected=*/reference_buffer->root_buffer())); if (!outputs_match) { - LOG(ERROR) << "Results do not match the reference. " - << "This is likely a bug/unexpected loss of precision."; + const char kMessage[] = + "Results do not match the reference. This is likely a " + "bug/unexpected loss of precision."; + LOG(ERROR) << kMessage; CHECK(!config_.should_crash_on_check_failure()); // WRONG_RESULT is not taken seriously by PickBestResult(), so // use DISQUALIFIED. res.mutable_failure()->set_kind(AutotuneResult::DISQUALIFIED); + res.mutable_failure()->set_msg(kMessage); } } results.push_back(res); From 158426ab2b6f53d060646d13b88fc3891a098d0c Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Fri, 28 Jul 2023 00:46:40 -0700 Subject: [PATCH 297/410] [xla:gpu] Add HLO tracing support Extract original HLO operation names from an MLIR operation location and pass it to runtime for annotating library calls with ScopedAnnotation. This only works for "library calls" operations (e.g. cuBLAS/gemm), and we need deeper integration with IREE to support tracing for kernel launches. PiperOrigin-RevId: 551766108 --- .../mlir/backends/openxla/conversion/BUILD | 1 + .../openxla/conversion/convert_library_ops.cc | 37 ++++++++++++++++++- .../backends/openxla/ir/xla_gpu_dialect.td | 3 ++ .../transforms/tests/gemm_to_openxla.mlir | 9 ++++- .../compiler/xla/service/gpu/openxla/BUILD | 3 +- .../compiler/xla/service/gpu/openxla/gemm.cc | 14 ++++++- .../compiler/xla/service/gpu/openxla/gemm.h | 8 +++- .../xla/service/gpu/openxla/module.cc | 12 +++++- .../compiler/xla/service/gpu/openxla/vm.cc | 21 ++++++++++- .../compiler/xla/service/gpu/openxla/vm.h | 18 +++++++++ 10 files changed, 114 insertions(+), 12 deletions(-) diff --git a/tensorflow/compiler/xla/mlir/backends/openxla/conversion/BUILD b/tensorflow/compiler/xla/mlir/backends/openxla/conversion/BUILD index 589f85c39b3750..5a858ae8c8d232 100644 --- a/tensorflow/compiler/xla/mlir/backends/openxla/conversion/BUILD +++ b/tensorflow/compiler/xla/mlir/backends/openxla/conversion/BUILD @@ -63,6 +63,7 @@ package( # "@llvm-project//mlir:Transforms", # "//tensorflow/compiler/xla/mlir/backends/openxla/ir:xla_gpu", # "//tensorflow/compiler/xla/mlir_hlo:lhlo_gpu", +# "//tensorflow/compiler/xla/translate/mhlo_to_hlo:location_exporter", # ] + if_openxla(["//third_party/iree/llvm-external-projects/iree-dialects:IREEInputDialect"]), # ) # diff --git a/tensorflow/compiler/xla/mlir/backends/openxla/conversion/convert_library_ops.cc b/tensorflow/compiler/xla/mlir/backends/openxla/conversion/convert_library_ops.cc index 9118f6c9d1039b..1f17b3c1ce71e2 100644 --- a/tensorflow/compiler/xla/mlir/backends/openxla/conversion/convert_library_ops.cc +++ b/tensorflow/compiler/xla/mlir/backends/openxla/conversion/convert_library_ops.cc @@ -37,6 +37,7 @@ limitations under the License. #include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/xla/mlir/backends/openxla/ir/xla_gpu_dialect.h" #include "tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.h" +#include "tensorflow/compiler/xla/translate/mhlo_to_hlo/location_exporter.h" namespace xla::gpu { @@ -67,6 +68,9 @@ class XlaGpuApi { // Imports `@xla_gpu.gemm.dispatch` into the module. func::FuncOp getDispatchGemm(OpBuilder &b, ModuleOp module); + // Imports `@xla_gpu.trace.create` into the module. + func::FuncOp getCreateTrace(OpBuilder &b, ModuleOp module); + private: SymbolTable &symTable(ModuleOp module); @@ -114,12 +118,23 @@ func::FuncOp XlaGpuApi::getCreateDotConfig(OpBuilder &b, ModuleOp module) { func::FuncOp XlaGpuApi::getDispatchGemm(OpBuilder &b, ModuleOp module) { auto execution_context = b.getType(); auto buffer_view = b.getType(); - SmallVector args = {execution_context, buffer_view, buffer_view, - buffer_view, b.getType()}; + SmallVector args = {execution_context, + buffer_view, // lhs + buffer_view, // rhs + buffer_view, // out + b.getType(), + b.getType()}; return addDecl(b, module, "xla_gpu.gemm.dispatch", FunctionType::get(b.getContext(), args, /*rets=*/TypeRange())); } +func::FuncOp XlaGpuApi::getCreateTrace(OpBuilder &b, ModuleOp module) { + SmallVector args = {b.getType()}; + SmallVector rets = {b.getType()}; + return addDecl(b, module, "xla_gpu.trace.create", + FunctionType::get(b.getContext(), args, rets)); +} + SymbolTable &XlaGpuApi::symTable(ModuleOp module) { return sym_table_.getSymbolTable(module); } @@ -221,6 +236,22 @@ TypedValue getDotConfig(XlaGpuApi &api, ImplicitLocOpBuilder &b, return call.getResult(0).cast>(); } +TypedValue getTrace(XlaGpuApi &api, ImplicitLocOpBuilder &b, + ModuleOp module, lmhlo_gpu::GEMMOp op) { + // Get original HLO operation name from the location. + Value hlo_op = b.create( + b.getType(), + /*name=*/b.getStringAttr("hlo_op"), + /*value=*/mhlo::GetDebugNameFromLocation(op.getLoc()), + /*alignment=*/nullptr, /*mime_type=*/nullptr); + + auto api_func = api.getCreateTrace(b, module); + auto call = b.create( + api_func.getSymName(), api_func.getResultTypes(), ValueRange(hlo_op)); + + return call.getResult(0).cast>(); +} + TypedValue getExecutionContext(Operation *op) { auto func = op->getParentOfType(); return func.getArguments().front().cast>(); @@ -243,6 +274,7 @@ struct ConvertGemmOp : public OpConversionPattern { ImplicitLocOpBuilder b(op.getLoc(), rewriter); auto dot_config = getDotConfig(*api, b, module, op); + auto trace = getTrace(*api, b, module, op); // Export arguments to buffer views. auto lhs = state->remapped[block][op.getA()]; @@ -263,6 +295,7 @@ struct ConvertGemmOp : public OpConversionPattern { args.push_back(export_op.getResult()); } args.push_back(dot_config); + args.push_back(trace); // TODO(ezhulenev): Should we import buffer view back and update remapping? auto api_func = api->getDispatchGemm(b, module); diff --git a/tensorflow/compiler/xla/mlir/backends/openxla/ir/xla_gpu_dialect.td b/tensorflow/compiler/xla/mlir/backends/openxla/ir/xla_gpu_dialect.td index f4042b2672b01a..bb75d411877183 100644 --- a/tensorflow/compiler/xla/mlir/backends/openxla/ir/xla_gpu_dialect.td +++ b/tensorflow/compiler/xla/mlir/backends/openxla/ir/xla_gpu_dialect.td @@ -65,5 +65,8 @@ def DotPrecisionType : XLA_GPU_Type<"DotPrecision", "dot_precision"> { let summary = "Precision for dot operation"; } +def TraceType : XLA_GPU_Type<"Trace", "trace"> { + let summary = "XLA:GPU trace annotation"; +} #endif // XLA_GPU_DIALECT diff --git a/tensorflow/compiler/xla/mlir/backends/openxla/transforms/tests/gemm_to_openxla.mlir b/tensorflow/compiler/xla/mlir/backends/openxla/transforms/tests/gemm_to_openxla.mlir index 5475e64970e3d6..aa0c7f35fab5b3 100644 --- a/tensorflow/compiler/xla/mlir/backends/openxla/transforms/tests/gemm_to_openxla.mlir +++ b/tensorflow/compiler/xla/mlir/backends/openxla/transforms/tests/gemm_to_openxla.mlir @@ -2,6 +2,8 @@ // RUN: xla-openxla-opt %s --xla-gpu-to-openxla --split-input-file \ // RUN: | FileCheck %s +#loc = loc("custom-call") + func.func @gemm( %arg0: memref<128xi8>, %arg1: memref<64xi8>, %arg2: memref<32xi8> {lmhlo.output_index = dense<> : tensor<0xi64>} @@ -18,7 +20,7 @@ func.func @gemm( dot_dimension_numbers = #mhlo.dot, precision_config = [#mhlo, #mhlo] - } : (memref<4x8xf32>, memref<8x2xf32>, memref<4x2xf32>) -> () + } : (memref<4x8xf32>, memref<8x2xf32>, memref<4x2xf32>) -> () loc(#loc) return } @@ -35,10 +37,13 @@ func.func @gemm( // CHECK: %[[DIMS:.*]] = call @xla_gpu.dot_dimension_numbers.create // CHECK: %[[PRECISION:.*]] = call @xla_gpu.dot_precision.create // CHECK: %[[CONFIG:.*]] = call @xla_gpu.dot_config.create +// CHECK: %[[HLO:.*]] = iree_input.byte_buffer.constant {{.*}} = "custom-call" +// CHECK: %[[TRACE:.*]] = call @xla_gpu.trace.create(%[[HLO]]) // CHECK: %[[LHS_BUF:.*]] = iree_input.tensor.export %[[LHS]] // CHECK: %[[RHS_BUF:.*]] = iree_input.tensor.export %[[RHS]] // CHECK: %[[OUT_BUF:.*]] = iree_input.tensor.export %[[OUT]] // CHECK: call @xla_gpu.gemm.dispatch( -// CHECK: %[[CTX]], %[[LHS_BUF]], %[[RHS_BUF]], %[[OUT_BUF]], %[[CONFIG]] +// CHECK: %[[CTX]], %[[LHS_BUF]], %[[RHS_BUF]], %[[OUT_BUF]], +// CHECK: %[[CONFIG]], %[[TRACE]] // CHECK: ) // CHECK: } diff --git a/tensorflow/compiler/xla/service/gpu/openxla/BUILD b/tensorflow/compiler/xla/service/gpu/openxla/BUILD index ccbf7abb102460..aa7ef22c441f6a 100644 --- a/tensorflow/compiler/xla/service/gpu/openxla/BUILD +++ b/tensorflow/compiler/xla/service/gpu/openxla/BUILD @@ -94,6 +94,7 @@ config_setting( # "//tensorflow/compiler/xla:xla_data_proto_cc", # "//tensorflow/compiler/xla/service:executable", # "//tensorflow/compiler/xla/service/gpu:matmul_utils", +# "//tensorflow/tsl/profiler/lib:scoped_annotation", # ] + if_openxla([ # ":hal", # ":vm", @@ -141,7 +142,7 @@ config_setting( # srcs = if_openxla(["vm.cc"]), # hdrs = if_openxla(["vm.h"]), # compatible_with = [], -# deps = if_openxla([ +# deps = ["@com_google_absl//absl/strings:str_format"] + if_openxla([ # "//third_party/iree/runtime/src/iree/vm", # "//third_party/iree/runtime/src/iree/vm:cc", # ]), diff --git a/tensorflow/compiler/xla/service/gpu/openxla/gemm.cc b/tensorflow/compiler/xla/service/gpu/openxla/gemm.cc index ad7b64e824fa35..f40c00e1f0487f 100644 --- a/tensorflow/compiler/xla/service/gpu/openxla/gemm.cc +++ b/tensorflow/compiler/xla/service/gpu/openxla/gemm.cc @@ -20,11 +20,15 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/matmul_utils.h" #include "tensorflow/compiler/xla/service/gpu/openxla/hal.h" +#include "tensorflow/compiler/xla/service/gpu/openxla/vm.h" #include "tensorflow/compiler/xla/service/service_executable_run_options.h" #include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/tsl/profiler/lib/scoped_annotation.h" namespace xla::gpu { +using tsl::profiler::ScopedAnnotation; + //===-----------------------------------------------------------------------===/ // XLA:GPU gemm API //===-----------------------------------------------------------------------===/ @@ -72,6 +76,8 @@ Status DispatchGemm(const vm::ExecutionContext& ctx, // XLA:GPU gemm custom module API //===-----------------------------------------------------------------------===/ +namespace vm { + // TODO(ezhulenev): We need to find a way to pass original Status back to the // caller preserving the location and stack frame. Can we use some diagnostic // side channel via the ExecutionContext? @@ -128,14 +134,18 @@ iree::StatusOr> GemmAPI::DotConfigCreate( return ref; } -iree::Status GemmAPI::GemmDispatch(iree::vm::ref ctx, +iree::Status GemmAPI::GemmDispatch(iree::vm::ref ctx, iree::vm::ref lhs, iree::vm::ref rhs, iree::vm::ref out, - iree::vm::ref config) { + iree::vm::ref config, + iree::vm::ref trace) { + ScopedAnnotation annotation([&] { return ToScopedAnnotationName(*trace); }); return FromStatus(DispatchGemm(*ctx, device_allocator_, lhs.get(), rhs.get(), out.get(), *config)); } + +} // namespace vm } // namespace xla::gpu //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/xla/service/gpu/openxla/gemm.h b/tensorflow/compiler/xla/service/gpu/openxla/gemm.h index 79dbd3fada167a..cf81ce0c4dc7e1 100644 --- a/tensorflow/compiler/xla/service/gpu/openxla/gemm.h +++ b/tensorflow/compiler/xla/service/gpu/openxla/gemm.h @@ -67,6 +67,8 @@ Status DispatchGemm(const vm::ExecutionContext& ctx, // XLA:GPU gemm custom module API //===-----------------------------------------------------------------------===/ +namespace vm { + class GemmAPI { public: explicit GemmAPI(iree_hal_allocator_t* device_allocator); @@ -89,16 +91,18 @@ class GemmAPI { iree::vm::ref dot_precision); // Dispatches gemm operation with given buffers and config. - iree::Status GemmDispatch(iree::vm::ref ctx, + iree::Status GemmDispatch(iree::vm::ref ctx, iree::vm::ref lhs, iree::vm::ref rhs, iree::vm::ref out, - iree::vm::ref config); + iree::vm::ref config, + iree::vm::ref trace); private: iree_hal_allocator_t* device_allocator_; }; +} // namespace vm } // namespace xla::gpu //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/xla/service/gpu/openxla/module.cc b/tensorflow/compiler/xla/service/gpu/openxla/module.cc index 554df91150c7dc..e9e489868818dc 100644 --- a/tensorflow/compiler/xla/service/gpu/openxla/module.cc +++ b/tensorflow/compiler/xla/service/gpu/openxla/module.cc @@ -32,7 +32,10 @@ namespace xla::gpu { // XLA:GPU custom module state //===-----------------------------------------------------------------------===/ -class XlaGpuModuleState : public GemmAPI { +using vm::GemmAPI; +using vm::TraceAPI; + +class XlaGpuModuleState : public GemmAPI, public TraceAPI { public: explicit XlaGpuModuleState(iree_hal_allocator_t* device_allocator) : GemmAPI(device_allocator) {} @@ -77,6 +80,9 @@ static const iree::vm::NativeFunction kXlaGpuFunctions[] = { MakeApiFunction("dot_precision.create", &GemmAPI::DotPrecisionCreate), MakeApiFunction("dot_config.create", &GemmAPI::DotConfigCreate), MakeApiFunction("gemm.dispatch", &GemmAPI::GemmDispatch), + + // XLA:GPU tracing APIs + MakeApiFunction("trace.create", &TraceAPI::TraceCreate), }; class XlaGpuModule : public iree::vm::NativeModule { @@ -155,6 +161,10 @@ iree_status_t RegisterXlaGpuTypes(iree_vm_instance_t* instance) { IREE_RETURN_IF_ERROR(RegisterType( instance, "xla_gpu.dot_config", &dot_config_registration)); + // XLA:GPU tracing types + IREE_RETURN_IF_ERROR( + RegisterType(instance, "xla_gpu.trace", &trace_registration)); + return iree_ok_status(); } diff --git a/tensorflow/compiler/xla/service/gpu/openxla/vm.cc b/tensorflow/compiler/xla/service/gpu/openxla/vm.cc index a295b418ea3158..dea05d4d3d68d9 100644 --- a/tensorflow/compiler/xla/service/gpu/openxla/vm.cc +++ b/tensorflow/compiler/xla/service/gpu/openxla/vm.cc @@ -15,17 +15,33 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/openxla/vm.h" +#include #include +#include "absl/strings/str_format.h" + namespace xla::gpu::vm { -using iree::StatusOr; +//===----------------------------------------------------------------------===// +// Trace annotations derived from HLO operations +//===----------------------------------------------------------------------===// + +std::string ToScopedAnnotationName(const Trace& trace) { + return absl::StrFormat("Thunk:#hlo_op=%s#", trace.hlo_op); +} + +iree::StatusOr> TraceAPI::TraceCreate( + iree_string_view_t trace) { + auto ref = iree::vm::make_ref(); + ref->hlo_op = std::string(trace.data, trace.size); + return ref; +} //===----------------------------------------------------------------------===// // Helper functions to work with VM lists //===----------------------------------------------------------------------===// -StatusOr> GetI64Vector(const iree_vm_list_t* list) { +iree::StatusOr> GetI64Vector(const iree_vm_list_t* list) { iree_host_size_t size = iree_vm_list_size(list); std::vector values(size); for (iree_host_size_t i = 0; i < size; ++i) { @@ -44,3 +60,4 @@ StatusOr> GetI64Vector(const iree_vm_list_t* list) { //===----------------------------------------------------------------------===// IREE_VM_DEFINE_TYPE_ADAPTERS(execution_context, xla::gpu::vm::ExecutionContext); +IREE_VM_DEFINE_TYPE_ADAPTERS(trace, xla::gpu::vm::Trace); diff --git a/tensorflow/compiler/xla/service/gpu/openxla/vm.h b/tensorflow/compiler/xla/service/gpu/openxla/vm.h index d8cccd347bc01a..0c9d84ea919d84 100644 --- a/tensorflow/compiler/xla/service/gpu/openxla/vm.h +++ b/tensorflow/compiler/xla/service/gpu/openxla/vm.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_OPENXLA_VM_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_OPENXLA_VM_H_ +#include #include #include "third_party/iree/runtime/src/iree/vm/api.h" // IWYU pragma: keep @@ -43,6 +44,22 @@ struct ExecutionContext : public iree::vm::RefObject { const DebugOptions* debug_options; }; +//===----------------------------------------------------------------------===// +// Trace annotations derived from HLO operations +//===----------------------------------------------------------------------===// + +struct Trace : public iree::vm::RefObject { + std::string hlo_op; +}; + +std::string ToScopedAnnotationName(const Trace& trace); + +struct TraceAPI { + // Creates `xla_gpu.trace` value. + iree::StatusOr> TraceCreate( + iree_string_view_t trace); +}; + //===----------------------------------------------------------------------===// // Helper functions to work with VM lists //===----------------------------------------------------------------------===// @@ -58,5 +75,6 @@ iree::StatusOr> GetI64Vector(const iree_vm_list_t* list); IREE_VM_DECLARE_TYPE_ADAPTERS(execution_context, xla::gpu::vm::ExecutionContext); +IREE_VM_DECLARE_TYPE_ADAPTERS(trace, xla::gpu::vm::Trace); #endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_OPENXLA_VM_H_ From 97d1bc28ea2026a919f3915204bfdd04bfb9ea85 Mon Sep 17 00:00:00 2001 From: Christian Sigg Date: Fri, 28 Jul 2023 01:12:30 -0700 Subject: [PATCH 298/410] [XLA] NFC: simplify shape util's `ElementsIn()`, and some IWYU. PiperOrigin-RevId: 551772651 --- tensorflow/compiler/xla/pjrt/BUILD | 1 + .../xla/pjrt/tracked_tfrt_cpu_device_buffer_test.cc | 1 + tensorflow/compiler/xla/shape_util.h | 11 ++++++----- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/tensorflow/compiler/xla/pjrt/BUILD b/tensorflow/compiler/xla/pjrt/BUILD index a8ed4d2d3fa348..da94c2e6317b84 100644 --- a/tensorflow/compiler/xla/pjrt/BUILD +++ b/tensorflow/compiler/xla/pjrt/BUILD @@ -510,6 +510,7 @@ xla_cc_test( srcs = ["tracked_tfrt_cpu_device_buffer_test.cc"], deps = [ ":tracked_tfrt_cpu_device_buffer", + "//tensorflow/tsl/platform:env", "@com_google_googletest//:gtest_main", ], ) diff --git a/tensorflow/compiler/xla/pjrt/tracked_tfrt_cpu_device_buffer_test.cc b/tensorflow/compiler/xla/pjrt/tracked_tfrt_cpu_device_buffer_test.cc index 05d47449c6cd3d..008ca55f95e242 100644 --- a/tensorflow/compiler/xla/pjrt/tracked_tfrt_cpu_device_buffer_test.cc +++ b/tensorflow/compiler/xla/pjrt/tracked_tfrt_cpu_device_buffer_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "tensorflow/tsl/platform/threadpool.h" namespace xla { namespace { diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index a499299b31c751..e8c1df940121db 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -103,12 +104,12 @@ class ShapeUtil { static inline int64_t ElementsIn(const Shape& shape) { DCHECK(shape.IsArray()) << ShapeUtil::HumanString(shape); DCHECK_EQ(shape.dimensions_size(), shape.rank()); - if (shape.dimensions().size() == 1) { - return shape.dimensions()[0]; + if (shape.dimensions().empty()) { + return 1LL; } - return std::accumulate( - shape.dimensions().begin(), shape.dimensions().end(), 1LL, - std::multiplies()); + auto begin = shape.dimensions().begin(); + return std::accumulate(std::next(begin), shape.dimensions().end(), *begin, + std::multiplies()); } // As ElementsIn(), but recurses through tuples. From dcbb4e23fa124f4be119ab9259e4da8fa8d5a322 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 28 Jul 2023 02:02:13 -0700 Subject: [PATCH 299/410] Update GraphDef version to 1571. PiperOrigin-RevId: 551784172 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 6fbd1d7cefb1e0..206f30a344b80d 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 1570 // Updated: 2023/7/27 +#define TF_GRAPH_DEF_VERSION 1571 // Updated: 2023/7/28 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From e1089d6babe53e9d58c95ef4f224f2a872671044 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 28 Jul 2023 02:02:41 -0700 Subject: [PATCH 300/410] compat: Update forward compatibility horizon to 2023-07-28 PiperOrigin-RevId: 551784303 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index ed761a1484ff06..57a49860479f77 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 7, 27) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 7, 28) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From a868ab27b2dd6eb8fe8809cccebd2050bf13691f Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Fri, 28 Jul 2023 02:07:38 -0700 Subject: [PATCH 301/410] [xla:gpu] Add StreamExecutor backend Add @xla_gpu.kernel.dispatch API to XLA:GPU module to dispatch device kernels using StreamExecutor API instead of a CUDA HAL. XProf tracing and kernels caching is coming in the next PR, together with a cleanup to remove duplicated XlaApi from the compiler pass. PiperOrigin-RevId: 551785536 --- .../compiler/xla/debug_options_flags.cc | 6 + .../compiler/xla/mlir/backends/openxla/BUILD | 1 + .../mlir/backends/openxla/conversion/BUILD | 1 + .../conversion/convert_compiled_ops.cc | 246 ++++++++++++++++-- .../openxla/conversion/convert_compiled_ops.h | 8 +- .../openxla/conversion/convert_library_ops.cc | 36 ++- .../backends/openxla/ir/xla_gpu_dialect.td | 4 + .../openxla/transforms/convert_to_openxla.cc | 58 ++++- .../backends/openxla/transforms/passes.cc | 5 +- .../mlir/backends/openxla/transforms/passes.h | 28 +- .../backends/openxla/transforms/passes.td | 5 + ...penxla.mlir => fusion_to_openxla_hal.mlir} | 0 .../tests/fusion_to_openxla_se.mlir | 50 ++++ .../mlir/backends/openxla/xla-openxla-opt.cc | 7 + .../service/gpu/compile_module_to_llvm_ir.cc | 6 +- .../compiler/xla/service/gpu/openxla/BUILD | 42 ++- .../xla/service/gpu/openxla/compiler.cc | 6 +- .../xla/service/gpu/openxla/executable.cc | 13 +- .../xla/service/gpu/openxla/executable.h | 4 + .../compiler/xla/service/gpu/openxla/gemm.cc | 12 +- .../compiler/xla/service/gpu/openxla/gemm.h | 14 +- .../compiler/xla/service/gpu/openxla/hal.cc | 9 +- .../compiler/xla/service/gpu/openxla/hal.h | 2 +- .../xla/service/gpu/openxla/kernel.cc | 107 ++++++++ .../compiler/xla/service/gpu/openxla/kernel.h | 88 +++++++ .../xla/service/gpu/openxla/module.cc | 14 +- .../compiler/xla/service/gpu/openxla/vm.cc | 25 +- .../compiler/xla/service/gpu/openxla/vm.h | 25 +- tensorflow/compiler/xla/xla.proto | 9 +- 29 files changed, 725 insertions(+), 106 deletions(-) rename tensorflow/compiler/xla/mlir/backends/openxla/transforms/tests/{fusion_to_openxla.mlir => fusion_to_openxla_hal.mlir} (100%) create mode 100644 tensorflow/compiler/xla/mlir/backends/openxla/transforms/tests/fusion_to_openxla_se.mlir create mode 100644 tensorflow/compiler/xla/service/gpu/openxla/kernel.cc create mode 100644 tensorflow/compiler/xla/service/gpu/openxla/kernel.h diff --git a/tensorflow/compiler/xla/debug_options_flags.cc b/tensorflow/compiler/xla/debug_options_flags.cc index 576fbbbfbf2e8f..9684f416719d94 100644 --- a/tensorflow/compiler/xla/debug_options_flags.cc +++ b/tensorflow/compiler/xla/debug_options_flags.cc @@ -136,6 +136,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { // OpenXLA/IREE runtime flags. opts.set_xla_gpu_enable_openxla_runtime(false); + opts.set_xla_gpu_enable_openxla_hal(true); // Set 4GB space limit for redzone scratch allocator. opts.set_xla_gpu_redzone_scratch_max_megabytes(1LL << 12); @@ -969,6 +970,11 @@ void MakeDebugOptionsFlags(std::vector* flag_list, bool_setter_for(&DebugOptions::set_xla_gpu_enable_openxla_runtime), debug_options->xla_gpu_enable_openxla_runtime(), "Whether to enable OpenXLA runtime for XLA:GPU backend")); + flag_list->push_back( + tsl::Flag("xla_gpu_enable_openxla_hal", + bool_setter_for(&DebugOptions::set_xla_gpu_enable_openxla_hal), + debug_options->xla_gpu_enable_openxla_hal(), + "Whether to enable OpenXLA CUDA HAL for XLA:GPU backend")); flag_list->push_back(tsl::Flag( "xla_gpu_nccl_termination_timeout_seconds", int64_setter_for( diff --git a/tensorflow/compiler/xla/mlir/backends/openxla/BUILD b/tensorflow/compiler/xla/mlir/backends/openxla/BUILD index 886086e8efcb73..c381bbc7b9b27d 100644 --- a/tensorflow/compiler/xla/mlir/backends/openxla/BUILD +++ b/tensorflow/compiler/xla/mlir/backends/openxla/BUILD @@ -34,6 +34,7 @@ config_setting( # "//tensorflow/compiler/xla/mlir/backends/openxla/transforms:passes", # "//tensorflow/compiler/xla/mlir_hlo:lhlo", # "//tensorflow/compiler/xla/mlir_hlo:lhlo_gpu", +# "//tensorflow/tsl/platform:platform_port", # ], # ) # diff --git a/tensorflow/compiler/xla/mlir/backends/openxla/conversion/BUILD b/tensorflow/compiler/xla/mlir/backends/openxla/conversion/BUILD index 5a858ae8c8d232..b2a52ea3704b25 100644 --- a/tensorflow/compiler/xla/mlir/backends/openxla/conversion/BUILD +++ b/tensorflow/compiler/xla/mlir/backends/openxla/conversion/BUILD @@ -37,6 +37,7 @@ package( # "@llvm-project//mlir:Support", # "@llvm-project//mlir:TensorDialect", # "@llvm-project//mlir:Transforms", +# "//tensorflow/compiler/xla/mlir/backends/openxla/ir:xla_gpu", # "//tensorflow/compiler/xla/mlir_hlo:lhlo", # "//tensorflow/compiler/xla/service/gpu:gpu_executable", # "//tensorflow/compiler/xla/service/gpu:launch_dimensions", diff --git a/tensorflow/compiler/xla/mlir/backends/openxla/conversion/convert_compiled_ops.cc b/tensorflow/compiler/xla/mlir/backends/openxla/conversion/convert_compiled_ops.cc index 924da08d5ca78c..530ce7faa74b77 100644 --- a/tensorflow/compiler/xla/mlir/backends/openxla/conversion/convert_compiled_ops.cc +++ b/tensorflow/compiler/xla/mlir/backends/openxla/conversion/convert_compiled_ops.cc @@ -24,6 +24,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -47,6 +48,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/xla/mlir/backends/openxla/conversion/de_bufferization.h" +#include "tensorflow/compiler/xla/mlir/backends/openxla/ir/xla_gpu_dialect.h" #include "tensorflow/compiler/xla/mlir_hlo/lhlo/IR/lhlo_ops.h" #include "tensorflow/compiler/xla/service/gpu/copy_thunk.h" #include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h" @@ -60,6 +62,99 @@ namespace { using namespace mlir; // NOLINT using namespace mlir::iree_compiler; // NOLINT +using arith::ConstantIndexOp; +using arith::ConstantIntOp; + +//===----------------------------------------------------------------------===// +// Helper functions to build arguments to API functions. +//===----------------------------------------------------------------------===// + +Type getBufferViewListType(OpBuilder &b) { + return b.getType( + b.getType()); +} + +// Creates `iree_input.list` list. +TypedValue getBufferViewList( + ImplicitLocOpBuilder &b, ArrayRef> values) { + Type type = getBufferViewListType(b); + Value size = b.create(values.size()); + Value list = b.create(type, size); + + if (!values.empty()) b.create(list, size); + for (auto indexed : llvm::enumerate(values)) { + Value index = b.create(indexed.index()); + Value view = b.create( + b.getType(), indexed.value(), + /*source_dims=*/ValueRange()); + b.create(list, index, view); + } + + return list.cast>(); +} + +//===----------------------------------------------------------------------===// +// Helper class to set up OpenXLA runtime API declarations +//===----------------------------------------------------------------------===// + +// XLA GPU <-> StreamExecutor integration as API declarations. +class XlaGpuApi { + public: + // Imports `@xla_gpu.kernel.create` into the module. + func::FuncOp getCreateKernel(OpBuilder &b, ModuleOp module); + + // Imports `@xla_gpu.kernel.dispatch` into the module. + func::FuncOp getDispatchKernel(OpBuilder &b, ModuleOp module); + + private: + SymbolTable &symTable(ModuleOp module); + + func::FuncOp addDecl(OpBuilder &b, ModuleOp module, std::string_view name, + FunctionType function_type); + + SymbolTableCollection sym_table_; +}; + +func::FuncOp XlaGpuApi::getCreateKernel(OpBuilder &b, ModuleOp module) { + SmallVector args = { + b.getType(), // kernel_name + b.getI32Type(), // shared_memory_bytes + }; + SmallVector rets = {b.getType()}; + return addDecl(b, module, "xla_gpu.kernel.create", + FunctionType::get(b.getContext(), args, rets)); +} + +func::FuncOp XlaGpuApi::getDispatchKernel(OpBuilder &b, ModuleOp module) { + SmallVector args = {b.getType(), + b.getType(), getBufferViewListType(b)}; + args.append(6, b.getI32Type()); // workgroup_size / workload_size + return addDecl(b, module, "xla_gpu.kernel.dispatch", + FunctionType::get(b.getContext(), args, /*rets=*/{})); +} + +SymbolTable &XlaGpuApi::symTable(ModuleOp module) { + return sym_table_.getSymbolTable(module); +} + +func::FuncOp XlaGpuApi::addDecl(OpBuilder &b, ModuleOp module, + std::string_view name, + FunctionType function_type) { + if (auto fn = sym_table_.lookupNearestSymbolFrom( + module, b.getStringAttr(name))) + return fn; + + Location loc = UnknownLoc::get(module->getContext()); + + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToEnd(module.getBody()); + + auto fn = b.create(loc, name, function_type); + fn.setPrivate(); + symTable(module).insert(fn); + return fn; +} + //===----------------------------------------------------------------------===// // Helper functions to work with ThunkSequence //===----------------------------------------------------------------------===// @@ -265,14 +360,14 @@ IREE::Input::PipelineLayoutAttr getPipelineLayout(CompiledOp &op) { //===----------------------------------------------------------------------===// template -struct ConvertCompiledOp : public OpConversionPattern { +struct ConvertCompiledOpToHal : public OpConversionPattern { using OpAdaptor = typename OpConversionPattern::OpAdaptor; - ConvertCompiledOp(TypeConverter &converter, MLIRContext *ctx, - IREE::Input::ExecutableSourceOp executable_source, - ThunkSequence *thunk_sequence, - std::shared_ptr state, - std::shared_ptr ordinal) + ConvertCompiledOpToHal(TypeConverter &converter, MLIRContext *ctx, + IREE::Input::ExecutableSourceOp executable_source, + ThunkSequence *thunk_sequence, + std::shared_ptr state, + std::shared_ptr ordinal) : OpConversionPattern(converter, ctx), executable_source(executable_source.getSymNameAttr()), executable_source_body(&executable_source.getBody().front()), @@ -295,18 +390,20 @@ struct ConvertCompiledOp : public OpConversionPattern { }; template -LogicalResult ConvertCompiledOp::matchAndRewrite( +LogicalResult ConvertCompiledOpToHal::matchAndRewrite( OpTy op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { ImplicitLocOpBuilder b(op.getLoc(), rewriter); auto *block = op->getBlock(); - // Extract compiled fusion from the thunk sequence. - auto compiled_fusion = extractCompiledOp(op, thunk_sequence, rewriter); - if (failed(compiled_fusion)) return failure(); + // Extract compiled operation from the thunk sequence. + auto compiled_op = extractCompiledOp(op, thunk_sequence, rewriter); + if (failed(compiled_op)) + return rewriter.notifyMatchFailure( + op, "failed to extract device compilation result for an operation"); // Handle copy operations first, before handling kernel launch. - for (auto © : compiled_fusion->memcpy) { + for (auto © : compiled_op->memcpy) { auto src_memref = cast>(copy->source_value()); auto dst_memref = cast>(copy->destination_value()); @@ -344,13 +441,13 @@ LogicalResult ConvertCompiledOp::matchAndRewrite( } // Compiled fusion was a plain copy operation. - if (thunk_sequence != nullptr && !compiled_fusion->kernel) { + if (thunk_sequence != nullptr && !compiled_op->kernel) { rewriter.eraseOp(op); return success(); } // Get kernel launch parameters from a compiled fusion. - auto [kernel_name, dims] = getKernelLaunchParams(*compiled_fusion); + auto [kernel_name, dims] = getKernelLaunchParams(*compiled_op); SmallVector workgroup_size = { dims.thread_counts_per_block().x, @@ -377,7 +474,7 @@ LogicalResult ConvertCompiledOp::matchAndRewrite( executable_export = b.create( /*sym_name=*/b.getStringAttr(kernel_name), /*ordinal=*/b.getIndexAttr((*ordinal)++), - /*layout=*/getPipelineLayout(*compiled_fusion), + /*layout=*/getPipelineLayout(*compiled_op), /*workgroup_size=*/b.getIndexArrayAttr(workgroup_size), /*subgroup_size=*/nullptr, /*workgroup_local_memory=*/shmem ? b.getIndexAttr(shmem) : nullptr); @@ -392,12 +489,12 @@ LogicalResult ConvertCompiledOp::matchAndRewrite( return b.create(size); })); - auto dispatch_args = getDispatchArguments(*compiled_fusion, *state); + auto dispatch_args = getDispatchArguments(*compiled_op, *state); auto &memrefs = dispatch_args.first; auto &tensors = dispatch_args.second; // Prepare tied operands and corresponding result types. - SmallVector tied_operands = getTiedOperands(*compiled_fusion); + SmallVector tied_operands = getTiedOperands(*compiled_op); SmallVector results = llvm::to_vector(llvm::map_range(tied_operands, [&](int64_t idx) -> Type { return tensors[idx].getType(); @@ -422,10 +519,107 @@ LogicalResult ConvertCompiledOp::matchAndRewrite( } //===----------------------------------------------------------------------===// -// Converts lmhlo.fusion op to a iree_input.dispatch +// Converts compiled op to an XLA:GPU kernel dispatch API call +//===----------------------------------------------------------------------===// + +TypedValue getExecutionContext(Operation *op) { + auto func = op->getParentOfType(); + return func.getArguments().front().cast>(); +} + +template +struct ConvertCompiledOpToApiCall : public OpConversionPattern { + using OpAdaptor = typename OpConversionPattern::OpAdaptor; + + ConvertCompiledOpToApiCall(TypeConverter &converter, MLIRContext *ctx, + ThunkSequence *thunk_sequence, + std::shared_ptr state, + std::shared_ptr api) + : OpConversionPattern(converter, ctx), + thunk_sequence(thunk_sequence), + state(std::move(state)), + api(std::move(api)) {} + + LogicalResult matchAndRewrite( + OpTy op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; + + ThunkSequence *thunk_sequence; + std::shared_ptr state; + std::shared_ptr api; +}; + +template +LogicalResult ConvertCompiledOpToApiCall::matchAndRewrite( + OpTy op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + + auto module = op->template getParentOfType(); + + // Extract compiled operation from the thunk sequence. + auto compiled_op = extractCompiledOp(op, thunk_sequence, rewriter); + if (failed(compiled_op)) + return rewriter.notifyMatchFailure( + op, "failed to extract device compilation result for an operation"); + + // TODO(ezhulenev): Add support for memcpy thunks. + if (!compiled_op->memcpy.empty()) + return rewriter.notifyMatchFailure(op, "memcpy thunks are not supported"); + + // Get kernel launch parameters from a compiled fusion. + auto [kernel_name, dims] = getKernelLaunchParams(*compiled_op); + + // Create XLA:GPU device kernel (it will own loaded PTX/CUBIN at run time). + Value name = b.create( + b.getType(), + /*name=*/b.getStringAttr("kernel_name"), /*value=*/kernel_name, + /*alignment=*/nullptr, /*mime_type=*/nullptr); + Value shmem = b.create(dims.SharedMemBytes(), 32); + + func::FuncOp create_kernel = api->getCreateKernel(b, module); + Value kernel = b.create(create_kernel.getSymName(), + create_kernel.getResultTypes(), + ValueRange({name, shmem})) + .getResult(0); + + // Prepare arguments for kernel dispatch. + SmallVector workgroup_size = { + b.create(dims.thread_counts_per_block().x, 32), + b.create(dims.thread_counts_per_block().y, 32), + b.create(dims.thread_counts_per_block().z, 32), + }; + + SmallVector workload_size = { + b.create(dims.block_counts().x, 32), + b.create(dims.block_counts().y, 32), + b.create(dims.block_counts().z, 32), + }; + + auto dispatch_args = getDispatchArguments(*compiled_op, *state); + auto &tensors = dispatch_args.second; + + Value buffer_views = getBufferViewList(b, tensors); + + // Prepare arguments for the kernel dispatch API call. + SmallVector args = {getExecutionContext(op), kernel, buffer_views}; + args.append(workgroup_size.begin(), workgroup_size.end()); + args.append(workload_size.begin(), workload_size.end()); + + func::FuncOp dispatch_kernel = api->getDispatchKernel(b, module); + // TODO(ezhulenev): Should we import buffer view back and update remapping? + b.create(dispatch_kernel.getSymName(), + dispatch_kernel.getResultTypes(), args); + + rewriter.eraseOp(op); + return success(); +} + +//===----------------------------------------------------------------------===// +// Converts lmhlo.fusion op to HAL / XLA:GPU runtime //===----------------------------------------------------------------------===// -using ConvertFusionOp = ConvertCompiledOp; +using ConvertFusionOpToHal = ConvertCompiledOpToHal; +using ConvertFusionOpToApiCall = ConvertCompiledOpToApiCall; // Returns Fusion kernel pipeline layout (ABI) inferred from the fusion // operation body looking at tensor<->memref conversions. @@ -481,10 +675,10 @@ SmallVector getTiedOperands(lmhlo::FusionOp op) { } //===----------------------------------------------------------------------===// -// Converts lmhlo.sort op to a iree_input.dispatch +// Converts lmhlo.sort op to to HAL / XLA:GPU runtime //===----------------------------------------------------------------------===// -using ConvertSortOp = ConvertCompiledOp; +using ConvertSortOpToHal = ConvertCompiledOpToHal; IREE::Input::PipelineLayoutAttr getPipelineLayout(lmhlo::SortOp op) { auto n_args = op.getInputs().size(); @@ -580,10 +774,20 @@ void populateCompiledOpsConversionPatterns( IREE::Input::ExecutableSourceOp executable_source, ThunkSequence *thunk_sequence, std::shared_ptr state) { auto *ctx = patterns.getContext(); - patterns.insert( + patterns.insert( converter, ctx, executable_source, thunk_sequence, state, /*ordinal=*/std::make_shared(0)); patterns.insert(converter, ctx, state); } +void populateCompiledOpsConversionPatterns( + mlir::RewritePatternSet &patterns, mlir::TypeConverter &converter, + ThunkSequence *thunk_sequence, std::shared_ptr state) { + auto api = std::make_shared(); + auto *ctx = patterns.getContext(); + patterns.insert(converter, ctx, thunk_sequence, + state, api); + patterns.insert(converter, ctx, state); +} + } // namespace xla::gpu diff --git a/tensorflow/compiler/xla/mlir/backends/openxla/conversion/convert_compiled_ops.h b/tensorflow/compiler/xla/mlir/backends/openxla/conversion/convert_compiled_ops.h index b04e5511a3526a..339e58c9c80681 100644 --- a/tensorflow/compiler/xla/mlir/backends/openxla/conversion/convert_compiled_ops.h +++ b/tensorflow/compiler/xla/mlir/backends/openxla/conversion/convert_compiled_ops.h @@ -29,13 +29,19 @@ namespace gpu { // Forward declare. class ThunkSequence; -// Appends patterns to convert LMHLO operations compiled to kernel thunks to an +// Appends patterns to convert LMHLO operations compiled to kernel thunks to // IREEInput executable export and dispatch operations. void populateCompiledOpsConversionPatterns( mlir::RewritePatternSet &patterns, mlir::TypeConverter &converter, mlir::iree_compiler::IREE::Input::ExecutableSourceOp executable_source, ThunkSequence *thunk_sequence, std::shared_ptr state); +// Appends patterns to convert LMHLO operations compiled to kernel thunks to +// XLA:GPU runtime API calls. +void populateCompiledOpsConversionPatterns( + mlir::RewritePatternSet &patterns, mlir::TypeConverter &converter, + ThunkSequence *thunk_sequence, std::shared_ptr state); + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/mlir/backends/openxla/conversion/convert_library_ops.cc b/tensorflow/compiler/xla/mlir/backends/openxla/conversion/convert_library_ops.cc index 1f17b3c1ce71e2..d88be76b6b8c09 100644 --- a/tensorflow/compiler/xla/mlir/backends/openxla/conversion/convert_library_ops.cc +++ b/tensorflow/compiler/xla/mlir/backends/openxla/conversion/convert_library_ops.cc @@ -80,24 +80,24 @@ class XlaGpuApi { SymbolTableCollection sym_table_; }; -Type getI64ListType(MLIRContext *ctx) { - return IREE::Input::ListType::get(ctx, IntegerType::get(ctx, 64)); +Type getI32ListType(OpBuilder &b) { + return b.getType(b.getI32Type()); } func::FuncOp XlaGpuApi::getCreateDotDimensionsNumbers(OpBuilder &b, ModuleOp module) { - auto i64_list = getI64ListType(b.getContext()); - SmallVector args = {/*lhs_batch_dimensions=*/i64_list, - /*rhs_batch_dimensions=*/i64_list, - /*lhs_contracting_dimensions=*/i64_list, - /*rhs_contracting_dimensions=*/i64_list}; + auto i32_list = getI32ListType(b); + SmallVector args = {/*lhs_batch_dimensions=*/i32_list, + /*rhs_batch_dimensions=*/i32_list, + /*lhs_contracting_dimensions=*/i32_list, + /*rhs_contracting_dimensions=*/i32_list}; SmallVector rets = {b.getType()}; return addDecl(b, module, "xla_gpu.dot_dimension_numbers.create", FunctionType::get(b.getContext(), args, rets)); } func::FuncOp XlaGpuApi::getCreateDotPrecision(OpBuilder &b, ModuleOp module) { - SmallVector args = {getI64ListType(b.getContext())}; + SmallVector args = {getI32ListType(b)}; SmallVector rets = {b.getType()}; return addDecl(b, module, "xla_gpu.dot_precision.create", FunctionType::get(b.getContext(), args, rets)); @@ -161,18 +161,16 @@ func::FuncOp XlaGpuApi::addDecl(OpBuilder &b, ModuleOp module, // Helper functions to build arguments to API functions. //===----------------------------------------------------------------------===// -// Creates `iree_input.list` list from the given values. -TypedValue getI64List(ImplicitLocOpBuilder &b, +// Creates `iree_input.list` list from the given values. +TypedValue getI32List(ImplicitLocOpBuilder &b, ArrayRef values) { - MLIRContext *ctx = b.getContext(); - Value size = b.create(values.size()); - Value list = b.create(getI64ListType(ctx), size); + Value list = b.create(getI32ListType(b), size); if (!values.empty()) b.create(list, size); for (auto indexed : llvm::enumerate(values)) { Value index = b.create(indexed.index()); - Value value = b.create(indexed.value(), 64); + Value value = b.create(indexed.value(), 32); b.create(list, index, value); } @@ -187,10 +185,10 @@ TypedValue getDotDimensionNumbers( XlaGpuApi &api, ImplicitLocOpBuilder &b, ModuleOp module, lmhlo_gpu::GEMMOp op) { mhlo::DotDimensionNumbersAttr attr = op.getDotDimensionNumbersAttr(); - SmallVector args = {getI64List(b, attr.getLhsBatchingDimensions()), - getI64List(b, attr.getRhsBatchingDimensions()), - getI64List(b, attr.getLhsContractingDimensions()), - getI64List(b, attr.getRhsContractingDimensions())}; + SmallVector args = {getI32List(b, attr.getLhsBatchingDimensions()), + getI32List(b, attr.getRhsBatchingDimensions()), + getI32List(b, attr.getLhsContractingDimensions()), + getI32List(b, attr.getRhsContractingDimensions())}; auto api_func = api.getCreateDotDimensionsNumbers(b, module); auto call = b.create(api_func.getSymName(), @@ -209,7 +207,7 @@ TypedValue getDotPrecision(XlaGpuApi &api, return static_cast(value); })); - SmallVector args = {getI64List(b, precision)}; + SmallVector args = {getI32List(b, precision)}; auto api_func = api.getCreateDotPrecision(b, module); auto call = b.create(api_func.getSymName(), diff --git a/tensorflow/compiler/xla/mlir/backends/openxla/ir/xla_gpu_dialect.td b/tensorflow/compiler/xla/mlir/backends/openxla/ir/xla_gpu_dialect.td index bb75d411877183..a60a4874aa41d6 100644 --- a/tensorflow/compiler/xla/mlir/backends/openxla/ir/xla_gpu_dialect.td +++ b/tensorflow/compiler/xla/mlir/backends/openxla/ir/xla_gpu_dialect.td @@ -65,6 +65,10 @@ def DotPrecisionType : XLA_GPU_Type<"DotPrecision", "dot_precision"> { let summary = "Precision for dot operation"; } +def KernelType : XLA_GPU_Type<"Kernel", "kernel"> { + let summary = "XLA:GPU device kernel"; +} + def TraceType : XLA_GPU_Type<"Trace", "trace"> { let summary = "XLA:GPU trace annotation"; } diff --git a/tensorflow/compiler/xla/mlir/backends/openxla/transforms/convert_to_openxla.cc b/tensorflow/compiler/xla/mlir/backends/openxla/transforms/convert_to_openxla.cc index ed343ccf6ac690..6a1cebaefca326 100644 --- a/tensorflow/compiler/xla/mlir/backends/openxla/transforms/convert_to_openxla.cc +++ b/tensorflow/compiler/xla/mlir/backends/openxla/transforms/convert_to_openxla.cc @@ -14,6 +14,9 @@ limitations under the License. ==============================================================================*/ #include +#include +#include +#include #include #include "third_party/iree/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/Input/InputDialect.h" @@ -30,12 +33,14 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/xla/mlir/backends/openxla/conversion/convert_compiled_ops.h" #include "tensorflow/compiler/xla/mlir/backends/openxla/conversion/convert_library_ops.h" #include "tensorflow/compiler/xla/mlir/backends/openxla/conversion/convert_memref_ops.h" #include "tensorflow/compiler/xla/mlir/backends/openxla/conversion/convert_while_op.h" #include "tensorflow/compiler/xla/mlir/backends/openxla/ir/xla_gpu_dialect.h" +#include "tensorflow/compiler/xla/mlir/backends/openxla/transforms/passes.h" #include "tensorflow/compiler/xla/mlir_hlo/lhlo/IR/lhlo_ops.h" #define GEN_PASS_DECL_CONVERTTOOPENXLA @@ -76,6 +81,24 @@ IREE::Input::ExecutableSourceOp createXlaExecutableSource(ModuleOp module) { //===----------------------------------------------------------------------===// +static std::string toString(OpenXlaBackend backend) { + switch (backend) { + case OpenXlaBackend::kHAL: + return "hal"; + case OpenXlaBackend::kStreamExecutor: + return "streamexecutor"; + } +} + +static FailureOr parseOpenXlaBackend(std::string_view str) { + if (str == "hal") { + return OpenXlaBackend::kHAL; + } else if (str == "streamexecutor") { + return OpenXlaBackend::kStreamExecutor; + } + return failure(); +} + // Adds `xla_gpu.execution_context` argument to all functions in the module. static void addExecutionContextArgument(ModuleOp module) { MLIRContext *ctx = module.getContext(); @@ -91,18 +114,27 @@ static void addExecutionContextArgument(ModuleOp module) { class ConvertToOpenXlaPass : public ::impl::ConvertToOpenXlaBase { public: - explicit ConvertToOpenXlaPass(ThunkSequence *thunk_sequence) - : thunk_sequence_(thunk_sequence) {} + ConvertToOpenXlaPass(ThunkSequence *thunk_sequence, + std::optional backend) + : thunk_sequence_(thunk_sequence) { + if (backend.has_value()) { + this->backend = toString(*backend); + } + } void runOnOperation() override { auto *ctx = &getContext(); + // Lower compiled operations to HAL or SE runtime. + auto compiled_ops_backend = parseOpenXlaBackend(backend); + if (failed(compiled_ops_backend)) { + getOperation().emitError() << "unsupported backend: " << backend; + return signalPassFailure(); + } + // Add execution context argument to all functions in the module. addExecutionContextArgument(getOperation()); - // Add all pre-compiled XLA fusions to the module as an executable source. - auto executable_source = createXlaExecutableSource(getOperation()); - TypeConverter converter; converter.addConversion([](Type type) { return type; }); @@ -129,8 +161,16 @@ class ConvertToOpenXlaPass populateLibraryOpsConversionPatterns(patterns, converter, state); populateMemrefConversionPatterns(patterns, converter, state); populateWhileOpConversionPatterns(patterns, converter, state); - populateCompiledOpsConversionPatterns( - patterns, converter, executable_source, thunk_sequence_, state); + + if (*compiled_ops_backend == OpenXlaBackend::kHAL) { + auto executable_source = createXlaExecutableSource(getOperation()); + populateCompiledOpsConversionPatterns( + patterns, converter, executable_source, thunk_sequence_, state); + + } else if (*compiled_ops_backend == OpenXlaBackend::kStreamExecutor) { + populateCompiledOpsConversionPatterns(patterns, converter, + thunk_sequence_, state); + } // Ensure all HLO and memref operations get lowered to IREEInput and OpenXLA // runtime. For this we have to de-bufferize the IR and correctly tie @@ -157,8 +197,8 @@ class ConvertToOpenXlaPass }; std::unique_ptr> createConvertToOpenXlaPass( - ThunkSequence *thunk_sequence) { - return std::make_unique(thunk_sequence); + ThunkSequence *thunk_sequence, std::optional backend) { + return std::make_unique(thunk_sequence, backend); } } // namespace xla::gpu diff --git a/tensorflow/compiler/xla/mlir/backends/openxla/transforms/passes.cc b/tensorflow/compiler/xla/mlir/backends/openxla/transforms/passes.cc index 935dc3d233c385..47e73ff20e1cb2 100644 --- a/tensorflow/compiler/xla/mlir/backends/openxla/transforms/passes.cc +++ b/tensorflow/compiler/xla/mlir/backends/openxla/transforms/passes.cc @@ -32,8 +32,9 @@ using namespace mlir; // NOLINT void registerOpenXlaPases() { ::impl::registerPasses(); } void populateOpenXlaRuntimePasses(mlir::OpPassManager& pm, - ThunkSequence* thunk_sequence) { - pm.addPass(createConvertToOpenXlaPass(thunk_sequence)); + ThunkSequence* thunk_sequence, + OpenXlaBackend backend) { + pm.addPass(createConvertToOpenXlaPass(thunk_sequence, backend)); pm.addPass(createCanonicalizerPass()); } diff --git a/tensorflow/compiler/xla/mlir/backends/openxla/transforms/passes.h b/tensorflow/compiler/xla/mlir/backends/openxla/transforms/passes.h index 723cccd9eb86bf..219e61345c79c0 100644 --- a/tensorflow/compiler/xla/mlir/backends/openxla/transforms/passes.h +++ b/tensorflow/compiler/xla/mlir/backends/openxla/transforms/passes.h @@ -16,6 +16,18 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_MLIR_BACKENDS_OPENXLA_TRANSFORMS_PASSES_H_ #define TENSORFLOW_COMPILER_XLA_MLIR_BACKENDS_OPENXLA_TRANSFORMS_PASSES_H_ +namespace xla::gpu { + +class ThunkSequence; // forward declare + +// We have two options for lowering executing compiled device kernels: +// (1) Use IREEs HAL, export all device kernels as executable source, and +// dispatch them using `iree_input.dispatch` (later lowered to Flow) +// (2) Use XLA:GPU StreamExecutor APIs to load and dispatch device kernels +enum class OpenXlaBackend { kHAL, kStreamExecutor }; + +} // namespace xla::gpu + //===----------------------------------------------------------------------===// // TODO(ezhulenev): We currently do not build with OpenXLA runtime in open // source because we do not have bazel dependency from XLA to IREE. @@ -27,10 +39,8 @@ class OpPassManager; } // namespace mlir namespace xla::gpu { -class ThunkSequence; -inline void populateOpenXlaRuntimePasses(mlir::OpPassManager&, ThunkSequence*) { -} - +inline void populateOpenXlaRuntimePasses(mlir::OpPassManager&, ThunkSequence*, + OpenXlaBackend backend) {} inline void registerOpenXlaPases() {} } // namespace xla::gpu @@ -39,26 +49,28 @@ inline void registerOpenXlaPases() {} //===----------------------------------------------------------------------===// #include +#include #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project namespace xla::gpu { -class ThunkSequence; // forward declare - // Populate passes that lower MLIR modules from a combination of LMHLO and // LMHLO_GPU dialects to the OpenXLA runtime (aka IREE input dialects + OpenXLA // custom calls implementing library integration). void populateOpenXlaRuntimePasses(mlir::OpPassManager& pm, - ThunkSequence* thunk_sequence); + ThunkSequence* thunk_sequence, + OpenXlaBackend backend); //===----------------------------------------------------------------------===// // Conversion from LMHLO dialects to OpenXLA runtime //===----------------------------------------------------------------------===// std::unique_ptr > -createConvertToOpenXlaPass(ThunkSequence* thunk_sequence = nullptr); +createConvertToOpenXlaPass( + ThunkSequence* thunk_sequence = nullptr, + std::optional backend = std::nullopt); //===----------------------------------------------------------------------===// // OpenXLA passes registration diff --git a/tensorflow/compiler/xla/mlir/backends/openxla/transforms/passes.td b/tensorflow/compiler/xla/mlir/backends/openxla/transforms/passes.td index b7b33e34e29b3e..68af7daaef7da9 100644 --- a/tensorflow/compiler/xla/mlir/backends/openxla/transforms/passes.td +++ b/tensorflow/compiler/xla/mlir/backends/openxla/transforms/passes.td @@ -32,6 +32,11 @@ def ConvertToOpenXla : Pass<"xla-gpu-to-openxla", "mlir::ModuleOp"> { "mlir::scf::SCFDialect", "xla::gpu::XlaGpuDialect", ]; + + let options = [ + Option<"backend", "backend", "std::string", "\"hal\"", + "The backend for launching kernels: hal or streamexecutor">, + ]; } #endif // XLA_GPU_OPENXLA_PASSES diff --git a/tensorflow/compiler/xla/mlir/backends/openxla/transforms/tests/fusion_to_openxla.mlir b/tensorflow/compiler/xla/mlir/backends/openxla/transforms/tests/fusion_to_openxla_hal.mlir similarity index 100% rename from tensorflow/compiler/xla/mlir/backends/openxla/transforms/tests/fusion_to_openxla.mlir rename to tensorflow/compiler/xla/mlir/backends/openxla/transforms/tests/fusion_to_openxla_hal.mlir diff --git a/tensorflow/compiler/xla/mlir/backends/openxla/transforms/tests/fusion_to_openxla_se.mlir b/tensorflow/compiler/xla/mlir/backends/openxla/transforms/tests/fusion_to_openxla_se.mlir new file mode 100644 index 00000000000000..2e0322940167eb --- /dev/null +++ b/tensorflow/compiler/xla/mlir/backends/openxla/transforms/tests/fusion_to_openxla_se.mlir @@ -0,0 +1,50 @@ +// RUN: export MSAN_OPTIONS=intercept_strpbrk=0 +// RUN: xla-openxla-opt %s --xla-gpu-to-openxla=backend=streamexecutor \ +// RUN: --split-input-file \ +// RUN: | FileCheck %s + +func.func @fusion( + %arg0: memref<12xi8>, %arg1: memref<12xi8>, + %arg2: memref<12xi8> {lmhlo.output_index = dense<> : tensor<0xi64>} +) { + %c0 = arith.constant 0 : index + %view0 = memref.view %arg0[%c0][] : memref<12xi8> to memref<3xf32> + %view1 = memref.view %arg1[%c0][] : memref<12xi8> to memref<3xf32> + %view2 = memref.view %arg2[%c0][] : memref<12xi8> to memref<3xf32> + "lmhlo.fusion"() ({ + %0 = bufferization.to_tensor %view0 : memref<3xf32> + %1 = bufferization.to_tensor %view1 : memref<3xf32> + %2 = mhlo.add %0, %1 : tensor<3xf32> + memref.tensor_store %2, %view2 : memref<3xf32> + "lmhlo.terminator"() : () -> () + }) : () -> () + "lmhlo.terminator"() : () -> () +} + +// CHECK-LABEL: func @fusion( +// CHECK: %[[CTX:.*]]: !xla_gpu.execution_context, +// CHECK: %[[ARG0:.*]]: tensor<12xi8>, %[[ARG1:.*]]: tensor<12xi8>, +// CHECK: %[[ARG2:.*]]: tensor<12xi8> {lmhlo.output_index = {{.*}}} +// CHECK: ) { +// CHECK: %[[TENSOR0:.*]] = iree_input.tensor.import {{.*}} -> tensor<3xf32> +// CHECK: %[[TENSOR1:.*]] = iree_input.tensor.import {{.*}} -> tensor<3xf32> +// CHECK: %[[TENSOR2:.*]] = iree_input.tensor.import {{.*}} -> tensor<3xf32> +// +// CHECK: %[[KERNEL:.*]] = call @xla_gpu.kernel.create +// +// CHECK: %[[C3:.*]] = arith.constant 3 : index +// CHECK: %[[ARGS:.*]] = iree_input.list.create %[[C3]] +// CHECK-SAME: !iree_input.list +// +// CHECK: %[[BUF0:.*]] = iree_input.tensor.export %[[TENSOR0]] +// CHECK: iree_input.list.set {{.*}}, %[[BUF0]] +// CHECK: %[[BUF1:.*]] = iree_input.tensor.export %[[TENSOR1]] +// CHECK: iree_input.list.set {{.*}}, %[[BUF1]] +// CHECK: %[[BUF2:.*]] = iree_input.tensor.export %[[TENSOR2]] +// CHECK: iree_input.list.set {{.*}}, %[[BUF2]] +// +// CHECK: call @xla_gpu.kernel.dispatch(%[[CTX]], %[[KERNEL]], %[[ARGS]] +// CHECK-SAME: (!xla_gpu.execution_context, !xla_gpu.kernel, +// CHECK-SAME: !iree_input.list, i32, i32, i32, +// CHECK-SAME: i32, i32, i32) -> () +// CHECK: } diff --git a/tensorflow/compiler/xla/mlir/backends/openxla/xla-openxla-opt.cc b/tensorflow/compiler/xla/mlir/backends/openxla/xla-openxla-opt.cc index 843cd7a0b141dd..495714dd5c5e6a 100644 --- a/tensorflow/compiler/xla/mlir/backends/openxla/xla-openxla-opt.cc +++ b/tensorflow/compiler/xla/mlir/backends/openxla/xla-openxla-opt.cc @@ -20,10 +20,17 @@ limitations under the License. #include "tensorflow/compiler/xla/mlir/backends/openxla/transforms/passes.h" #include "tensorflow/compiler/xla/mlir_hlo/lhlo/IR/lhlo_ops.h" #include "tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.h" +#include "tensorflow/tsl/platform/init_main.h" using namespace mlir; // NOLINT int main(int argc, char **argv) { + // Initialize the process. On OSS this is a no-op. + // Note: we do not parse any flags here; all flags are parsed by + // `MlirOptMain` below. + int32_t argc1 = 1; + tsl::port::InitMain("Xla OpenXLA Pass Driver", &argc1, &argv); + DialectRegistry registry; registry.insertgetName(), mlir::PassManager::Nesting::Implicit); - populateOpenXlaRuntimePasses(pm, thunk_sequence); + + OpenXlaBackend backend = debug_options.xla_gpu_enable_openxla_hal() + ? OpenXlaBackend::kHAL + : OpenXlaBackend::kStreamExecutor; + populateOpenXlaRuntimePasses(pm, thunk_sequence, backend); if (pm.run(module).failed()) { return InternalError("Failed to lower LMHLO to OpenXLA input dialects."); diff --git a/tensorflow/compiler/xla/service/gpu/openxla/BUILD b/tensorflow/compiler/xla/service/gpu/openxla/BUILD index aa7ef22c441f6a..30273d26a86460 100644 --- a/tensorflow/compiler/xla/service/gpu/openxla/BUILD +++ b/tensorflow/compiler/xla/service/gpu/openxla/BUILD @@ -56,6 +56,9 @@ config_setting( # compatible_with = [], # defines = if_not_openxla(["XLA_DISABLE_OPENXLA_RUNTIME=1"]), # deps = [ +# ":compiler", +# ":module", +# ":vm", # "@com_google_absl//absl/log", # "@com_google_absl//absl/log:check", # "@com_google_absl//absl/strings", @@ -69,9 +72,6 @@ config_setting( # "//tensorflow/compiler/xla/service/gpu:buffer_allocations", # "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", # ] + if_openxla([ -# ":compiler", -# ":module", -# ":vm", # "//third_party/iree/runtime/src/iree/base", # "//third_party/iree/runtime/src/iree/hal", # "//third_party/iree/runtime/src/iree/hal/drivers/cuda", @@ -88,6 +88,9 @@ config_setting( # hdrs = if_openxla(["gemm.h"]), # compatible_with = [], # deps = [ +# ":hal", +# ":vm", +# "@com_google_absl//absl/container:inlined_vector", # "@com_google_absl//absl/log", # "//tensorflow/compiler/xla:status", # "//tensorflow/compiler/xla:statusor", @@ -96,8 +99,29 @@ config_setting( # "//tensorflow/compiler/xla/service/gpu:matmul_utils", # "//tensorflow/tsl/profiler/lib:scoped_annotation", # ] + if_openxla([ +# "//third_party/iree/runtime/src/iree/hal", +# "//third_party/iree/runtime/src/iree/vm", +# ]), +# ) +# +# cc_library( +# name = "kernel", +# srcs = if_openxla(["kernel.cc"]), +# hdrs = if_openxla(["kernel.h"]), +# compatible_with = [], +# deps = [ # ":hal", # ":vm", +# "@com_google_absl//absl/container:inlined_vector", +# "@com_google_absl//absl/log", +# "//tensorflow/compiler/xla:status", +# "//tensorflow/compiler/xla:statusor", +# "//tensorflow/compiler/xla/service:executable", +# "//tensorflow/compiler/xla/service/gpu:launch_dimensions", +# "//tensorflow/compiler/xla/service/gpu:stream_executor_util", +# "//tensorflow/compiler/xla/stream_executor:device_memory", +# "//tensorflow/tsl/profiler/lib:scoped_annotation", +# ] + if_openxla([ # "//third_party/iree/runtime/src/iree/hal", # "//third_party/iree/runtime/src/iree/vm", # ]), @@ -109,11 +133,12 @@ config_setting( # hdrs = if_openxla(["module.h"]), # compatible_with = [], # deps = [ -# "@com_google_absl//absl/log", -# ] + if_openxla([ # ":gemm", # ":hal", +# ":kernel", # ":vm", +# "@com_google_absl//absl/log", +# ] + if_openxla([ # "//third_party/iree/runtime/src/iree/base", # "//third_party/iree/runtime/src/iree/hal", # "//third_party/iree/runtime/src/iree/modules/hal:types", @@ -142,9 +167,14 @@ config_setting( # srcs = if_openxla(["vm.cc"]), # hdrs = if_openxla(["vm.h"]), # compatible_with = [], -# deps = ["@com_google_absl//absl/strings:str_format"] + if_openxla([ +# deps = [ +# "@com_google_absl//absl/container:inlined_vector", +# "@com_google_absl//absl/strings:str_format", +# ] + if_openxla([ # "//third_party/iree/runtime/src/iree/vm", # "//third_party/iree/runtime/src/iree/vm:cc", +# "//third_party/iree/runtime/src/iree/hal", +# "//third_party/iree/runtime/src/iree/modules/hal:types", # ]), # ) # diff --git a/tensorflow/compiler/xla/service/gpu/openxla/compiler.cc b/tensorflow/compiler/xla/service/gpu/openxla/compiler.cc index 1077534141b66f..d107c4d78bef27 100644 --- a/tensorflow/compiler/xla/service/gpu/openxla/compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/openxla/compiler.cc @@ -38,7 +38,6 @@ limitations under the License. #include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/xla/status.h" -#include "tensorflow/compiler/xla/util.h" #include "tensorflow/tsl/platform/platform.h" #if defined(PLATFORM_GOOGLE) @@ -225,7 +224,10 @@ Status BindXlaDeviceKernels(mlir::ModuleOp module, std::string_view asm_text, auto src = sym_table.lookup("xla.module.ptx"); - if (!src) return InternalError("failed to find XLA executable source"); + + // If we are running with StreamExecutor backend we might not have executable + // source for kernels. + if (!src) return OkStatus(); // Bind XLA device kernels to an executable source. auto objects = getExecutableObjects(getExecutableTarget(ctx), diff --git a/tensorflow/compiler/xla/service/gpu/openxla/executable.cc b/tensorflow/compiler/xla/service/gpu/openxla/executable.cc index d242bc969e12e5..06f4b96f277bf4 100644 --- a/tensorflow/compiler/xla/service/gpu/openxla/executable.cc +++ b/tensorflow/compiler/xla/service/gpu/openxla/executable.cc @@ -168,13 +168,14 @@ OpenXlaRuntimeExecutable::Create(std::unique_ptr program, return std::unique_ptr(new OpenXlaRuntimeExecutable( std::move(device), std::move(bytecode), std::move(program->buffer_sizes), - std::move(program->debug_options), context, instance, std::move(modules), - function)); + std::move(program->debug_options), asm_text, binary, context, instance, + std::move(modules), function)); } OpenXlaRuntimeExecutable::OpenXlaRuntimeExecutable( std::unique_ptr device, std::unique_ptr bytecode, std::vector buffer_sizes, DebugOptions debug_options, + std::string_view asm_text, absl::Span binary, iree::vm::ref context, iree::vm::ref instance, std::unique_ptr> modules, @@ -183,6 +184,8 @@ OpenXlaRuntimeExecutable::OpenXlaRuntimeExecutable( bytecode_(std::move(bytecode)), buffer_sizes_(std::move(buffer_sizes)), debug_options_(std::move(debug_options)), + asm_text_(asm_text), + binary_(binary), context_(std::move(context)), instance_(std::move(instance)), modules_(std::move(modules)), @@ -212,8 +215,10 @@ Status OpenXlaRuntimeExecutable::Execute( &inputs)); // Add execution context as the first arguments. - auto execution_context = - iree::vm::make_ref(run_options, &debug_options_); + auto execution_context = iree::vm::make_ref( + run_options, &debug_options_, + vm::ExecutionContext::ExecutableSource{asm_text_, binary_}); + // TODO(ezhulenev): Can we do ref_move here? IREE_CHECK_OK(iree_vm_list_push_ref_retain(inputs.get(), execution_context)); diff --git a/tensorflow/compiler/xla/service/gpu/openxla/executable.h b/tensorflow/compiler/xla/service/gpu/openxla/executable.h index d34d9760a42f1a..7a7f247e02ca93 100644 --- a/tensorflow/compiler/xla/service/gpu/openxla/executable.h +++ b/tensorflow/compiler/xla/service/gpu/openxla/executable.h @@ -134,6 +134,7 @@ class OpenXlaRuntimeExecutable { OpenXlaRuntimeExecutable( std::unique_ptr device, std::unique_ptr bytecode, std::vector buffer_sizes, DebugOptions debug_options, + std::string_view asm_text, absl::Span binary, iree::vm::ref context, iree::vm::ref instance, std::unique_ptr> modules, @@ -150,6 +151,9 @@ class OpenXlaRuntimeExecutable { std::vector buffer_sizes_; const DebugOptions debug_options_; + std::string_view asm_text_; + absl::Span binary_; + // TODO(ezhulenev): VM context and instance should be shared between multiple // executables. Also HAL module should be loaded just once. This has to be // fixed together with efficient device sharing, because HAL VM module diff --git a/tensorflow/compiler/xla/service/gpu/openxla/gemm.cc b/tensorflow/compiler/xla/service/gpu/openxla/gemm.cc index f40c00e1f0487f..4f5abe0335dd44 100644 --- a/tensorflow/compiler/xla/service/gpu/openxla/gemm.cc +++ b/tensorflow/compiler/xla/service/gpu/openxla/gemm.cc @@ -33,9 +33,9 @@ using tsl::profiler::ScopedAnnotation; // XLA:GPU gemm API //===-----------------------------------------------------------------------===/ -static StatusOr GetGemmConfig(const iree_hal_buffer_view_t* lhs, - const iree_hal_buffer_view_t* rhs, - const iree_hal_buffer_view_t* out, +static StatusOr GetGemmConfig(iree_hal_buffer_view_t* lhs, + iree_hal_buffer_view_t* rhs, + iree_hal_buffer_view_t* out, const vm::DotConfig& config) { int64_t compute_precision = config.dot_precision->precision.empty() @@ -43,11 +43,11 @@ static StatusOr GetGemmConfig(const iree_hal_buffer_view_t* lhs, : *absl::c_max_element(config.dot_precision->precision); return GemmConfig::For( - GetShape(lhs), config.dot_dimension_numbers->lhs_batch_dims, + GetBufferShape(lhs), config.dot_dimension_numbers->lhs_batch_dims, config.dot_dimension_numbers->lhs_contracting_dims, // lhs - GetShape(rhs), config.dot_dimension_numbers->rhs_batch_dims, + GetBufferShape(rhs), config.dot_dimension_numbers->rhs_batch_dims, config.dot_dimension_numbers->rhs_contracting_dims, // rhs - GetShape(out), // out + GetBufferShape(out), // out config.alpha_real, config.alpha_imag, config.beta, config.algorithm, compute_precision); } diff --git a/tensorflow/compiler/xla/service/gpu/openxla/gemm.h b/tensorflow/compiler/xla/service/gpu/openxla/gemm.h index cf81ce0c4dc7e1..5d3de4bc2287b7 100644 --- a/tensorflow/compiler/xla/service/gpu/openxla/gemm.h +++ b/tensorflow/compiler/xla/service/gpu/openxla/gemm.h @@ -17,8 +17,8 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_OPENXLA_GEMM_H_ #include -#include +#include "absl/container/inlined_vector.h" #include "third_party/iree/runtime/src/iree/hal/api.h" // IWYU pragma: keep #include "third_party/iree/runtime/src/iree/vm/api.h" // IWYU pragma: keep #include "tensorflow/compiler/xla/service/gpu/openxla/vm.h" @@ -27,20 +27,20 @@ limitations under the License. namespace xla::gpu { //===-----------------------------------------------------------------------===/ -// XLA:GPU custom module module types +// XLA:GPU gemm API custom types //===-----------------------------------------------------------------------===/ namespace vm { struct DotDimensionNumbers : public iree::vm::RefObject { - std::vector lhs_batch_dims; - std::vector rhs_batch_dims; - std::vector lhs_contracting_dims; - std::vector rhs_contracting_dims; + absl::InlinedVector lhs_batch_dims; + absl::InlinedVector rhs_batch_dims; + absl::InlinedVector lhs_contracting_dims; + absl::InlinedVector rhs_contracting_dims; }; struct DotPrecision : public iree::vm::RefObject { - std::vector precision; + absl::InlinedVector precision; }; struct DotConfig : public iree::vm::RefObject { diff --git a/tensorflow/compiler/xla/service/gpu/openxla/hal.cc b/tensorflow/compiler/xla/service/gpu/openxla/hal.cc index 173c52795c5579..028296439648b6 100644 --- a/tensorflow/compiler/xla/service/gpu/openxla/hal.cc +++ b/tensorflow/compiler/xla/service/gpu/openxla/hal.cc @@ -27,7 +27,7 @@ namespace xla::gpu { //===----------------------------------------------------------------------===// // TODO(ezhulenev): Add support for all element types. -static PrimitiveType GetElementType(const iree_hal_buffer_view_t* view) { +static PrimitiveType GetElementType(iree_hal_buffer_view_t* view) { switch (iree_hal_buffer_view_element_type(view)) { case IREE_HAL_ELEMENT_TYPE_FLOAT_32: return PrimitiveType::F32; @@ -37,14 +37,13 @@ static PrimitiveType GetElementType(const iree_hal_buffer_view_t* view) { } } -static absl::InlinedVector GetDims( - const iree_hal_buffer_view_t* view) { +static absl::InlinedVector GetDims(iree_hal_buffer_view_t* view) { const iree_host_size_t* dims = iree_hal_buffer_view_shape_dims(view); iree_host_size_t rank = iree_hal_buffer_view_shape_rank(view); return absl::InlinedVector(dims, dims + rank); } -Shape GetShape(const iree_hal_buffer_view_t* view) { +Shape GetBufferShape(iree_hal_buffer_view_t* view) { return ShapeUtil::MakeShape(GetElementType(view), GetDims(view)); } @@ -53,7 +52,7 @@ StatusOr GetDeviceMemory( // Get original allocation behind a buffer subspan. iree_hal_buffer_t* allocated = iree_hal_buffer_allocated_buffer(buffer); - // Export original allocation as a device allocation. + // Export allocated buffer as an external device allocation. iree_hal_external_buffer_t external_buffer; iree::Status exported = iree_hal_allocator_export_buffer( device_allocator, allocated, diff --git a/tensorflow/compiler/xla/service/gpu/openxla/hal.h b/tensorflow/compiler/xla/service/gpu/openxla/hal.h index ea33958cc01ea5..cc1d6e39586bdc 100644 --- a/tensorflow/compiler/xla/service/gpu/openxla/hal.h +++ b/tensorflow/compiler/xla/service/gpu/openxla/hal.h @@ -26,7 +26,7 @@ namespace xla::gpu { // Helper functions to work with IREE buffers and buffer views //===----------------------------------------------------------------------===// -Shape GetShape(const iree_hal_buffer_view_t* view); +Shape GetBufferShape(iree_hal_buffer_view_t* view); StatusOr GetDeviceMemory( iree_hal_allocator_t* device_allocator, iree_hal_buffer_t* buffer); diff --git a/tensorflow/compiler/xla/service/gpu/openxla/kernel.cc b/tensorflow/compiler/xla/service/gpu/openxla/kernel.cc new file mode 100644 index 00000000000000..176147dfd18458 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/openxla/kernel.cc @@ -0,0 +1,107 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/openxla/kernel.h" + +#include "absl/container/inlined_vector.h" +#include "tensorflow/compiler/xla/service/gpu/openxla/hal.h" +#include "tensorflow/compiler/xla/service/gpu/openxla/vm.h" +#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" +#include "tensorflow/compiler/xla/service/service_executable_run_options.h" +#include "tensorflow/compiler/xla/stream_executor/device_memory.h" + +namespace xla::gpu { + +//===-----------------------------------------------------------------------===/ +// XLA:GPU kernel dispatch API +//===-----------------------------------------------------------------------===/ + +Status DispatchKernel(const vm::ExecutionContext& ctx, const vm::Kernel& kernel, + iree_hal_allocator_t* device_allocator, + absl::Span args, + LaunchDimensions dims) { + se::Stream* stream = ctx.run_options->stream(); + se::StreamExecutor* executor = stream->parent(); + + // TODO(ezhulenev): Keep a cache of loaded kernels for each executor. + TF_ASSIGN_OR_RETURN( + std::unique_ptr kernel_base, + CreateKernel(kernel.kernel_name, args.size(), ctx.executable_source.ptx, + ctx.executable_source.cubin, executor, + kernel.shared_memory_bytes)); + + absl::InlinedVector device_args; + for (iree_hal_buffer_view_t* arg : args) { + TF_ASSIGN_OR_RETURN(device_args.emplace_back(), + GetDeviceMemory(device_allocator, arg)); + } + + return ExecuteKernelOnStream(*kernel_base, device_args, dims, stream); +} + +//===-----------------------------------------------------------------------===/ +// XLA:GPU custom module kernel dispatch API +//===-----------------------------------------------------------------------===/ + +namespace vm { + +// TODO(ezhulenev): We need to find a way to pass original Status back to the +// caller preserving the location and stack frame. Can we use some diagnostic +// side channel via the ExecutionContext? +static iree::Status FromStatus(Status status) { + if (status.ok()) return iree_ok_status(); + + // TODO(ezhulenev): Convert from ABSL to IREE error code. + std::string err = status.ToString(); + return iree_make_status(IREE_STATUS_INTERNAL, "internal error: %s", + err.c_str()); +} + +KernelAPI::KernelAPI(iree_hal_allocator_t* device_allocator) + : device_allocator_(device_allocator) {} + +iree::StatusOr> KernelAPI::KernelCreate( + iree_string_view_t kernel_name, int32_t shared_memory_bytes) { + auto ref = iree::vm::make_ref(); + ref->kernel_name = std::string(kernel_name.data, kernel_name.size); + ref->shared_memory_bytes = shared_memory_bytes; + return ref; +} + +iree::Status KernelAPI::KernelDispatch( + iree::vm::ref ctx, iree::vm::ref kernel, + iree::vm::ref args, int32_t workgroup_size_x, + int32_t workgroup_size_y, int32_t workgroup_size_z, int32_t workload_size_x, + int32_t workload_size_y, int32_t workload_size_z) { + // Kernel launch dimensions + shared memory requirement. + LaunchDimensions launch_dimensions( + {workload_size_x, workload_size_y, workload_size_z}, + {workgroup_size_x, workgroup_size_y, workgroup_size_z}); + launch_dimensions.SetSharedMemBytes(kernel->shared_memory_bytes); + + IREE_ASSIGN_OR_RETURN(auto buffer_views, GetBufferViewVector(args.get())); + return FromStatus(DispatchKernel(*ctx, *kernel, device_allocator_, + {buffer_views.data(), buffer_views.size()}, + launch_dimensions)); +} + +} // namespace vm +} // namespace xla::gpu + +//===----------------------------------------------------------------------===// +// Register types with IREE VM +//===----------------------------------------------------------------------===// + +IREE_VM_DEFINE_TYPE_ADAPTERS(kernel, xla::gpu::vm::Kernel); diff --git a/tensorflow/compiler/xla/service/gpu/openxla/kernel.h b/tensorflow/compiler/xla/service/gpu/openxla/kernel.h new file mode 100644 index 00000000000000..caebd5899540bd --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/openxla/kernel.h @@ -0,0 +1,88 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_OPENXLA_KERNEL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_OPENXLA_KERNEL_H_ + +#include + +#include "third_party/iree/runtime/src/iree/hal/api.h" // IWYU pragma: keep +#include "third_party/iree/runtime/src/iree/vm/api.h" // IWYU pragma: keep +#include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h" +#include "tensorflow/compiler/xla/service/gpu/openxla/vm.h" + +namespace xla::gpu { + +namespace vm { + +//===-----------------------------------------------------------------------===/ +// XLA:GPU kernel API custom types +//===-----------------------------------------------------------------------===/ + +struct Kernel : public iree::vm::RefObject { + std::string kernel_name; + int32_t shared_memory_bytes; +}; + +} // namespace vm + +//===-----------------------------------------------------------------------===/ +// XLA:GPU kernel API +//===-----------------------------------------------------------------------===/ + +Status DispatchKernel(const vm::ExecutionContext& ctx, const vm::Kernel& kernel, + iree_hal_allocator_t* device_allocator, + absl::Span args, + LaunchDimensions dims); + +//===-----------------------------------------------------------------------===/ +// XLA:GPU custom module kernel dispatch API +//===-----------------------------------------------------------------------===/ + +namespace vm { + +class KernelAPI { + public: + explicit KernelAPI(iree_hal_allocator_t* device_allocator); + + iree::StatusOr> KernelCreate( + iree_string_view_t kernel_name, int32_t shared_memory_bytes); + + // Dispatches device kernel with given buffers and parameters. + iree::Status KernelDispatch(iree::vm::ref ctx, + iree::vm::ref kernel, + iree::vm::ref args, + // Workgroup size (block size) + int32_t workgroup_size_x, + int32_t workgroup_size_y, + int32_t workgroup_size_z, + // Workload size (grid size) + int32_t workload_size_x, int32_t workload_size_y, + int32_t workload_size_z); + + private: + iree_hal_allocator_t* device_allocator_; +}; + +} // namespace vm +} // namespace xla::gpu + +//===----------------------------------------------------------------------===// +// Register types with IREE VM +//===----------------------------------------------------------------------===// + +IREE_VM_DECLARE_TYPE_ADAPTERS(kernel, xla::gpu::vm::Kernel); + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_OPENXLA_KERNEL_H_ diff --git a/tensorflow/compiler/xla/service/gpu/openxla/module.cc b/tensorflow/compiler/xla/service/gpu/openxla/module.cc index e9e489868818dc..6b961dc8ee2a86 100644 --- a/tensorflow/compiler/xla/service/gpu/openxla/module.cc +++ b/tensorflow/compiler/xla/service/gpu/openxla/module.cc @@ -24,6 +24,7 @@ limitations under the License. #include "third_party/iree/runtime/src/iree/vm/native_module_cc.h" #include "third_party/iree/runtime/src/iree/vm/native_module_packing.h" #include "tensorflow/compiler/xla/service/gpu/openxla/gemm.h" +#include "tensorflow/compiler/xla/service/gpu/openxla/kernel.h" #include "tensorflow/compiler/xla/service/gpu/openxla/vm.h" namespace xla::gpu { @@ -33,12 +34,13 @@ namespace xla::gpu { //===-----------------------------------------------------------------------===/ using vm::GemmAPI; +using vm::KernelAPI; using vm::TraceAPI; -class XlaGpuModuleState : public GemmAPI, public TraceAPI { +class XlaGpuModuleState : public GemmAPI, public KernelAPI, public TraceAPI { public: explicit XlaGpuModuleState(iree_hal_allocator_t* device_allocator) - : GemmAPI(device_allocator) {} + : GemmAPI(device_allocator), KernelAPI(device_allocator) {} }; //===----------------------------------------------------------------------===// @@ -81,6 +83,10 @@ static const iree::vm::NativeFunction kXlaGpuFunctions[] = { MakeApiFunction("dot_config.create", &GemmAPI::DotConfigCreate), MakeApiFunction("gemm.dispatch", &GemmAPI::GemmDispatch), + // XLA:GPU kernel APIs + MakeApiFunction("kernel.create", &KernelAPI::KernelCreate), + MakeApiFunction("kernel.dispatch", &KernelAPI::KernelDispatch), + // XLA:GPU tracing APIs MakeApiFunction("trace.create", &TraceAPI::TraceCreate), }; @@ -161,6 +167,10 @@ iree_status_t RegisterXlaGpuTypes(iree_vm_instance_t* instance) { IREE_RETURN_IF_ERROR(RegisterType( instance, "xla_gpu.dot_config", &dot_config_registration)); + // XLA:GPU kernel dispatch types + IREE_RETURN_IF_ERROR(RegisterType(instance, "xla_gpu.kernel", + &kernel_registration)); + // XLA:GPU tracing types IREE_RETURN_IF_ERROR( RegisterType(instance, "xla_gpu.trace", &trace_registration)); diff --git a/tensorflow/compiler/xla/service/gpu/openxla/vm.cc b/tensorflow/compiler/xla/service/gpu/openxla/vm.cc index dea05d4d3d68d9..75ad598d653289 100644 --- a/tensorflow/compiler/xla/service/gpu/openxla/vm.cc +++ b/tensorflow/compiler/xla/service/gpu/openxla/vm.cc @@ -16,9 +16,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/openxla/vm.h" #include -#include +#include "absl/container/inlined_vector.h" #include "absl/strings/str_format.h" +#include "third_party/iree/runtime/src/iree/modules/hal/types.h" namespace xla::gpu::vm { @@ -41,16 +42,30 @@ iree::StatusOr> TraceAPI::TraceCreate( // Helper functions to work with VM lists //===----------------------------------------------------------------------===// -iree::StatusOr> GetI64Vector(const iree_vm_list_t* list) { +iree::StatusOr> +GetBufferViewVector(iree_vm_list_t* list) { iree_host_size_t size = iree_vm_list_size(list); - std::vector values(size); + absl::InlinedVector vector(size); + + for (iree_host_size_t i = 0; i < size; ++i) { + iree_vm_ref_t ref{nullptr}; + IREE_RETURN_IF_ERROR(iree_vm_list_get_ref_assign(list, i, &ref)); + IREE_RETURN_IF_ERROR(iree_hal_buffer_view_check_deref(ref, &vector[i])); + } + return vector; +} + +iree::StatusOr> GetI64Vector( + iree_vm_list_t* list) { + iree_host_size_t size = iree_vm_list_size(list); + absl::InlinedVector vector(size); for (iree_host_size_t i = 0; i < size; ++i) { iree_vm_value_t value; IREE_RETURN_IF_ERROR( iree_vm_list_get_value_as(list, i, IREE_VM_VALUE_TYPE_I64, &value)); - values[i] = value.i64; + vector[i] = value.i64; } - return values; + return vector; } } // namespace xla::gpu::vm diff --git a/tensorflow/compiler/xla/service/gpu/openxla/vm.h b/tensorflow/compiler/xla/service/gpu/openxla/vm.h index 0c9d84ea919d84..3fc31dbe59aa35 100644 --- a/tensorflow/compiler/xla/service/gpu/openxla/vm.h +++ b/tensorflow/compiler/xla/service/gpu/openxla/vm.h @@ -17,9 +17,10 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_OPENXLA_VM_H_ #include -#include -#include "third_party/iree/runtime/src/iree/vm/api.h" // IWYU pragma: keep +#include "absl/container/inlined_vector.h" +#include "third_party/iree/runtime/src/iree/hal/api.h" // IWYU pragma: keep +#include "third_party/iree/runtime/src/iree/vm/api.h" // IWYU pragma: keep namespace xla { @@ -36,12 +37,22 @@ namespace gpu::vm { // runtime APIs. For example through `run_options` pointer we get access to // the current compute stream, stream borrower, parent executor, etc. struct ExecutionContext : public iree::vm::RefObject { + // XLA:GPU kernels compiled to PTX/CUBIN (for NVIDIA platform). + struct ExecutableSource { + const std::string_view ptx; + const absl::Span cubin; + }; + ExecutionContext(const ServiceExecutableRunOptions* run_options, - const DebugOptions* debug_options) - : run_options(run_options), debug_options(debug_options) {} + const DebugOptions* debug_options, + ExecutableSource executable_source) + : run_options(run_options), + debug_options(debug_options), + executable_source(executable_source) {} const ServiceExecutableRunOptions* run_options; const DebugOptions* debug_options; + ExecutableSource executable_source; }; //===----------------------------------------------------------------------===// @@ -64,7 +75,11 @@ struct TraceAPI { // Helper functions to work with VM lists //===----------------------------------------------------------------------===// -iree::StatusOr> GetI64Vector(const iree_vm_list_t* list); +iree::StatusOr> +GetBufferViewVector(iree_vm_list_t* list); + +iree::StatusOr> GetI64Vector( + iree_vm_list_t* list); } // namespace gpu::vm } // namespace xla diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index 9431fe3a93d3af..7ee2bb1d7090c1 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -423,7 +423,7 @@ message DebugOptions { bool xla_gpu_enable_xla_runtime_executable = 169; // If true, use OpenXLA runtime for XLA:GPU backend. That is, use IREE VM - // as a host executable, CUDA HAL for dispatching device kernels and + // as a host executable, optional CUDA HAL for dispatching device kernels and // custom modules for integration with libraries required for running // XLA:GPU programs. // @@ -431,6 +431,11 @@ message DebugOptions { // is defined above. bool xla_gpu_enable_openxla_runtime = 233; + // If true, use OpenXLA hardware abstraction layer (aka CUDA HAL) to dispatch + // device kernels, otherwise use StreamExecutor kernel launch APIs. Has any + // effect only if `xla_gpu_enable_openxla_runtime` is set to true. + bool xla_gpu_enable_openxla_hal = 234; + // Timeout in seconds before terminating jobs that are stuck in a NCCL // Rendezvous. Negative value disables the timeout and will not terminate. int64 xla_gpu_nccl_termination_timeout_seconds = 163; @@ -589,7 +594,7 @@ message DebugOptions { bool xla_gpu_dump_autotuned_triton_fusions = 232; - // Next id: 234 + // Next id: 235 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend. From 6d13c55dc7d6df3cfebf2b979cb860d869733d25 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Fri, 28 Jul 2023 02:21:35 -0700 Subject: [PATCH 302/410] [xla:gpu] NFC: Use single flag to control experimental compiler/runtime build PiperOrigin-RevId: 551788871 --- .../compiler/xla/mlir/backends/openxla/BUILD | 16 ++++++++++------ .../xla/mlir/backends/openxla/build_config.bzl | 4 ++-- .../compiler/xla/service/gpu/openxla/BUILD | 11 +---------- .../xla/service/gpu/openxla/build_config.bzl | 13 ------------- 4 files changed, 13 insertions(+), 31 deletions(-) delete mode 100644 tensorflow/compiler/xla/service/gpu/openxla/build_config.bzl diff --git a/tensorflow/compiler/xla/mlir/backends/openxla/BUILD b/tensorflow/compiler/xla/mlir/backends/openxla/BUILD index c381bbc7b9b27d..aad55104601b1d 100644 --- a/tensorflow/compiler/xla/mlir/backends/openxla/BUILD +++ b/tensorflow/compiler/xla/mlir/backends/openxla/BUILD @@ -1,4 +1,5 @@ load("//tensorflow/compiler/xla:xla.bzl", "xla_cc_binary") +load("@bazel_skylib//rules:common_settings.bzl", "bool_flag") load("@bazel_skylib//rules:build_test.bzl", "build_test") package( @@ -7,13 +8,16 @@ package( licenses = ["notice"], ) -# Add `--define=xla_gpu_with_openxla_compiler=1` to build command to enable experimental -# OpenXLA/IREE backend compiler. +# Add `--//third_party/tensorflow/compiler/xla/mlir/backends/openxla:enable` to build command to +# enable experimental backend compiler and runtime. +bool_flag( + name = "enable", + build_setting_default = False, +) + config_setting( - name = "with_openxla_compiler", - values = { - "define": "xla_gpu_with_openxla_compiler=1", - }, + name = "enabled", + flag_values = {":enable": "True"}, ) # copybara:uncomment_begin(not supported in OSS build) diff --git a/tensorflow/compiler/xla/mlir/backends/openxla/build_config.bzl b/tensorflow/compiler/xla/mlir/backends/openxla/build_config.bzl index 3ab222a6da6929..935d82992db91b 100644 --- a/tensorflow/compiler/xla/mlir/backends/openxla/build_config.bzl +++ b/tensorflow/compiler/xla/mlir/backends/openxla/build_config.bzl @@ -2,12 +2,12 @@ def if_openxla(then, otherwise = []): return select({ - "//tensorflow/compiler/xla/mlir/backends/openxla:with_openxla_compiler": then, + "//tensorflow/compiler/xla/mlir/backends/openxla:enabled": then, "//conditions:default": otherwise, }) def if_not_openxla(then, otherwise = []): return select({ - "//tensorflow/compiler/xla/mlir/backends/openxla:with_openxla_compiler": otherwise, + "//tensorflow/compiler/xla/mlir/backends/openxla:enabled": otherwise, "//conditions:default": then, }) diff --git a/tensorflow/compiler/xla/service/gpu/openxla/BUILD b/tensorflow/compiler/xla/service/gpu/openxla/BUILD index 30273d26a86460..047cb73d5548ee 100644 --- a/tensorflow/compiler/xla/service/gpu/openxla/BUILD +++ b/tensorflow/compiler/xla/service/gpu/openxla/BUILD @@ -1,6 +1,6 @@ load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") load("//tensorflow/tsl/platform:build_config.bzl", "tf_platform_deps") -load("//tensorflow/compiler/xla/service/gpu/openxla:build_config.bzl", "if_not_openxla", "if_openxla") +load("//tensorflow/compiler/xla/mlir/backends/openxla:build_config.bzl", "if_not_openxla", "if_openxla") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -13,15 +13,6 @@ package_group( includes = ["//tensorflow/compiler/xla:friends"], ) -# Add `--define=xla_gpu_with_openxla_runtime=1` to build command to enable experimental OpenXLA/IREE -# backend for XLA:GPU executables. -config_setting( - name = "with_openxla_runtime", - values = { - "define": "xla_gpu_with_openxla_runtime=1", - }, -) - # copybara:uncomment_begin(not supported in OSS build) # # cc_library( diff --git a/tensorflow/compiler/xla/service/gpu/openxla/build_config.bzl b/tensorflow/compiler/xla/service/gpu/openxla/build_config.bzl deleted file mode 100644 index 1599be4c4374c1..00000000000000 --- a/tensorflow/compiler/xla/service/gpu/openxla/build_config.bzl +++ /dev/null @@ -1,13 +0,0 @@ -"""Helpers for conditional OpenXLA compilation.""" - -def if_openxla(then, otherwise = []): - return select({ - ":with_openxla_runtime": then, - "//conditions:default": otherwise, - }) - -def if_not_openxla(then, otherwise = []): - return select({ - ":with_openxla_runtime": otherwise, - "//conditions:default": then, - }) From 8ba454c752ea546d60c6b25a4a70a7269569d3cb Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Fri, 28 Jul 2023 02:31:16 -0700 Subject: [PATCH 303/410] [XLA:GPU] Limit the use of split-K in Triton GEMMs. Split-K is unnecessary for GEMMs with sufficiently large outputs, probing it during autotuning is slow and requires lots of temporary memory. PiperOrigin-RevId: 551790648 --- tensorflow/compiler/xla/service/gpu/BUILD | 1 + .../xla/service/gpu/triton_autotuner.cc | 47 +++++++++++++++---- .../xla/service/gpu/triton_autotuner.h | 5 +- .../xla/service/gpu/triton_autotuner_test.cc | 38 ++++++++++++++- 4 files changed, 78 insertions(+), 13 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 34b6bf9b5e9382..1b09d9408c55fb 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -609,6 +609,7 @@ cc_library( "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", "//tensorflow/compiler/xla:autotuning_proto_cc", + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_proto_cc", diff --git a/tensorflow/compiler/xla/service/gpu/triton_autotuner.cc b/tensorflow/compiler/xla/service/gpu/triton_autotuner.cc index e0587af9197537..0773e5823c6196 100644 --- a/tensorflow/compiler/xla/service/gpu/triton_autotuner.cc +++ b/tensorflow/compiler/xla/service/gpu/triton_autotuner.cc @@ -54,6 +54,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/stream_executor/device_description.h" #include "tensorflow/compiler/xla/stream_executor/device_memory.h" @@ -89,6 +90,9 @@ static AutotuneResult::TritonGemmKey GemmKey(int64_t block_m, int64_t block_n, return key; } +// Not a hard limit, just an assumption that should stay valid. +constexpr int kMaxTileSize = 512; + struct CompilationKey { template friend H AbslHashValue(H h, const CompilationKey& k) { @@ -156,7 +160,6 @@ class TritonAutotunerVisitor : public DfsHloRewriteVisitor { allocator = stream_exec->GetAllocator(); } - HloInstruction* root = fusion.root_instruction(); TF_ASSIGN_OR_RETURN(se::Stream* const stream, allocator->GetStream(stream_exec->device_ordinal())); @@ -167,11 +170,12 @@ class TritonAutotunerVisitor : public DfsHloRewriteVisitor { AutotunerUtil::CreateRedzoneAllocator(config_, debug_opts)); std::optional reference_buffer; - BufferComparator comparator(root->shape(), fusion.parent()->config()); + const HloInstruction& root = *fusion.root_instruction(); + BufferComparator comparator(root.shape(), fusion.parent()->config()); const std::vector configurations = GetPossibleMatmulAutotuneConfigs( - stream_exec->GetDeviceDescription().cuda_compute_capability(), + root, stream_exec->GetDeviceDescription().cuda_compute_capability(), config_.ExhaustiveTilingSearch()); GpuDeviceInfo gpu_device_info = GetGpuDeviceInfo(config_.GetExecutor()); @@ -181,6 +185,9 @@ class TritonAutotunerVisitor : public DfsHloRewriteVisitor { executables; auto compile = [&](const AutotuneResult::TritonGemmKey& conf) { + CHECK(conf.block_m() <= kMaxTileSize); + CHECK(conf.block_n() <= kMaxTileSize); + CHECK(conf.block_k() <= kMaxTileSize); TF_ASSIGN_OR_RETURN(std::unique_ptr executable, autotuner_compile_util_->Compile([&] { return TritonGemmAutotuneExtractor( @@ -285,7 +292,7 @@ class TritonAutotunerVisitor : public DfsHloRewriteVisitor { TF_ASSIGN_OR_RETURN( AutotuneResult best, - PickBestResult(results, root->ToString(), root->GetModule()->config())); + PickBestResult(results, root.ToString(), root.GetModule()->config())); if (debug_opts.xla_gpu_dump_autotuned_triton_fusions()) { TF_ASSIGN_OR_RETURN( @@ -392,7 +399,7 @@ constexpr std::array NUM_WARPS = {2, 4, 8, 16}; constexpr std::array SPLIT_K = {1, 2, 4, 8, 16}; std::vector GetExhaustiveMatmulAutotuneConfigs( - const se::CudaComputeCapability compute_capability) { + const se::CudaComputeCapability compute_capability, const int max_split_k) { std::vector configs; bool mma_layout_v2 = compute_capability.IsAtLeast(se::CudaComputeCapability::AMPERE); @@ -410,6 +417,9 @@ std::vector GetExhaustiveMatmulAutotuneConfigs( } for (int block_k : BLOCK_SIZES) { for (int split_k : SPLIT_K) { + if (split_k > max_split_k) { + continue; + } auto config = GemmKey(block_m, block_n, block_k, split_k, num_stages, num_warps); configs.push_back(std::move(config)); @@ -423,7 +433,7 @@ std::vector GetExhaustiveMatmulAutotuneConfigs( } std::vector GetFixedMatmulAutotuneConfigs( - const se::CudaComputeCapability compute_capability) { + const se::CudaComputeCapability compute_capability, const int max_split_k) { std::vector configs = { GemmKey(32, 32, 256, 1, 1, 4), GemmKey(64, 32, 32, 16, 1, 4), GemmKey(32, 64, 64, 4, 1, 4), GemmKey(128, 128, 64, 4, 1, 4), @@ -457,17 +467,38 @@ std::vector GetFixedMatmulAutotuneConfigs( }), configs.end()); } + configs.erase( + std::remove_if(configs.begin(), configs.end(), + [&](const AutotuneResult::TritonGemmKey& config) { + return config.split_k() > max_split_k; + }), + configs.end()); return configs; } } // anonymous namespace std::vector GetPossibleMatmulAutotuneConfigs( + const HloInstruction& instr, const se::CudaComputeCapability compute_capability, bool exhaustive_tiling_search) { + // Split-K optimization enables more even utilization of a GPU in cases + // where tiling just the non-contracting dimensions of a GEMM does not create + // a sufficient number of thread block programs to occupy all available cores. + // Given the typical ~100 cores per GPU 500 tiles make around 5 full + // waves that completely avoid the need for split-K. The formula below is + // n_tiles = split_k * (M * N) / (block_m * block_n) + // with pessimistically assumed maximum block_m and block_n. + // Most likely there is no need for split-K already at much smaller output + // tensor sizes. + constexpr int kSufficientNumberOfTiles = 500; + const int max_split_k = + std::max(1L, kSufficientNumberOfTiles * kMaxTileSize * kMaxTileSize / + ShapeUtil::ElementsIn(instr.shape())); return exhaustive_tiling_search - ? GetExhaustiveMatmulAutotuneConfigs(compute_capability) - : GetFixedMatmulAutotuneConfigs(compute_capability); + ? GetExhaustiveMatmulAutotuneConfigs(compute_capability, + max_split_k) + : GetFixedMatmulAutotuneConfigs(compute_capability, max_split_k); } StatusOr TritonAutotuner::Run( diff --git a/tensorflow/compiler/xla/service/gpu/triton_autotuner.h b/tensorflow/compiler/xla/service/gpu/triton_autotuner.h index 0994425d9ef5a9..9b37c049208fe7 100644 --- a/tensorflow/compiler/xla/service/gpu/triton_autotuner.h +++ b/tensorflow/compiler/xla/service/gpu/triton_autotuner.h @@ -15,7 +15,6 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_TRITON_AUTOTUNER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_TRITON_AUTOTUNER_H_ -#include #include #include "absl/container/flat_hash_set.h" @@ -50,9 +49,9 @@ class TritonAutotuner : public HloModulePass { }; // TODO(b/266210099): have a way to generate/load these dynamically. -// Returns a list of possible tilings for a gemm performed in Triton. +// Returns a list of possible tilings for a GEMM performed in Triton. std::vector GetPossibleMatmulAutotuneConfigs( - se::CudaComputeCapability compute_capability, + const HloInstruction& instr, se::CudaComputeCapability compute_capability, bool exhaustive_tiling_search = false); } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/triton_autotuner_test.cc b/tensorflow/compiler/xla/service/gpu/triton_autotuner_test.cc index 1341b9061d9682..5ca39c6b0cb6fd 100644 --- a/tensorflow/compiler/xla/service/gpu/triton_autotuner_test.cc +++ b/tensorflow/compiler/xla/service/gpu/triton_autotuner_test.cc @@ -222,7 +222,10 @@ TEST_F(TritonAutotunerTest, VoltaUsesNoMoreThanTwoStages) { const se::CudaComputeCapability compute_capability{ se::CudaComputeCapability::VOLTA, /*minor=*/0}; const std::vector configs = - GetPossibleMatmulAutotuneConfigs(compute_capability); + GetPossibleMatmulAutotuneConfigs( + *HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {1024, 1024}), ""), + compute_capability); EXPECT_FALSE(std::any_of(configs.begin(), configs.end(), [](const AutotuneResult::TritonGemmKey& key) { return key.num_stages() > 2; @@ -233,13 +236,44 @@ TEST_F(TritonAutotunerTest, AmpereUsesMoreThanTwoStages) { const se::CudaComputeCapability compute_capability{ se::CudaComputeCapability::AMPERE, /*minor=*/0}; const std::vector configs = - GetPossibleMatmulAutotuneConfigs(compute_capability); + GetPossibleMatmulAutotuneConfigs( + *HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {1024, 1024}), ""), + compute_capability); EXPECT_TRUE(std::any_of(configs.begin(), configs.end(), [](const AutotuneResult::TritonGemmKey& key) { return key.num_stages() > 2; })); } +TEST_F(TritonAutotunerTest, SmallOutputCanUseLargeSplitK) { + const se::CudaComputeCapability compute_capability{ + se::CudaComputeCapability::AMPERE, /*minor=*/0}; + const std::vector configs = + GetPossibleMatmulAutotuneConfigs( + *HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {1024, 1024}), ""), + compute_capability); + EXPECT_TRUE(std::any_of(configs.begin(), configs.end(), + [](const AutotuneResult::TritonGemmKey& key) { + return key.split_k() >= 16; + })); +} + +TEST_F(TritonAutotunerTest, LargeOutputDoesNotUseLargeSplitK) { + const se::CudaComputeCapability compute_capability{ + se::CudaComputeCapability::AMPERE, /*minor=*/0}; + const std::vector configs = + GetPossibleMatmulAutotuneConfigs( + *HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {20480, 20480}), ""), + compute_capability); + EXPECT_FALSE(std::any_of(configs.begin(), configs.end(), + [](const AutotuneResult::TritonGemmKey& key) { + return key.split_k() > 1; + })); +} + TEST_F(TritonAutotunerTest, Int8FusedGemm) { const std::string hlo = R"( HloModule module From 0967a0f6b8d8cf153ccaf41e4d57dc63a054882f Mon Sep 17 00:00:00 2001 From: Crefeda Rodrigues <65665931+cfRod@users.noreply.github.com> Date: Fri, 28 Jul 2023 03:17:58 -0700 Subject: [PATCH 304/410] PR #61110: Adds matmul heuristic for oneDNN ACL builds on AArch64 Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/61110 This PR follows an approach of using heuristics to choose whether to rewrite a oneDNN matmul node via mkl_layout_pass for AArch64 builds. The heuristics is based on a decision tree model with the shapes and number of ops as features. This work is based on three PRs in upstream TensorFlow: 1. https://github.com/tensorflow/tensorflow/pull/60160 that uses the rewrite pass for the convolution microbenchmarks, which has been merged. 2. https://github.com/tensorflow/tensorflow/pull/60026 that uses a heuristic based on a linear model to choose between oneDNN and Eigen. The changes in this PR include the following: - Use the rewrite pass for matmul microbenchmarks - Use heuristic for matmul ops to decide when to rewrite a matmul node and is guarded for AArch64 builds Performance impact: With this PR, we show the performance before and after introducing this patch for 8 cores on Neoverse V1 platforms for default thread settings. ### Matmul microbenchmarks Before the patch ![image](https://github.com/tensorflow/tensorflow/assets/65665931/6e0f033c-aeb2-490e-8f78-7b08b4d4b19e) After the patch ![image](https://github.com/tensorflow/tensorflow/assets/65665931/a313fc01-c91c-4ca2-b7b8-5b6659c439a1) ### NLP models from hugging face ![image](https://github.com/tensorflow/tensorflow/assets/65665931/702fe640-c274-4dfe-83b0-6ef43537bf03) Copybara import of the project: -- 3ff50a6a46f1350e8b4f310a8fdac65e1d7498d5 by Crefeda Rodrigues : Update cpu_info to detect Aarch64 CPUs Change-Id: Ie6f132e28f801c1e7f68ab06b7e582eed1506ac9 -- 741149aae46ceb1001f93eeebd6400c8c0c845bc by Crefeda Rodrigues : Add matmul heuristics for oneDNN Aarch64 -- 2865e77522dac02d4b726fc83e54e77df6975477 by Crefeda Rodrigues <65665931+cfRod@users.noreply.github.com>: Update tensorflow/tsl/platform/cpu_info.cc Co-authored-by: Penporn Koanantakool <38085909+penpornk@users.noreply.github.com> -- f901f19399a93bae2a028c40f493b2e588db3c06 by Crefeda Rodrigues <65665931+cfRod@users.noreply.github.com>: Update tensorflow/tsl/platform/cpu_info.cc Co-authored-by: Penporn Koanantakool <38085909+penpornk@users.noreply.github.com> -- f922e0ec78b79345c107aae6db6ca3a459ee2c7a by Crefeda Rodrigues <65665931+cfRod@users.noreply.github.com>: Update tensorflow/tsl/platform/cpu_info.cc Co-authored-by: Penporn Koanantakool <38085909+penpornk@users.noreply.github.com> Merging this change closes #61110 COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/tensorflow/pull/61110 from cfRod:onednn_acl_matmul_heuristic f922e0ec78b79345c107aae6db6ca3a459ee2c7a PiperOrigin-RevId: 551801101 --- .../core/common_runtime/mkl_layout_pass.cc | 4 ++ tensorflow/core/graph/mkl_testlib.cc | 13 ---- tensorflow/core/graph/mkl_testlib.h | 4 -- .../core/grappler/optimizers/remapper.cc | 26 +++++-- .../kernels/mkl/mkl_matmul_op_benchmark.cc | 36 ++++++++-- tensorflow/core/platform/cpu_info.h | 2 + tensorflow/core/util/mkl_heuristics.h | 68 +++++++++++++++++++ tensorflow/tsl/platform/cpu_info.cc | 49 +++++++++++-- tensorflow/tsl/platform/cpu_info.h | 7 ++ 9 files changed, 178 insertions(+), 31 deletions(-) diff --git a/tensorflow/core/common_runtime/mkl_layout_pass.cc b/tensorflow/core/common_runtime/mkl_layout_pass.cc index 5e0ed106ed6891..050d98176e121e 100644 --- a/tensorflow/core/common_runtime/mkl_layout_pass.cc +++ b/tensorflow/core/common_runtime/mkl_layout_pass.cc @@ -1520,7 +1520,11 @@ class MklLayoutRewritePass : public GraphOptimizationPass { TF_CHECK_OK(GetNodeAttr(n->def(), "T", &T)); if ((T == DT_FLOAT) || (T == DT_BFLOAT16)) { VLOG(2) << "Rewriting MatMul to _MklMatMul"; +#ifdef DNNL_AARCH64_USE_ACL + return MatMulHeuristic(n); +#else return true; +#endif } return false; } diff --git a/tensorflow/core/graph/mkl_testlib.cc b/tensorflow/core/graph/mkl_testlib.cc index 2029207ee43278..05d0b67d3e1a16 100644 --- a/tensorflow/core/graph/mkl_testlib.cc +++ b/tensorflow/core/graph/mkl_testlib.cc @@ -23,19 +23,6 @@ namespace tensorflow { namespace test { namespace graph { -Node* oneDNNMatmul(Graph* g, Node* in0, Node* in1, bool transpose_a, - bool transpose_b) { - Node* ret = nullptr; - TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_MklMatMul") - .Input(in0) - .Input(in1) - .Attr("transpose_a", transpose_a) - .Attr("transpose_b", transpose_b) - .Attr("_kernel", mkl_op_registry::kMklNameChangeOpLabel) - .Finalize(g, &ret)); - return ret; -} - Node* oneDNNSoftmax(Graph* g, Node* input) { Node* ret = nullptr; TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_MklSoftmax") diff --git a/tensorflow/core/graph/mkl_testlib.h b/tensorflow/core/graph/mkl_testlib.h index 7f3fb726c80c59..733f124168d949 100644 --- a/tensorflow/core/graph/mkl_testlib.h +++ b/tensorflow/core/graph/mkl_testlib.h @@ -24,10 +24,6 @@ namespace tensorflow { namespace test { namespace graph { -// Adds a _MklMatmul node in g doing in0.contract(in1). -Node* oneDNNMatmul(Graph* g, Node* in0, Node* in1, bool transpose_a, - bool transpose_b); - Node* oneDNNSoftmax(Graph* g, Node* input); } // namespace graph diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc index d02722cd1af7dc..5af44b3340479f 100644 --- a/tensorflow/core/grappler/optimizers/remapper.cc +++ b/tensorflow/core/grappler/optimizers/remapper.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -2909,6 +2910,12 @@ void CopyMatMulAttributes(const NodeDef& matmul, NodeDef* fused_matmul, auto& activation_attr = activation->attr(); (*attr)["leakyrelu_alpha"] = activation_attr.at("alpha"); } + if (IsMKLEnabled()) { + auto input_shapes = src_attr.find("_input_shapes"); + if (input_shapes != src_attr.end()) { + (*attr)["_input_shapes"] = input_shapes->second; + } + } } void CopyBatchMatMulAttributes(const NodeDef& batchmatmul, @@ -2921,6 +2928,12 @@ void CopyBatchMatMulAttributes(const NodeDef& batchmatmul, (*attr)["T"] = src_attr.at("T"); (*attr)["adj_x"] = src_attr.at("adj_x"); (*attr)["adj_y"] = src_attr.at("adj_y"); + if (IsMKLEnabled()) { + auto input_shapes = src_attr.find("_input_shapes"); + if (input_shapes != src_attr.end()) { + (*attr)["_input_shapes"] = input_shapes->second; + } + } } void SetFusedOpAttributes(NodeDef* fused, @@ -2960,6 +2973,7 @@ Status AddFusedContractionNode(RemapperContext* ctx, CopyDepthwiseConv2dNativeAttributes(contraction, &fused_op); } else if (IsMatMul(contraction)) { fused_op.set_op(kFusedMatMul); + AddInputShapesAttr(*ctx, matched.contraction); CopyMatMulAttributes(contraction, &fused_op); } else if (IsConv3D(contraction)) { fused_op.set_op(kFusedConv3D); @@ -2997,7 +3011,7 @@ Status AddFusedContractionNode(RemapperContext* ctx, // attr and the value of alpha in case of LeakyRelu activation // creating a copy of the contraction - fused_op.CopyFrom(contraction); + fused_op = contraction; auto* attr = fused_op.mutable_attr(); auto contraction_fused_ops_list = @@ -3065,6 +3079,7 @@ Status AddFusedContractionNode( CopyDepthwiseConv2dNativeAttributes(contraction, &fused_op); } else if (IsMatMul(contraction)) { fused_op.set_op(kFusedMatMul); + AddInputShapesAttr(*ctx, matched.contraction); CopyMatMulAttributes(contraction, &fused_op, &activation); } else if (IsConv3D(contraction)) { fused_op.set_op(kFusedConv3D); @@ -3256,6 +3271,7 @@ Status AddFusedContractionNode(RemapperContext* ctx, AddInputShapesAttr(*ctx, matched.contraction); CopyConv2DAttributes(contraction, &contraction_node); } else if (IsMatMul(contraction)) { + AddInputShapesAttr(*ctx, matched.contraction); contraction_node.set_op(kFusedMatMul); CopyMatMulAttributes(contraction, &contraction_node); } else if (IsConv3D(contraction)) { @@ -4490,7 +4506,8 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item, IsDepthwiseConv2dNative(ctx.graph_view.graph()->node(i)) || IsBiasAdd(ctx.graph_view.graph()->node(i)) || IsTranspose(ctx.graph_view.graph()->node(i)) || - IsSigmoid(ctx.graph_view.graph()->node(i))) { + IsSigmoid(ctx.graph_view.graph()->node(i)) || + IsMatMul(ctx.graph_view.graph()->node(i))) { AddInputShapesAttr(ctx, i); } @@ -4600,7 +4617,8 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item, // We need to infer what is the shape of sigmoid AddInputShapesAttr(ctx, sigmoid_idx); const NodeDef* sigmoid = ctx.graph_view.GetNode(sigmoid_idx)->node(); - + const int intra_op_parallelism_threads = + item.optimization_options().intra_op_parallelism_threads; double total_mflops = CalculateNodeMFlops(AttrSlice(*sigmoid), "Sigmoid"); double thr = @@ -4611,7 +4629,7 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item, // so we are not going to rewrite node replace = false; } -#endif +#endif // DNNL_AARCH64_USE_ACL if (replace) { TF_RETURN_IF_ERROR( ReplaceSigmoidMulWithSwish(&ctx, sigmoidmul_matched_nodes_map, diff --git a/tensorflow/core/kernels/mkl/mkl_matmul_op_benchmark.cc b/tensorflow/core/kernels/mkl/mkl_matmul_op_benchmark.cc index bcee0d22f07296..05f62e1058a118 100644 --- a/tensorflow/core/kernels/mkl/mkl_matmul_op_benchmark.cc +++ b/tensorflow/core/kernels/mkl/mkl_matmul_op_benchmark.cc @@ -16,7 +16,10 @@ limitations under the License. #ifdef INTEL_MKL #include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" -#include "tensorflow/core/graph/mkl_testlib.h" +#include "tensorflow/core/common_runtime/mkl_layout_pass.h" +#include "tensorflow/core/graph/mkl_graph_util.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/util/util.h" namespace tensorflow { namespace { @@ -29,9 +32,34 @@ static Graph* Matmul(int m, int k, int n, bool transpose_a, bool transpose_b, in0.flat().setRandom(); Tensor in1(type, transpose_b ? TensorShape({n, k}) : TensorShape({k, n})); in1.flat().setRandom(); - test::graph::oneDNNMatmul(g, test::graph::Constant(g, in0), - test::graph::Constant(g, in1), transpose_a, - transpose_b); + + Node* src0 = test::graph::Constant(g, in0); + Node* src1 = test::graph::Constant(g, in1); + g->AddEdge(g->source_node(), 0, src0, 0); + g->AddEdge(g->source_node(), 1, src1, 0); + // Add shape sizes + AttrValue attr_input_shape; + TensorShapeProto* proto = attr_input_shape.mutable_list()->add_shape(); + proto->add_dim()->set_size(m); + proto->add_dim()->set_size(k); + proto = attr_input_shape.mutable_list()->add_shape(); + proto->add_dim()->set_size(k); + proto->add_dim()->set_size(n); + + Node* ret = nullptr; + TF_CHECK_OK(NodeBuilder(g->NewName("matmul"), "MatMul") + .Input(src0) + .Input(src1) + .Attr("transpose_a", transpose_a) + .Attr("transpose_b", transpose_b) + .Attr("_input_shapes", attr_input_shape) + .Finalize(g, &ret)); +#ifdef INTEL_MKL + if (IsMKLEnabled()) { + std::unique_ptr* ug = new std::unique_ptr(g); + RunMklLayoutRewritePass(ug); + } +#endif // INTEL_MKL return g; } diff --git a/tensorflow/core/platform/cpu_info.h b/tensorflow/core/platform/cpu_info.h index 9a776f30c1bc22..3144ce138cd349 100644 --- a/tensorflow/core/platform/cpu_info.h +++ b/tensorflow/core/platform/cpu_info.h @@ -25,6 +25,7 @@ limitations under the License. namespace tensorflow { namespace port { +using tsl::port::Aarch64CPU; using tsl::port::ADX; using tsl::port::AES; using tsl::port::AMX_BF16; @@ -80,6 +81,7 @@ using tsl::port::SSE3; using tsl::port::SSE4_1; using tsl::port::SSE4_2; using tsl::port::SSSE3; +using tsl::port::TestAarch64CPU; using tsl::port::TestCPUFeature; } // namespace port diff --git a/tensorflow/core/util/mkl_heuristics.h b/tensorflow/core/util/mkl_heuristics.h index cf991cc39b575e..83d8deae9a7fe0 100644 --- a/tensorflow/core/util/mkl_heuristics.h +++ b/tensorflow/core/util/mkl_heuristics.h @@ -23,6 +23,7 @@ limitations under the License. #include #include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/graph/graph.h" #include "tensorflow/tsl/platform/cpu_info.h" namespace tensorflow { @@ -120,6 +121,73 @@ static double CalculateNodeMFlops(const AttrSlice& attrs, return -1; } +// MatMulHeuristic returns true to rewrite the node with oneDNN +// false to execute the node in Eigen +static bool MatMulHeuristic(const Node* n) { + // Run heuristic only if CPU is ARM_NEOVERSE_V1 + if (!tsl::port::TestAarch64CPU(tsl::port::Aarch64CPU::ARM_NEOVERSE_V1)) { + return true; + } + // Check if we can obtain dimensions for this node. + std::vector shape_attrs; + if (!TryGetNodeAttr(n->attrs(), "_input_shapes", &shape_attrs)) { + // We can't obtain shape so we will revert to default behaviour + // to rewrite node. + return true; + } + + if ((n->type_string() == "MatMul" || n->type_string() == "_FusedMatMul")) { + TensorShape lhs_shape, rhs_shape; + if (TensorShape::BuildTensorShape(*shape_attrs[0], &lhs_shape) != + tsl::OkStatus()) { + return true; + } + if (TensorShape::BuildTensorShape(*shape_attrs[1], &rhs_shape) != + tsl::OkStatus()) { + return true; + } + + auto M = lhs_shape.dim_size(0); + auto K = lhs_shape.dim_size(1); + auto N = rhs_shape.dim_size(1); + auto ops = M * N * K; + std::array n_threshold = {7560, 250, 1536}; + std::array m_threshold = {378, 80}; + std::array ops_threshold = {5242880, 1090519040}; + + if (N <= n_threshold.at(0)) { + if (ops <= ops_threshold.at(0)) { + if (M <= m_threshold.at(0)) { + return false; + } else { + if (N <= n_threshold.at(1)) { + return false; + } else { + return true; + } + } + } else { + if (M <= m_threshold.at(1)) { + if (N <= n_threshold.at(2)) { + return true; + } else { + return false; + } + } else { + if (ops <= ops_threshold.at(1)) { + return true; + } else { + return false; + } + } + } + } else { + return false; + } + } + return true; +} + } // namespace tensorflow #endif // INTEL_MKL diff --git a/tensorflow/tsl/platform/cpu_info.cc b/tensorflow/tsl/platform/cpu_info.cc index f716fb6b3c5aca..e1835e4118ebdb 100644 --- a/tensorflow/tsl/platform/cpu_info.cc +++ b/tensorflow/tsl/platform/cpu_info.cc @@ -362,16 +362,19 @@ CPUIDInfo *cpuid = nullptr; // Structure for basic CPUID info. class CPUIDInfo { public: - CPUIDInfo() : implementer_(0), variant_(0), cpunum_(0) {} + CPUIDInfo() + : implementer_(0), + variant_(0), + cpunum_(0), + is_arm_neoverse_v1_(0), + is_arm_neoverse_n1_(0) {} static void Initialize() { - // Initialize cpuid struct. - if (cpuid != nullptr) { - return; - } + // Initialize CPUIDInfo pointer. + if (cpuid != nullptr) return; cpuid = new CPUIDInfo; - + // Make sure CPUID registers are available before reading them. if (!(getauxval(AT_HWCAP) & HWCAP_CPUID)) { return; } @@ -415,9 +418,23 @@ class CPUIDInfo { uint32 midr_el1 = std::stoul(line, nullptr, 16); // Unpack variant and CPU ID. + // Reference: + // https://developer.arm.com/documentation/101427/0101/Register-descriptions/AArch64-system-registers/MIDR-EL1--Main-ID-Register--EL1. cpuid->implementer_ = (midr_el1 >> 24) & 0xFF; cpuid->variant_ = (midr_el1 >> 20) & 0xF; cpuid->cpunum_ = (midr_el1 >> 4) & 0xFFF; + if (cpuid->implementer_ == 0x41) { + switch (cpuid->cpunum_) { + case 0xd40: // ARM NEOVERSE V1 + cpuid->is_arm_neoverse_v1_ = 1; + break; + case 0xd0c: // ARM NEOVERSE N1 + cpuid->is_arm_neoverse_n1_ = 1; + break; + default: + break; + } + } } } #endif // !PLATFORM_WINDOWS @@ -426,10 +443,22 @@ class CPUIDInfo { int implementer() const { return implementer_; } int cpunum() const { return cpunum_; } + static bool TestAarch64CPU(Aarch64CPU cpu) { + InitCPUIDInfo(); + switch (cpu) { + case ARM_NEOVERSE_V1: + return cpuid->is_arm_neoverse_v1_; + default: + return 0; + } + } + private: int implementer_; int variant_; int cpunum_; + int is_arm_neoverse_v1_; // ARM NEOVERSE V1 + int is_arm_neoverse_n1_; // ARM NEOVERSE N1 }; absl::once_flag cpuid_once_flag; @@ -450,6 +479,14 @@ bool TestCPUFeature(CPUFeature feature) { #endif } +bool TestAarch64CPU(Aarch64CPU cpu) { +#if defined(PLATFORM_IS_ARM64) && !defined(__APPLE__) && !defined(__OpenBSD__) + return CPUIDInfo::TestAarch64CPU(cpu); +#else + return false; +#endif +} + std::string CPUVendorIDString() { #ifdef PLATFORM_IS_X86 InitCPUIDInfo(); diff --git a/tensorflow/tsl/platform/cpu_info.h b/tensorflow/tsl/platform/cpu_info.h index 826c9d31beb230..efb627c09b2b9b 100644 --- a/tensorflow/tsl/platform/cpu_info.h +++ b/tensorflow/tsl/platform/cpu_info.h @@ -134,6 +134,13 @@ enum CPUFeature { AMX_BF16 = 43, // Bfloat16 tile matrix multiplication }; +enum Aarch64CPU { + ARM_NEOVERSE_N1 = 0, // ARM NEOVERSE N1 + ARM_NEOVERSE_V1 = 1, // ARM NEOVERSE V1 +}; +// Checks whether the current AArch64 processor is supported. +bool TestAarch64CPU(Aarch64CPU cpu); + // Checks whether the current processor supports one of the features above. // Checks CPU registers to return hardware capabilities. bool TestCPUFeature(CPUFeature feature); From f3034f991a0423822e78ebbe4fd24279d1b34f71 Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Fri, 28 Jul 2023 03:21:53 -0700 Subject: [PATCH 305/410] [XLA:GPU] Enable by default more fusions in Triton GEMMs. PiperOrigin-RevId: 551801728 --- tensorflow/compiler/xla/debug_options_flags.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/compiler/xla/debug_options_flags.cc b/tensorflow/compiler/xla/debug_options_flags.cc index 9684f416719d94..ddf54156ffdff8 100644 --- a/tensorflow/compiler/xla/debug_options_flags.cc +++ b/tensorflow/compiler/xla/debug_options_flags.cc @@ -168,7 +168,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_enable_cudnn_int8x32_convolution_reordering(true); opts.set_xla_gpu_triton_gemm_any(false); opts.set_xla_gpu_enable_triton_softmax_fusion(true); - opts.set_xla_gpu_triton_fusion_level(1); + opts.set_xla_gpu_triton_fusion_level(2); // Moving reduce-scatter out of while loops can increase memory footprint, so // turning it off by default. From f37474941deda653a586ef702dfebe13b54264bc Mon Sep 17 00:00:00 2001 From: George Necula Date: Fri, 28 Jul 2023 03:24:19 -0700 Subject: [PATCH 306/410] XlaCallModule: Increase limit of number of error message inputs for shape assertions When I introduced shape assertions I thought that we would only ever need to add two error message inputs, for the operands of binary comparisons that fail. Since then we have introduced much more elaborate error messages that sometimes need more inputs: https://github.com/google/jax/pull/16813. Hence, we ran into the limit of 4 maximum error inputs. I do not quite know how to remove the limit easily, so for now we increase it to 32. PiperOrigin-RevId: 551802441 --- .../compiler/tests/xla_call_module_test.py | 11 ++++--- .../xla/python/refine_polymorphic_shapes.cc | 32 ++++++++----------- 2 files changed, 20 insertions(+), 23 deletions(-) diff --git a/tensorflow/compiler/tests/xla_call_module_test.py b/tensorflow/compiler/tests/xla_call_module_test.py index 8b75dc450188b1..79f4918f8843ce 100644 --- a/tensorflow/compiler/tests/xla_call_module_test.py +++ b/tensorflow/compiler/tests/xla_call_module_test.py @@ -14,6 +14,7 @@ # ============================================================================== """Tests for XLA call module op wrapper.""" import os +import re from typing import Tuple import unittest @@ -472,11 +473,13 @@ def f(x): # x: f32[b, 5] and b = 3, with a constraint b == 4. func.func public @main(%arg1: tensor) -> tensor { %b = "stablehlo.get_dimension_size"(%arg1) {dimension = 0 : i64} : (tensor) -> tensor %4 = stablehlo.constant dense<4> : tensor + %5 = stablehlo.constant dense<5> : tensor + %11 = stablehlo.constant dense<11> : tensor %ok = stablehlo.compare EQ, %b, %4, SIGNED : (tensor, tensor) -> tensor - stablehlo.custom_call @shape_assertion(%ok, %b, %4) { - error_message = "Expecting {0} == {1}", + stablehlo.custom_call @shape_assertion(%ok, %b, %4, %5, %4, %5, %4, %5, %4, %5, %4, %5, %11) { + error_message = "Expecting {0} == {1}. Extra {2,=5}, {3}, {{0}, {4}, {5}, {6}, {7}, {11}.", has_side_effect = true - } : (tensor, tensor, tensor) -> () + } : (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () return %b : tensor } } @@ -498,7 +501,7 @@ def f(x): # x: f32[b, 5] and b = 3, with a constraint b == 4. else: with self.assertRaisesRegex( errors.InvalidArgumentError, - 'Expecting 3 == 4'): + re.escape('Expecting 3 == 4. Extra 5 , 4, {0}, 5, 4, 5, 4, 11.')): self._assertOpOutputMatchesExpected(f, (x,), (res,)) def test_invalid_shape_assertion(self): diff --git a/tensorflow/compiler/xla/python/refine_polymorphic_shapes.cc b/tensorflow/compiler/xla/python/refine_polymorphic_shapes.cc index 01a09c1c244909..08192ac583e10e 100644 --- a/tensorflow/compiler/xla/python/refine_polymorphic_shapes.cc +++ b/tensorflow/compiler/xla/python/refine_polymorphic_shapes.cc @@ -44,7 +44,7 @@ namespace { constexpr llvm::StringRef shapeAssertionName = "shape_assertion"; constexpr llvm::StringRef errorMessageAttrName = "error_message"; // We bound the number of error_message_inputs for using llvm::formatv -constexpr int maxErrorMessageInputs = 4; +constexpr int maxErrorMessageInputs = 32; // TODO(necula): Remove this bound // This pass is needed when we have shape assertions. A shape assertion is // represented via the `stablehlo.custom_call @shape_assertion` @@ -198,24 +198,18 @@ struct CheckShapeAssertionsPass const mlir::SmallVector &errorMessageInputs) const { int nrErrorMessageInputs = errorMessageInputs.size(); auto errorMessageFormat = errorMessage.data(); - switch (nrErrorMessageInputs) { - case 0: - return errorMessageFormat; - case 1: - return llvm::formatv(errorMessageFormat, errorMessageInputs[0]); - case 2: - return llvm::formatv(errorMessageFormat, errorMessageInputs[0], - errorMessageInputs[1]); - case 3: - return llvm::formatv(errorMessageFormat, errorMessageInputs[0], - errorMessageInputs[1], errorMessageInputs[2]); - case 4: - return llvm::formatv(errorMessageFormat, errorMessageInputs[0], - errorMessageInputs[1], errorMessageInputs[2], - errorMessageInputs[3]); - default: - return errorMessageFormat; - } + if (nrErrorMessageInputs == 0) return errorMessageFormat; + auto errInput = [nrErrorMessageInputs, &errorMessageInputs](int idx) { + return (idx < nrErrorMessageInputs ? errorMessageInputs[idx] : -1); + }; + return llvm::formatv( + errorMessageFormat, errInput(0), errInput(1), errInput(2), errInput(3), + errInput(4), errInput(5), errInput(6), errInput(7), errInput(8), + errInput(9), errInput(10), errInput(11), errInput(12), errInput(13), + errInput(14), errInput(15), errInput(16), errInput(17), errInput(18), + errInput(19), errInput(20), errInput(21), errInput(22), errInput(23), + errInput(24), errInput(25), errInput(26), errInput(27), errInput(28), + errInput(29), errInput(30), errInput(31)); } mlir::StringRef getArgument() const override { From d84a2cf517d213fc3e65b11fcb8d2f7947965ba2 Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Fri, 28 Jul 2023 04:07:56 -0700 Subject: [PATCH 307/410] Remove duplicate random dependency from ROCM build. PiperOrigin-RevId: 551811336 --- tensorflow/compiler/xla/service/gpu/BUILD | 1 - 1 file changed, 1 deletion(-) diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 1b09d9408c55fb..81606aab58f6fd 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -984,7 +984,6 @@ cc_library( ]) + if_rocm_is_configured([ "//tensorflow/compiler/xla/stream_executor/rocm:stream_executor_rocm", "@local_config_rocm//rocm:rocm_headers", - "//tensorflow/tsl/platform:random", ]), ) From 0f3da2e8b46fbb09043b0f454e4069d0196a7c6d Mon Sep 17 00:00:00 2001 From: George Necula Date: Fri, 28 Jul 2023 05:40:35 -0700 Subject: [PATCH 308/410] Introduce the DynamicTopKOp experiment This CL implements a dynamic version of the TopKOp. Once the dynamism RFC is figured out, we expect to have an upstream representation for this notion, but in the meanwhile it is modelled via a `stablehlo.dynamic_top_k` custom call. PiperOrigin-RevId: 551827051 --- third_party/stablehlo/temporary.patch | 406 +++++++++++++++++++++++++- 1 file changed, 394 insertions(+), 12 deletions(-) diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch index ddd7f6c1f59a86..10e2117dd3c60e 100644 --- a/third_party/stablehlo/temporary.patch +++ b/third_party/stablehlo/temporary.patch @@ -278,7 +278,7 @@ diff --ruN a/stablehlo/stablehlo/dialect/CMakeLists.txt b/stablehlo/stablehlo/di diff --ruN a/stablehlo/stablehlo/dialect/ExperimentalOps.cpp b/stablehlo/stablehlo/dialect/ExperimentalOps.cpp --- stablehlo/stablehlo/dialect/ExperimentalOps.cpp +++ stablehlo/stablehlo/dialect/ExperimentalOps.cpp -@@ -0,0 +1,392 @@ +@@ -0,0 +1,504 @@ +/* Copyright 2023 The StableHLO Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); @@ -669,12 +669,124 @@ diff --ruN a/stablehlo/stablehlo/dialect/ExperimentalOps.cpp b/stablehlo/stableh + return DynamicRngBitGeneratorOpAdaptor(op); +} + ++LogicalResult DynamicTopKOpAdaptor::verify() { ++ if (op_->getNumOperands() != 2) ++ return op_.emitError("expects size(operands) = 2"); ++ if (op_->getNumResults() != 2) ++ return op_.emitError("expects size(results) = 2"); ++ for (const auto& attr : op_->getAttrs()) { ++ // api_version and backend_config have default values. ++ // call_target_name should be "stablehlo.dynamic_top_k". ++ if (attr.getName() != "api_version" && attr.getName() != "backend_config" && ++ attr.getName() != "call_target_name") ++ return op_.emitError() ++ << attr.getName() << " is not a supported attribute"; ++ } ++ if (!op_.getBackendConfig().empty()) ++ return op_.emitError() << "expects an empty backend_config"; ++ if (op_.getCallTargetName() != "stablehlo.dynamic_top_k") ++ return op_.emitError() << "expects @stablehlo.dynamic_top_k"; ++ ++ auto operand = op_.getInputs()[0]; ++ auto k = op_.getInputs()[1]; ++ auto values = op_.getResults()[0]; ++ auto indices = op_.getResults()[1]; ++ ++ // dynamic_top_k_i1 ++ auto operandType = operand.getType().dyn_cast(); ++ if (!operandType || !operandType.hasRank() || operandType.getRank() < 1 || ++ !operandType.getElementType().isIntOrFloat()) ++ return op_.emitError() ++ << "expects operand #0 " ++ << "to be a tensor of integer or floating-point type " ++ << "of rank at least 1"; ++ ++ // dynamic_top_k_i2 ++ auto kType = k.getType().dyn_cast(); ++ if (!kType || !kType.hasRank() || ++ kType.getRank() != 0 || !kType.getElementType().isIntOrIndex()) ++ return op_.emitError() ++ << "expects k (operand #1) " ++ << "to be a 0-dimensional tensor of integer or index type"; ++ ++ // dynamic_top_k_o1 ++ auto valuesType = values.getType().dyn_cast(); ++ if (!valuesType || !valuesType.hasRank() || valuesType.getRank() < 1 || ++ !valuesType.getElementType().isIntOrFloat()) ++ return op_.emitError() ++ << "expects values (result #0) " ++ << "to be a tensor of integer or floating-point type " ++ << "of rank at least 1"; ++ ++ // dynamic_top_k_o2 ++ auto indicesType = indices.getType().dyn_cast(); ++ if (!indicesType || !indicesType.hasRank() || indicesType.getRank() < 1 || ++ !indicesType.getElementType().isSignlessInteger(32)) ++ return op_.emitError() << "expects indices (result #1) " ++ << "to be a tensor of si32 of rank at least 1"; ++ ++ // dynamic_top_k_c1 ++ auto operandLastDim = operandType.getRank() - 1; ++ SmallVector expectedValuesShape(operandType.getShape()); ++ expectedValuesShape[operandLastDim] = ++ valuesType.getDimSize(valuesType.getRank() - 1); ++ if (failed(verifyCompatibleShape(expectedValuesShape, valuesType.getShape()))) ++ return op_.emitError() << "expects the values shape to match the operand " ++ "shape in all but the last dimension"; ++ ++ // dynamic_top_k_c2 ++ if (valuesType.getElementType() != operandType.getElementType()) ++ return op_.emitError() ++ << "expects the values element type to be the same as the operand " ++ << "element type"; ++ ++ // dynamic_top_k_c3 ++ if (!operandType.isDynamicDim(operandLastDim) && ++ !valuesType.isDynamicDim(operandLastDim) && ++ operandType.getDimSize(operandLastDim) < ++ valuesType.getDimSize(operandLastDim)) ++ return op_.emitError() << "expects the values last dimension to have size " ++ "at least as large " ++ << "as operand last dimension"; ++ ++ // dynamic_top_k_c4 ++ if (failed( ++ verifyCompatibleShape(indicesType.getShape(), valuesType.getShape()))) ++ return op_.emitError() ++ << "expects the indices shape to match the values shape"; ++ ++ return success(); ++} ++ ++TypedValue DynamicTopKOpAdaptor::getOperand() { ++ return op_.getInputs()[0].cast>(); ++} ++ ++TypedValue DynamicTopKOpAdaptor::getK() { ++ return op_.getInputs()[1].cast>(); ++} ++ ++ ++TypedValue DynamicTopKOpAdaptor::getValues() { ++ return op_.getResults()[0].cast>(); ++} ++ ++TypedValue DynamicTopKOpAdaptor::getIndices() { ++ return op_.getResults()[1].cast>(); ++} ++ ++std::optional getDynamicTopKOp( ++ CustomCallOp op) { ++ if (op.getCallTargetName() != "stablehlo.dynamic_top_k") return {}; ++ return DynamicTopKOpAdaptor(op); ++} ++ +} // namespace stablehlo +} // namespace mlir diff --ruN a/stablehlo/stablehlo/dialect/ExperimentalOps.h b/stablehlo/stablehlo/dialect/ExperimentalOps.h --- stablehlo/stablehlo/dialect/ExperimentalOps.h +++ stablehlo/stablehlo/dialect/ExperimentalOps.h -@@ -0,0 +1,170 @@ +@@ -0,0 +1,227 @@ +/* Copyright 2023 The StableHLO Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); @@ -767,7 +879,7 @@ diff --ruN a/stablehlo/stablehlo/dialect/ExperimentalOps.h b/stablehlo/stablehlo + CustomCallOp op_; +}; + -+// Wraps a custom call in a DynamicReduceWindowAdaptor. ++// Wraps a custom call in a DynamicReduceWindowOpAdaptor. +// Fails if the call_target_name of the custom call doesn't match +// "stablehlo.dynamic_reduce_window". +std::optional getDynamicReduceWindowOp( @@ -835,12 +947,69 @@ diff --ruN a/stablehlo/stablehlo/dialect/ExperimentalOps.h b/stablehlo/stablehlo + CustomCallOp op_; +}; + -+// Wraps a custom call in a DynamicReduceWindowAdaptor. ++// Wraps a custom call in a DynamicRngBitGeneratorOpAdaptor. +// Fails if the call_target_name of the custom call doesn't match +// "stablehlo.dynamic_rng_bit_generator". +std::optional getDynamicRngBitGeneratorOp( + CustomCallOp op); + ++// The DynamicTopKOp experiment provides a dynamic version of ++// TopKOp. Once the dynamism RFC is figured out, we expect to have an ++// upstream representation for this notion. ++// ++// Within this experiment, DynamicTopKOp is represented via the ++// `stablehlo.custom_call @stablehlo.dynamic_top_k` custom call. ++// This custom call has the regular operand of TopKOp plus an ++// additional `k` operand that determines the shape of the output. ++// ++// Semantics of DynamicTopKOp are inherited from semantics of Chlo.TopKOp. ++// ++// #### Inputs ++// ++// | Label | Name | Type | ++// |-------|-----------------|----------------------------------------------| ++// | (I1) | `operand` | tensor of integer or floating-point type | ++// | (I2) | `k` | 0-dimensional tensor of integer or index type| ++// ++// #### Outputs ++// ++// | Name | Type | ++// |----------------|------------------------------------------| ++// | `values` | tensor of integer or floating-point type | ++// | `indices` | tensor of si32 type | ++// ++// #### Constraints ++// ++// * (C1) `shape(values)[:-1] = shape(operand)[:-1]` ++// * (C2) `element_type(values) = element_type(operand)` ++// * (C3) `shape(values)[-1] <= shape(operand)[-1]` ++// * (C4) `shape(indices) = shape(values)` ++class DynamicTopKOpAdaptor { ++ public: ++ DynamicTopKOpAdaptor(CustomCallOp op) : op_(op) {} ++ operator Operation*() { return op_; } ++ Operation* operator->() { return op_; } ++ ++ // These accessors assume that the operation is well-formed (i.e. that it ++ // can pass verification). ++ TypedValue getOperand(); ++ TypedValue getK(); ++ TypedValue getValues(); ++ TypedValue getIndices(); ++ ++ // Verifies the constraints documented above. ++ // Emits errors if errors are detected. ++ LogicalResult verify(); ++ ++ private: ++ CustomCallOp op_; ++}; ++ ++// Wraps a custom call in a DynamicTopKOpAdaptor. ++// Fails if the call_target_name of the custom call doesn't match ++// "stablehlo.dynamic_top_k". ++std::optional getDynamicTopKOp(CustomCallOp op); ++ +} // namespace stablehlo +} // namespace mlir + @@ -4212,7 +4381,7 @@ diff --ruN a/stablehlo/stablehlo/tests/stablehlo_canonicalize_dynamism.mlir b/st // CHECK-LABEL: func @dynamic_reshape_success func.func @dynamic_reshape_success(%arg0: tensor<4xf32>) -> tensor<1x4xf32> { // CHECK-NOT: stablehlo.dynamic_reshape -@@ -452,6 +618,44 @@ +@@ -452,6 +618,185 @@ %0 = stablehlo.constant dense<[1, 4]> : tensor<2xi64> %1 = stablehlo.dynamic_reshape %arg0, %0 : (tensor<4xf32>, tensor<2xi64>) -> tensor<1x?xf32> return %1 : tensor<1x?xf32> @@ -4254,13 +4423,154 @@ diff --ruN a/stablehlo/stablehlo/tests/stablehlo_canonicalize_dynamism.mlir b/st + rng_algorithm = #stablehlo + } : (tensor<2xui64>, tensor<2xi64>) -> (tensor<2xui64>, tensor) + return %1#1 : tensor ++} ++ ++// ----- ++ ++// CHECK-LABEL: func @dynamic_top_k_success ++func.func @dynamic_top_k_success(%arg0: tensor<16xf32>) -> (tensor<3xf32>, tensor<3xi32>) { ++ // CHECK: chlo.top_k ++ %k = stablehlo.constant dense<3> : tensor ++ %1:2 = stablehlo.custom_call @stablehlo.dynamic_top_k(%arg0, %k) : (tensor<16xf32>, tensor) -> (tensor<3xf32>, tensor<3xi32>) ++ return %1#0, %1#1 : tensor<3xf32>, tensor<3xi32> ++} ++ ++// ----- ++ ++// CHECK-LABEL: func @dynamic_top_k_failure_k_mismatch ++func.func @dynamic_top_k_failure_k_mismatch(%arg0: tensor<16xf32>) -> (tensor<3xf32>, tensor<3xi32>) { ++ // CHECK: @stablehlo.dynamic_top_k ++ %k = stablehlo.constant dense<4> : tensor ++ %1:2 = stablehlo.custom_call @stablehlo.dynamic_top_k(%arg0, %k) : (tensor<16xf32>, tensor) -> (tensor<3xf32>, tensor<3xi32>) ++ return %1#0, %1#1 : tensor<3xf32>, tensor<3xi32> ++} ++ ++// ----- ++ ++// dynamic_top_k I1 ++// CHECK-LABEL: func @dynamic_top_k_error_operand_not_float ++func.func @dynamic_top_k_error_operand_not_float(%arg0: tensor<16xcomplex>) -> (tensor<3xcomplex>, tensor<3xi32>) { ++ // expected-error@+2{{expects operand #0 to be a tensor of integer or floating-point type}} ++ %k = stablehlo.constant dense<3> : tensor ++ %1:2 = stablehlo.custom_call @stablehlo.dynamic_top_k(%arg0, %k) : (tensor<16xcomplex>, tensor) -> (tensor<3xcomplex>, tensor<3xi32>) ++ return %1#0, %1#1 : tensor<3xcomplex>, tensor<3xi32> ++} ++ ++// ----- ++ ++// dynamic_top_k I1 ++// CHECK-LABEL: func @dynamic_top_k_error_operand_unranked ++func.func @dynamic_top_k_error_operand_unranked(%arg0: tensor<*xf32>) -> (tensor<3xf32>, tensor<3xi32>) { ++ // expected-error@+2{{expects operand #0 to be a tensor of integer or floating-point type of rank at least 1}} ++ %k = stablehlo.constant dense<3> : tensor ++ %1:2 = stablehlo.custom_call @stablehlo.dynamic_top_k(%arg0, %k) : (tensor<*xf32>, tensor) -> (tensor<3xf32>, tensor<3xi32>) ++ return %1#0, %1#1 : tensor<3xf32>, tensor<3xi32> ++} ++ ++// ----- ++ ++// dynamic_top_k I1 ++// CHECK-LABEL: func @dynamic_top_k_error_scalar_operand ++func.func @dynamic_top_k_error_scalar_operand(%arg0: tensor) -> (tensor<3xf32>, tensor<3xi32>) { ++ // expected-error@+2{{expects operand #0 to be a tensor of integer or floating-point type of rank at least 1}} ++ %k = stablehlo.constant dense<3> : tensor ++ %1:2 = stablehlo.custom_call @stablehlo.dynamic_top_k(%arg0, %k) : (tensor, tensor) -> (tensor<3xf32>, tensor<3xi32>) ++ return %1#0, %1#1 : tensor<3xf32>, tensor<3xi32> ++} ++ ++// ----- ++ ++// dynamic_top_k I2 ++// CHECK-LABEL: func @dynamic_top_k_error_k_not_integer ++func.func @dynamic_top_k_error_k_not_integer(%arg0: tensor<16xf32>) -> (tensor<3xf32>, tensor<3xi32>) { ++ // expected-error@+2{{expects k (operand #1) to be a 0-dimensional tensor of integer or index type}} ++ %k = stablehlo.constant dense<3.> : tensor ++ %1:2 = stablehlo.custom_call @stablehlo.dynamic_top_k(%arg0, %k) : (tensor<16xf32>, tensor) -> (tensor<3xf32>, tensor<3xi32>) ++ return %1#0, %1#1 : tensor<3xf32>, tensor<3xi32> ++} ++ ++// ----- ++ ++// dynamic_top_k I2 ++// CHECK-LABEL: func @dynamic_top_k_error_k_not_scalar ++func.func @dynamic_top_k_error_k_not_scalar(%arg0: tensor<16xf32>) -> (tensor<3xf32>, tensor<3xi32>) { ++ // expected-error@+2{{expects k (operand #1) to be a 0-dimensional tensor of integer or index type}} ++ %k = stablehlo.constant dense<3> : tensor<1xui64> ++ %1:2 = stablehlo.custom_call @stablehlo.dynamic_top_k(%arg0, %k) : (tensor<16xf32>, tensor<1xui64>) -> (tensor<3xf32>, tensor<3xi32>) ++ return %1#0, %1#1 : tensor<3xf32>, tensor<3xi32> ++} ++ ++// ----- ++ ++// dynamic_top_k O1 ++// CHECK-LABEL: func @dynamic_top_k_error_values_not_float ++func.func @dynamic_top_k_error_values_not_float(%arg0: tensor<16xf32>) -> (tensor<3xcomplex>, tensor<3xi32>) { ++ // expected-error@+2{{expects values (result #0) to be a tensor of integer or floating-point type}} ++ %k = stablehlo.constant dense<3> : tensor ++ %1:2 = stablehlo.custom_call @stablehlo.dynamic_top_k(%arg0, %k) : (tensor<16xf32>, tensor) -> (tensor<3xcomplex>, tensor<3xi32>) ++ return %1#0, %1#1 : tensor<3xcomplex>, tensor<3xi32> ++} ++ ++// ----- ++ ++// dynamic_top_k O2 ++// CHECK-LABEL: func @dynamic_top_k_error_indices_not_i32 ++func.func @dynamic_top_k_error_indices_not_i32(%arg0: tensor<16xf32>) -> (tensor<3xf32>, tensor<3xi64>) { ++ // expected-error@+2{{expects indices (result #1) to be a tensor of si32}} ++ %k = stablehlo.constant dense<3> : tensor ++ %1:2 = stablehlo.custom_call @stablehlo.dynamic_top_k(%arg0, %k) : (tensor<16xf32>, tensor) -> (tensor<3xf32>, tensor<3xi64>) ++ return %1#0, %1#1 : tensor<3xf32>, tensor<3xi64> ++} ++ ++// ----- ++ ++// dynamic_top_k C1 ++// CHECK-LABEL: func @dynamic_top_k_error_values_bad_rank ++func.func @dynamic_top_k_error_values_bad_rank(%arg0: tensor<16xf32>) -> (tensor<3x4xf32>, tensor<3xi32>) { ++ // expected-error@+2{{expects the values shape to match the operand shape in all but the last dimension}} ++ %k = stablehlo.constant dense<3> : tensor ++ %1:2 = stablehlo.custom_call @stablehlo.dynamic_top_k(%arg0, %k) : (tensor<16xf32>, tensor) -> (tensor<3x4xf32>, tensor<3xi32>) ++ return %1#0, %1#1 : tensor<3x4xf32>, tensor<3xi32> ++} ++ ++// ----- ++ ++// dynamic_top_k C2 ++// CHECK-LABEL: func @dynamic_top_k_error_values_bad_element_type ++func.func @dynamic_top_k_error_values_bad_element_type(%arg0: tensor<16xf32>) -> (tensor<3xf64>, tensor<3xi32>) { ++ // expected-error@+2{{expects the values element type to be the same as the operand element type}} ++ %k = stablehlo.constant dense<3> : tensor ++ %1:2 = stablehlo.custom_call @stablehlo.dynamic_top_k(%arg0, %k) : (tensor<16xf32>, tensor) -> (tensor<3xf64>, tensor<3xi32>) ++ return %1#0, %1#1 : tensor<3xf64>, tensor<3xi32> ++} ++ ++// ----- ++ ++// dynamic_top_k C3 ++// CHECK-LABEL: func @dynamic_top_k_error_values_last_dim_too_large ++func.func @dynamic_top_k_error_values_last_dim_too_large(%arg0: tensor<16xf32>) -> (tensor<17xf32>, tensor<3xi32>) { ++ // expected-error@+2{{expects the values last dimension to have size at least as large as operand last dimension}} ++ %k = stablehlo.constant dense<17> : tensor ++ %1:2 = stablehlo.custom_call @stablehlo.dynamic_top_k(%arg0, %k) : (tensor<16xf32>, tensor) -> (tensor<17xf32>, tensor<3xi32>) ++ return %1#0, %1#1 : tensor<17xf32>, tensor<3xi32> ++} ++ ++// ----- ++ ++// dynamic_top_k C4 ++// CHECK-LABEL: func @dynamic_top_k_error_indices_shape_mismatch ++func.func @dynamic_top_k_error_indices_shape_mismatch(%arg0: tensor<16xf32>) -> (tensor<3xf32>, tensor<4xi32>) { ++ // expected-error@+2{{expects the indices shape to match the values shape}} ++ %k = stablehlo.constant dense<3> : tensor ++ %1:2 = stablehlo.custom_call @stablehlo.dynamic_top_k(%arg0, %k) : (tensor<16xf32>, tensor) -> (tensor<3xf32>, tensor<4xi32>) ++ return %1#0, %1#1 : tensor<3xf32>, tensor<4xi32> } // ----- diff --ruN a/stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir b/stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir --- stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir +++ stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir -@@ -607,12 +607,45 @@ +@@ -607,12 +607,55 @@ // ----- @@ -4303,21 +4613,43 @@ diff --ruN a/stablehlo/stablehlo/tests/stablehlo_refine_shapes.mlir b/stablehlo/ + rng_algorithm = #stablehlo + } : (tensor<2xui64>, tensor<2xi64>) -> (tensor, tensor<*xf32>) + func.return %1#0, %1#1 : tensor, tensor<*xf32> ++} ++ ++// ----- ++ ++// CHECK-LABEL: func @refine_dynamic_top_k ++func.func @refine_dynamic_top_k(%arg0: tensor<16xf32>) -> (tensor, tensor) { ++ // CHECK: stablehlo.dynamic_top_k{{.*}} -> (tensor<4xf32>, tensor<4xi32>) ++ %k = stablehlo.constant dense<4> : tensor ++ %1:2 = stablehlo.custom_call @stablehlo.dynamic_top_k(%arg0, %k) : (tensor<16xf32>, tensor) -> (tensor, tensor) ++ return %1#0, %1#1 : tensor, tensor } // ----- +diff --ruN a/stablehlo/stablehlo/transforms/Passes.td b/stablehlo/stablehlo/transforms/Passes.td +--- stablehlo/stablehlo/transforms/Passes.td ++++ stablehlo/stablehlo/transforms/Passes.td +@@ -25,6 +25,7 @@ + For example, if the output_shape operand of DynamicReshapeOp is a constant + value, then the operation can be transformed to ReshapeOp. + }]; ++ let dependentDialects = ["mlir::chlo::ChloDialect"]; + } + + def StablehloLegalizeToVhloPass : Pass<"stablehlo-legalize-to-vhlo", "ModuleOp"> { diff --ruN a/stablehlo/stablehlo/transforms/StablehloCanonicalizeDynamism.cpp b/stablehlo/stablehlo/transforms/StablehloCanonicalizeDynamism.cpp --- stablehlo/stablehlo/transforms/StablehloCanonicalizeDynamism.cpp +++ stablehlo/stablehlo/transforms/StablehloCanonicalizeDynamism.cpp -@@ -24,6 +24,7 @@ +@@ -24,6 +24,8 @@ #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" ++#include "stablehlo/dialect/ChloOps.h" +#include "stablehlo/dialect/ExperimentalOps.h" #include "stablehlo/dialect/StablehloOps.h" #include "stablehlo/transforms/Passes.h" -@@ -198,6 +199,54 @@ +@@ -198,6 +200,54 @@ } }; @@ -4372,7 +4704,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloCanonicalizeDynamism.cpp b/ struct CanonicalizeDynamicReshapeOpPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; -@@ -210,6 +259,27 @@ +@@ -210,6 +260,56 @@ if (!op.getType().hasStaticShape()) return rewriter.notifyMatchFailure(op, "expected static result type"); rewriter.replaceOpWithNewOp(op, op.getType(), op.getOperand()); @@ -4397,16 +4729,46 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloCanonicalizeDynamism.cpp b/ + return rewriter.notifyMatchFailure(op, "expected static output type"); + rewriter.replaceOpWithNewOp( + op, op->getResultTypes(), op.getRngAlgorithm(), op.getInitialState()); ++ return success(); ++ } ++}; ++ ++struct CanonicalizeDynamicTopKOpPattern ++ : public OpRewritePattern { ++ using OpRewritePattern::OpRewritePattern; ++ LogicalResult matchAndRewrite(CustomCallOp impl, ++ PatternRewriter& rewriter) const override { ++ auto maybeOp = getDynamicTopKOp(impl); ++ if (!maybeOp || failed(maybeOp->verify())) return failure(); ++ DynamicTopKOpAdaptor op = *maybeOp; ++ ++ SmallVector k; ++ if (failed(hlo::matchInts(op.getK(), k))) ++ return rewriter.notifyMatchFailure(impl, "expected constant k"); ++ ++ // We rely on many of the properties checked by verification. ++ auto valuesType = op.getValues().getType().cast(); ++ auto valuesLastDimSize = valuesType.getShape()[valuesType.getRank() - 1]; ++ if (hlo::isDynamicDimSize(valuesLastDimSize) || ++ valuesLastDimSize != k[0]) ++ return rewriter.notifyMatchFailure( ++ op, ++ "expected value of k to match the values last dimension size of " ++ "static values type (result #0)"); ++ ++ rewriter.replaceOpWithNewOp( ++ op, op->getResultTypes(), op.getOperand(), k[0]); return success(); } }; -@@ -320,7 +390,9 @@ +@@ -320,7 +420,10 @@ patterns.add(&getContext()); patterns.add(&getContext()); patterns.add(&getContext()); + patterns.add(&getContext()); patterns.add(&getContext()); + patterns.add(&getContext()); ++ patterns.add(&getContext()); patterns.add( &getContext()); patterns.add(&getContext()); @@ -4421,7 +4783,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehl #include "stablehlo/dialect/StablehloOps.h" #include "stablehlo/dialect/TypeInference.h" #include "stablehlo/transforms/Passes.h" -@@ -844,12 +845,78 @@ +@@ -844,12 +845,97 @@ } }; @@ -4497,16 +4859,36 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehl + // The shape of `output_state` (the first result) is determined by the shape + // of `initial_state`, so we ignore it and provide an empty refinement. + return refineReturnTypes(rewriter, op, {{initialStateType}, {outputShape}}); ++ } ++}; ++ ++struct RefineDynamicTopKOpPattern : public OpRewritePattern { ++ using OpRewritePattern::OpRewritePattern; ++ LogicalResult matchAndRewrite(CustomCallOp impl, ++ PatternRewriter& rewriter) const override { ++ auto maybeOp = getDynamicTopKOp(impl); ++ if (!maybeOp || failed(maybeOp->verify())) return failure(); ++ DynamicTopKOpAdaptor op = *maybeOp; ++ ++ auto operandType = op.getOperand().getType().cast(); ++ SmallVector outputShape(operandType.getShape()); ++ SmallVector k; ++ if (failed(hlo::matchInts(op.getK(), k))) ++ return rewriter.notifyMatchFailure(op, "expected constant k"); ++ ++ outputShape[operandType.getRank() - 1] = k[0]; ++ return refineReturnTypes(rewriter, op, {{outputShape}, {outputShape}}); } }; -@@ -1181,7 +1248,9 @@ +@@ -1181,7 +1267,10 @@ patterns.add(&getContext()); patterns.add(&getContext()); patterns.add(&getContext()); + patterns.add(&getContext()); patterns.add(&getContext()); + patterns.add(&getContext()); ++ patterns.add(&getContext()); patterns.add(&getContext()); patterns.add(&getContext()); patterns.add(&getContext()); From f1759fd02e88f9f84d95345462b67a758c442e4e Mon Sep 17 00:00:00 2001 From: Fergus Henderson Date: Fri, 28 Jul 2023 05:47:28 -0700 Subject: [PATCH 309/410] Fix typo in test name. PiperOrigin-RevId: 551828157 --- tensorflow/lite/mutable_op_resolver_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/lite/mutable_op_resolver_test.cc b/tensorflow/lite/mutable_op_resolver_test.cc index 6f409a47533607..8622579a3c8aa3 100644 --- a/tensorflow/lite/mutable_op_resolver_test.cc +++ b/tensorflow/lite/mutable_op_resolver_test.cc @@ -64,7 +64,7 @@ TfLiteRegistration* GetDummy2Registration() { return ®istration; } -TEST(MutableOpResolverTest, FinOp) { +TEST(MutableOpResolverTest, FindOp) { MutableOpResolver resolver; resolver.AddBuiltin(BuiltinOperator_ADD, GetDummyRegistration()); From e3902ad17b5f2a7308320be0d3f5cf543afca7c5 Mon Sep 17 00:00:00 2001 From: Alan Kelly Date: Fri, 28 Jul 2023 06:17:32 -0700 Subject: [PATCH 310/410] IF STATIC: sharing node inputs and outputs with subgraph inputs and outputs IF DYNAMIC: sharing node inputs with subgraph inputs PiperOrigin-RevId: 551833537 --- tensorflow/lite/kernels/BUILD | 18 + tensorflow/lite/kernels/control_flow_common.h | 163 +++++ tensorflow/lite/kernels/if.cc | 248 ++++--- tensorflow/lite/kernels/if_test.cc | 682 ++++++++++++++++++ tensorflow/lite/kernels/subgraph_test_util.cc | 215 ++++++ tensorflow/lite/kernels/subgraph_test_util.h | 19 + tensorflow/lite/kernels/while.cc | 152 +--- tensorflow/lite/kernels/while_test.cc | 4 +- .../lite/profiling/profile_summarizer_test.cc | 8 +- 9 files changed, 1270 insertions(+), 239 deletions(-) create mode 100644 tensorflow/lite/kernels/control_flow_common.h diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD index 3b4628935c927e..03902b8da48151 100644 --- a/tensorflow/lite/kernels/BUILD +++ b/tensorflow/lite/kernels/BUILD @@ -582,6 +582,20 @@ cc_library( ], ) +cc_library( + name = "control_flow_common", + srcs = [], + hdrs = ["control_flow_common.h"], + compatible_with = get_compatible_with_portable(), + copts = tflite_copts(), + deps = [ + ":kernel_util", + "//tensorflow/lite:kernel_api", + "//tensorflow/lite/core:subgraph", + "//tensorflow/lite/core/c:common", + ], +) + # See also VARIABLE_KERNEL_SRCS below. BUILTIN_KERNEL_SRCS = [ "activations.cc", @@ -714,6 +728,7 @@ BUILTIN_KERNEL_DEPS = [ ":lstm_shared", ":op_macros", ":padding", + ":control_flow_common", "//third_party/eigen3", "@flatbuffers", "//tensorflow/lite:framework_stable", @@ -2612,6 +2627,9 @@ cc_test( ":test_main", "//tensorflow/lite:framework_stable", "//tensorflow/lite/core:framework_stable", + "//tensorflow/lite/delegates/xnnpack:xnnpack_delegate", + "//tensorflow/lite/kernels/internal:tensor", + "//tensorflow/lite/profiling:memory_info", "@com_google_googletest//:gtest", ], ) diff --git a/tensorflow/lite/kernels/control_flow_common.h b/tensorflow/lite/kernels/control_flow_common.h new file mode 100644 index 00000000000000..5bcc25850c3022 --- /dev/null +++ b/tensorflow/lite/kernels/control_flow_common.h @@ -0,0 +1,163 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_KERNELS_CONTROL_FLOW_COMMON_H_ +#define TENSORFLOW_LITE_KERNELS_CONTROL_FLOW_COMMON_H_ + +#include + +#include "tensorflow/lite/context_util.h" +#include "tensorflow/lite/core/c/common.h" +#include "tensorflow/lite/core/subgraph.h" +#include "tensorflow/lite/kernels/kernel_util.h" + +namespace tflite { +namespace ops { +namespace builtin { +// Propagate tensor shapes and types from `src_tensor_indices` in `src_subgraph` +// to `dst_tensor_indices` in `dst_subgraph`. +// +// When `resize_subgraph_inputs` is true, the function calls subgraphs's +// `ResizeInputTensor` function, and it may trigger the memory planner to +// reallocate memory. +// When `resize_subgraph_inputs` is false, it implies `context` belongs to +// `dst_subgraph`. The function calls `context->ResizeTensor`. This happens +// when resizing `While` op's outputs. +template +TfLiteStatus CopyTensorsShapeAndType(TfLiteContext* context, + Subgraph* src_subgraph, + const SrcVector& src_tensor_indices, + Subgraph* dst_subgraph, + const DstVector& dst_tensor_indices, + bool resize_subgraph_inputs) { + TF_LITE_ENSURE_EQ(context, src_tensor_indices.size(), + dst_tensor_indices.size()); + for (int i = 0; i < src_tensor_indices.size(); ++i) { + // Skip copying unused destination tensors. + if (dst_tensor_indices[i] == kTfLiteOptionalTensor) continue; + + const TfLiteTensor* src_tensor = + src_subgraph->tensor(src_tensor_indices[i]); + + TfLiteTensor* dst_tensor = dst_subgraph->tensor(dst_tensor_indices[i]); + if (resize_subgraph_inputs) { + std::vector dims(src_tensor->dims->data, + src_tensor->dims->data + src_tensor->dims->size); + dst_subgraph->ResizeInputTensor(dst_tensor_indices[i], dims); + } else { + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, dst_tensor, + TfLiteIntArrayCopy(src_tensor->dims))); + } + dst_tensor->type = src_tensor->type; + } + return kTfLiteOk; +} + +// Copy the tensors data from tensors `src_tensor_indices` in `src_subgraph` +// to `dst_tensor_indices` in `dst_subgraph`. +template +TfLiteStatus CopyTensorsData(TfLiteContext* context, Subgraph* src_subgraph, + const SrcVector& src_tensor_indices, + Subgraph* dst_subgraph, + const DstVector& dst_tensor_indices) { + TF_LITE_ENSURE_EQ(context, src_tensor_indices.size(), + dst_tensor_indices.size()); + for (int i = 0; i < src_tensor_indices.size(); ++i) { + // Skip copying unused destination tensors. + if (dst_tensor_indices[i] == kTfLiteOptionalTensor) continue; + + const TfLiteTensor* src_tensor = + src_subgraph->tensor(src_tensor_indices[i]); + TfLiteTensor* dst_tensor = dst_subgraph->tensor(dst_tensor_indices[i]); + if (IsDynamicTensor(dst_tensor)) { + TfLiteTensorRealloc(src_tensor->bytes, dst_tensor); + } + TF_LITE_ENSURE_OK(context, TfLiteTensorCopy(src_tensor, dst_tensor)); + } + return kTfLiteOk; +} + +// Propagate tensor shapes and types from `src_tensor_indices` in `src_subgraph` +// to `dst_tensor_indices` in `dst_subgraph` and copy data deeply. +template +TfLiteStatus DeepCopyTensorsShapeTypeData( + TfLiteContext* context, TfLiteNode* node, Subgraph* src_subgraph, + const SrcVector& src_tensor_indices, Subgraph* dst_subgraph, + const DstVector& dst_tensor_indices, bool body_has_dynamic_output_tensors) { + if (body_has_dynamic_output_tensors) { + Subgraph* this_subgraph = reinterpret_cast(context->impl_); + bool resize_subgraph_inputs = (dst_subgraph != this_subgraph); + TF_LITE_ENSURE_OK( + context, CopyTensorsShapeAndType( + context, src_subgraph, src_tensor_indices, dst_subgraph, + dst_tensor_indices, resize_subgraph_inputs)); + if (resize_subgraph_inputs) { + TF_LITE_ENSURE_OK(context, dst_subgraph->AllocateTensors()); + } + } + TF_LITE_ENSURE_OK(context, + CopyTensorsData(context, src_subgraph, src_tensor_indices, + dst_subgraph, dst_tensor_indices)); + return kTfLiteOk; +} + +template +TfLiteStatus DeepOrShallowCopyTensorsShapeTypeData( + TfLiteContext* context, TfLiteNode* node, Subgraph* src_subgraph, + const SrcVector& src_tensor_indices, Subgraph* dst_subgraph, + const DstVector& dst_tensor_indices) { + // Resize the destination subgraph inputs. + for (int i = 0; i < src_tensor_indices.size(); ++i) { + // Skip copying unused destination tensors. + if (dst_tensor_indices[i] == kTfLiteOptionalTensor) continue; + if (src_tensor_indices[i] == kTfLiteOptionalTensor) continue; + + const TfLiteTensor* src_tensor = + src_subgraph->tensor(src_tensor_indices[i]); + TfLiteTensor* dst_tensor = dst_subgraph->tensor(dst_tensor_indices[i]); + std::vector dims(src_tensor->dims->data, + src_tensor->dims->data + src_tensor->dims->size); + dst_subgraph->ResizeInputTensor(dst_tensor_indices[i], dims); + dst_tensor->type = src_tensor->type; + if (!IsResourceOrVariant(src_tensor)) { + dst_tensor->bytes = 0; // Don't allocate memory with AllocateTensors(). + dst_tensor->data.raw = nullptr; + } + } + TF_LITE_ENSURE_OK(context, dst_subgraph->AllocateTensors()); + // Deep or shallow copy the data from src subgraph to dst. + for (int i = 0; i < src_tensor_indices.size(); ++i) { + // Skip copying unused destination tensors. + if (dst_tensor_indices[i] == kTfLiteOptionalTensor) continue; + if (src_tensor_indices[i] == kTfLiteOptionalTensor) continue; + + const TfLiteTensor* src_tensor = + src_subgraph->tensor(src_tensor_indices[i]); + TfLiteTensor* dst_tensor = dst_subgraph->tensor(dst_tensor_indices[i]); + if (IsResourceOrVariant(src_tensor)) { + TfLiteTensorRealloc(src_tensor->bytes, dst_tensor); + TF_LITE_ENSURE_OK(context, TfLiteTensorCopy(src_tensor, dst_tensor)); + } else { + dst_tensor->bytes = src_tensor->bytes; + dst_tensor->data.raw = src_tensor->data.raw; + } + } + return kTfLiteOk; +} +} // namespace builtin +} // namespace ops +} // namespace tflite + +#endif // TENSORFLOW_LITE_KERNELS_CONTROL_FLOW_COMMON_H_ diff --git a/tensorflow/lite/kernels/if.cc b/tensorflow/lite/kernels/if.cc index ad724f45876e0c..dce144d3a892d1 100644 --- a/tensorflow/lite/kernels/if.cc +++ b/tensorflow/lite/kernels/if.cc @@ -15,6 +15,7 @@ limitations under the License. #include +#include #include #include #include @@ -22,6 +23,7 @@ limitations under the License. #include "tensorflow/lite/core/c/builtin_op_data.h" #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/core/subgraph.h" +#include "tensorflow/lite/kernels/control_flow_common.h" #include "tensorflow/lite/kernels/internal/compatibility.h" #include "tensorflow/lite/kernels/kernel_util.h" @@ -33,6 +35,7 @@ namespace if_kernel { struct OpData { int then_subgraph_index; int else_subgraph_index; + bool subgraph_has_dynamic_output_tensors; }; void* Init(TfLiteContext* context, const char* buffer, size_t length) { @@ -40,6 +43,7 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) { const auto* params = reinterpret_cast(buffer); op_data->then_subgraph_index = params->then_subgraph_index; op_data->else_subgraph_index = params->else_subgraph_index; + op_data->subgraph_has_dynamic_output_tensors = false; return op_data; } @@ -48,7 +52,7 @@ void Free(TfLiteContext* context, void* buffer) { } TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { - const OpData* op_data = reinterpret_cast(node->user_data); + OpData* op_data = reinterpret_cast(node->user_data); TF_LITE_ENSURE(context, node->inputs->size > 0); @@ -80,30 +84,36 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, num_outputs, subgraph->outputs().size()); } - bool has_dynamic_output_tensors = false; + // Remove unused inputs of both subgraphs to skip copying unnecessary + // inputs. + then_subgraph->RemoveUnusedInputs(); + else_subgraph->RemoveUnusedInputs(); + + const int* const start = node->inputs->data + 1; + std::vector node_inputs(start, start + num_inputs); + // Prepare and check the subgraphs. + for (auto* subgraph : {then_subgraph, else_subgraph}) { + TF_LITE_ENSURE_OK( + context, CopyTensorsShapeAndType(context, this_subgraph, node_inputs, + subgraph, subgraph->inputs(), true)); + } + for (auto* subgraph : {then_subgraph, else_subgraph}) { for (int i = 0; i < num_inputs; ++i) { - // The first input of the node is the condition. The indices of the inputs - // passed to the subgraphs are offset by 1. - const TfLiteTensor* input; - TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, i + 1, &input)); - std::vector dims(input->dims->data, - input->dims->data + input->dims->size); - TF_LITE_ENSURE_OK(context, subgraph->ResizeInputTensor(i, dims)); - TfLiteTensor* subgraph_input = subgraph->tensor(subgraph->inputs()[i]); - if (IsDynamicTensor(input)) { - SetTensorToDynamic(subgraph_input); + int input_idx = subgraph->inputs()[i]; + if (input_idx == kTfLiteOptionalTensor) continue; + TfLiteTensor* subgraph_input = subgraph->tensor(input_idx); + if (!IsResourceOrVariant(subgraph_input)) { + // Set the allocation type to custom to prevent memory allocation. + subgraph_input->allocation_type = kTfLiteCustom; } - TF_LITE_ENSURE_TYPES_EQ(context, input->type, subgraph_input->type); } - // Note: The `Prepare` function is responsible to run `AllocateTensors` on - // both subgraphs. It's intentionally not to break out of the loop when - // finding a dynamic output tensor. TF_LITE_ENSURE_OK(context, subgraph->AllocateTensors()); - has_dynamic_output_tensors |= subgraph->HasDynamicTensors(); + op_data->subgraph_has_dynamic_output_tensors |= + subgraph->HasDynamicTensors(); } - if (!has_dynamic_output_tensors) { + if (!op_data->subgraph_has_dynamic_output_tensors) { for (int i = 0; i < num_outputs; ++i) { TfLiteTensor* then_output = then_subgraph->tensor(then_subgraph->outputs()[i]); @@ -112,20 +122,19 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // If the 2 subgraphs have static but different output shapes, the output // tensors of the IF op have dynamic sizes. if (!TfLiteIntArrayEqual(then_output->dims, else_output->dims)) { - has_dynamic_output_tensors = true; + op_data->subgraph_has_dynamic_output_tensors = true; break; } } } for (int i = 0; i < num_outputs; ++i) { + if (node->outputs->data[i] == kTfLiteOptionalTensor) continue; TfLiteTensor* output; TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &output)); - if (has_dynamic_output_tensors) { + if (op_data->subgraph_has_dynamic_output_tensors) { SetTensorToDynamic(output); } else { - // When there's no dynamic output tensors, the 2 subgraph has exactly - // the same static sized outputs. TfLiteTensor* then_output = then_subgraph->tensor(then_subgraph->outputs()[i]); TfLiteIntArray* output_size = TfLiteIntArrayCopy(then_output->dims); @@ -133,99 +142,154 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { context->ResizeTensor(context, output, output_size)); } } - return kTfLiteOk; } -TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - const OpData* op_data = reinterpret_cast(node->user_data); - - const TfLiteTensor* cond; - TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &cond)); - bool cond_value = cond->data.b[0]; +// Returns the subgraph input tensor index if the given output is also an input. +int output_is_input(int output_idx, const std::vector& subgraph_inputs) { + auto e = + std::find(subgraph_inputs.begin(), subgraph_inputs.end(), output_idx); + return (e != subgraph_inputs.end()) ? (e - subgraph_inputs.begin()) : -1; +} +// Evaluate IF op when subgraphs have dynamic outputs. +TfLiteStatus Eval_dynamic(TfLiteContext* context, TfLiteNode* node, + Subgraph* active_branch_subgraph) { Subgraph* this_subgraph = reinterpret_cast(context->impl_); - auto* subgraphs = this_subgraph->GetSubgraphs(); - - // Currently we copy the input / output between the subgraphs. This isn't - // optimized yet. - // TODO(b/120234921): Optimize and avoid copying tensors between subgraphs. - int active_branch_subgraph_index = - cond_value ? op_data->then_subgraph_index : op_data->else_subgraph_index; - Subgraph& active_branch_subgraph = - *(*subgraphs)[active_branch_subgraph_index]; - - // We release memory of the subgraph at the end of evaluation to save memory. - // So it's required to call AllocateTensors() for the second run. - TF_LITE_ENSURE_OK(context, active_branch_subgraph.AllocateTensors()); - - for (int i = 0; i < active_branch_subgraph.inputs().size(); ++i) { - const TfLiteTensor* input; - TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, i + 1, &input)); - TfLiteTensor* subgraph_input = - active_branch_subgraph.tensor(active_branch_subgraph.inputs()[i]); - - if (IsDynamicTensor(subgraph_input)) { - TfLiteTensorRealloc(input->bytes, subgraph_input); - } - TF_LITE_ENSURE_EQ(context, input->bytes, subgraph_input->bytes); - TfLiteTensorCopy(input, subgraph_input); + TF_LITE_ENSURE_OK(context, active_branch_subgraph->AllocateTensors()); + const int num_inputs = node->inputs->size - 1; + const int num_outputs = node->outputs->size; + const int* const start = node->inputs->data + 1; + std::vector node_inputs(start, start + num_inputs); + // node->inputs -> subgraph->inputs + TF_LITE_ENSURE_OK( + context, DeepOrShallowCopyTensorsShapeTypeData( + context, node, this_subgraph, node_inputs, + active_branch_subgraph, active_branch_subgraph->inputs())); + + // Invoke active_branch_subgraph subgraph + TF_LITE_ENSURE_OK(context, active_branch_subgraph->Invoke()); + for (int tensor_index : active_branch_subgraph->outputs()) { + active_branch_subgraph->EnsureTensorDataIsReadable(tensor_index); } - TF_LITE_ENSURE_OK(context, active_branch_subgraph.Invoke()); - - for (int tensor_index : active_branch_subgraph.outputs()) { - active_branch_subgraph.EnsureTensorDataIsReadable(tensor_index); - } + // subgraph->outputs -> node->outputs + TF_LITE_ENSURE_OK(context, + DeepCopyTensorsShapeTypeData( + context, node, active_branch_subgraph, + active_branch_subgraph->outputs(), this_subgraph, + TfLiteIntArrayView(node->outputs), true)); - bool has_dynamic_output_tensors = false; - for (int i = 0; i < node->outputs->size; ++i) { - TfLiteTensor* output; - TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &output)); - if (IsDynamicTensor(output)) { - has_dynamic_output_tensors = true; - break; + for (int i = 0; i < num_outputs; ++i) { + const int input_pos = output_is_input(active_branch_subgraph->outputs()[i], + active_branch_subgraph->inputs()); + if (input_pos != -1) { + TfLiteTensor* this_input = + this_subgraph->tensor(node->inputs->data[input_pos + 1]); + TfLiteTensor* this_output = this_subgraph->tensor(node->outputs->data[i]); + TfLiteTensorCopy(this_input, this_output); } } + return kTfLiteOk; +} - if (has_dynamic_output_tensors) { - for (int i = 0; i < node->outputs->size; ++i) { - TfLiteTensor* output; - TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &output)); - TfLiteTensor* subgraph_output = - active_branch_subgraph.tensor(active_branch_subgraph.outputs()[i]); - TfLiteIntArray* output_size = TfLiteIntArrayCopy(subgraph_output->dims); - TF_LITE_ENSURE_OK(context, - context->ResizeTensor(context, output, output_size)); +// Evaluate IF op when subgraphs has static outputs. +TfLiteStatus Eval_static(TfLiteContext* context, TfLiteNode* node, + Subgraph* active_branch_subgraph) { + Subgraph* this_subgraph = reinterpret_cast(context->impl_); + + const int num_inputs = node->inputs->size - 1; + const int num_outputs = node->outputs->size; + const int* const start = node->inputs->data + 1; + std::vector node_inputs(start, start + num_inputs); + for (int i = 0; i < num_outputs; ++i) { + int output_idx = active_branch_subgraph->outputs()[i]; + if (output_idx == kTfLiteOptionalTensor) continue; + TfLiteTensor* subgraph_output = active_branch_subgraph->tensor(output_idx); + if (!IsResourceOrVariant(subgraph_output) && + !IsConstantTensor(subgraph_output)) { + subgraph_output->allocation_type = kTfLiteCustom; } } - - for (int i = 0; i < active_branch_subgraph.outputs().size(); ++i) { - const TfLiteTensor* subgraph_output = - active_branch_subgraph.tensor(active_branch_subgraph.outputs()[i]); - TfLiteTensor* output; - TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &output)); - - if (IsDynamicTensor(output)) { - TfLiteTensorRealloc(subgraph_output->bytes, output); + // node->inputs -> subgraph->inputs + TF_LITE_ENSURE_OK( + context, DeepOrShallowCopyTensorsShapeTypeData( + context, node, this_subgraph, node_inputs, + active_branch_subgraph, active_branch_subgraph->inputs())); + + TF_LITE_ENSURE_OK( + context, + CopyTensorsShapeAndType(context, active_branch_subgraph, + active_branch_subgraph->outputs(), this_subgraph, + TfLiteIntArrayView(node->outputs), false)); + for (int i = 0; i < num_outputs; ++i) { + TfLiteTensor* this_output = this_subgraph->tensor(node->outputs->data[i]); + TfLiteTensor* subgraph_output = + active_branch_subgraph->tensor(active_branch_subgraph->outputs()[i]); + if (active_branch_subgraph->outputs()[i] == kTfLiteOptionalTensor) { + TfLiteTensor* this_input = + this_subgraph->tensor(node->inputs->data[i + 1]); + TfLiteTensorResizeMaybeCopy(this_input->bytes, this_output, false); + TfLiteTensorCopy(this_input, this_output); + } else { + const int input_pos = + output_is_input(active_branch_subgraph->outputs()[i], + active_branch_subgraph->inputs()); + if (input_pos != -1) { + TfLiteTensor* this_input = + this_subgraph->tensor(node->inputs->data[input_pos + 1]); + TfLiteTensorResizeMaybeCopy(this_input->bytes, this_output, false); + TfLiteTensorCopy(this_input, this_output); + } else if (IsConstantTensor(subgraph_output)) { + TfLiteTensorCopy(subgraph_output, this_output); + } else { + subgraph_output->data = this_output->data; + } } + } - TF_LITE_ENSURE_EQ(context, output->bytes, subgraph_output->bytes); - TfLiteTensorCopy(subgraph_output, output); + // Invoke subgraph + TF_LITE_ENSURE_OK(context, active_branch_subgraph->Invoke()); + for (int tensor_index : active_branch_subgraph->outputs()) { + active_branch_subgraph->EnsureTensorDataIsReadable(tensor_index); } - // Release memory of subgraphs to save the memory. Though it impacts latency, - // actual impacts looks very little, so no additional option is introduced for - // the feature until we find a different case. + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + OpData* op_data = reinterpret_cast(node->user_data); + Subgraph* this_subgraph = reinterpret_cast(context->impl_); + auto* subgraphs = this_subgraph->GetSubgraphs(); Subgraph* then_subgraph = (*subgraphs)[op_data->then_subgraph_index].get(); Subgraph* else_subgraph = (*subgraphs)[op_data->else_subgraph_index].get(); - TF_LITE_ENSURE_OK(context, then_subgraph->ReleaseMemory()); - TF_LITE_ENSURE_OK(context, else_subgraph->ReleaseMemory()); + + const TfLiteTensor* cond; + TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &cond)); + bool cond_value = cond->data.b[0]; + + Subgraph* active_branch_subgraph; + if (cond_value) { + active_branch_subgraph = then_subgraph; + } else { + active_branch_subgraph = else_subgraph; + } + + if (op_data->subgraph_has_dynamic_output_tensors) { + TF_LITE_ENSURE_OK(context, + Eval_dynamic(context, node, active_branch_subgraph)); + } else { + TF_LITE_ENSURE_OK(context, + Eval_static(context, node, active_branch_subgraph)); + } + + if (!this_subgraph->ShouldPreserveAllTensors()) { + TF_LITE_ENSURE_OK(context, active_branch_subgraph->ReleaseMemory()); + } return kTfLiteOk; } - } // namespace if_kernel TfLiteRegistration* Register_IF() { diff --git a/tensorflow/lite/kernels/if_test.cc b/tensorflow/lite/kernels/if_test.cc index 733ccbd9d4cb21..5fd734bba86b4d 100644 --- a/tensorflow/lite/kernels/if_test.cc +++ b/tensorflow/lite/kernels/if_test.cc @@ -20,14 +20,19 @@ limitations under the License. #include #include "tensorflow/lite/core/interpreter.h" +#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/kernels/subgraph_test_util.h" namespace tflite { using subgraph_test_util::CheckIntTensor; +using subgraph_test_util::CheckScalarStringTensor; +using subgraph_test_util::CheckStringTensor; using subgraph_test_util::ControlFlowOpTest; using subgraph_test_util::FillIntTensor; +using subgraph_test_util::FillScalarStringTensor; namespace { @@ -155,5 +160,682 @@ TEST_F(DynamicSubgraphIfTest, TestIfFalse) { CheckIntTensor(output, {5}, {0, 5, 7, 0, 0}); } +class IfTest : public ControlFlowOpTest {}; + +TEST_F(IfTest, TestWithXNNPACK) { + interpreter_ = std::make_unique(); + AddSubgraphs(2); + builder_->BuildXNNPACKSubgraph(interpreter_->subgraph(1)); + builder_->BuildXNNPACKSubgraph(interpreter_->subgraph(2)); + builder_->BuildFloatIfSubgraph(&interpreter_->primary_subgraph(), 3); + + const auto opt = TfLiteXNNPackDelegateOptionsDefault(); + TfLiteDelegate* xnnpack_delegate = TfLiteXNNPackDelegateCreate(&opt); + interpreter_->primary_subgraph().MarkAsDelegationSkippable(); + interpreter_->subgraph(1)->MarkAsDelegationSkippable(); + ASSERT_EQ(interpreter_->ModifyGraphWithDelegate(xnnpack_delegate), kTfLiteOk); + ASSERT_EQ(interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1}), + kTfLiteOk); + ASSERT_EQ(interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {1}), + kTfLiteOk); + ASSERT_EQ(interpreter_->ResizeInputTensor(interpreter_->inputs()[2], {1}), + kTfLiteOk); + ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); + interpreter_->typed_input_tensor(0)[0] = false; + float* input0 = + GetTensorData(interpreter_->tensor(interpreter_->inputs()[1])); + input0[0] = 1; + float* input1 = + GetTensorData(interpreter_->tensor(interpreter_->inputs()[2])); + input1[0] = 1; + + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + TfLiteTensor* output0 = interpreter_->tensor(interpreter_->outputs()[0]); + float* output0_data = GetTensorData(output0); + ASSERT_EQ(output0_data[0], 4); + TfLiteTensor* output1 = interpreter_->tensor(interpreter_->outputs()[1]); + float* output1_data = GetTensorData(output1); + ASSERT_EQ(output1_data[0], 4); + + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + interpreter_->typed_input_tensor(0)[0] = true; + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + TfLiteXNNPackDelegateDelete(xnnpack_delegate); +} + +TEST_F(IfTest, TestInputIsOutput) { + interpreter_ = std::make_unique(); + AddSubgraphs(2); + builder_->BuildInputIsOutputSubgraph(interpreter_->subgraph(1)); + builder_->BuildInputIsOutputSubgraph(interpreter_->subgraph(2)); + builder_->BuildMultiInputIfSubgraph(&interpreter_->primary_subgraph(), 4); + + ASSERT_EQ(interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1}), + kTfLiteOk); + ASSERT_EQ(interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {1}), + kTfLiteOk); + ASSERT_EQ(interpreter_->ResizeInputTensor(interpreter_->inputs()[2], {1}), + kTfLiteOk); + ASSERT_EQ(interpreter_->ResizeInputTensor(interpreter_->inputs()[3], {1}), + kTfLiteOk); + ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); + interpreter_->typed_input_tensor(0)[0] = true; + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {1}); + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[2]), {1}); + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[3]), {1}); + + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + + TfLiteTensor* output0 = interpreter_->tensor(interpreter_->outputs()[0]); + CheckIntTensor(output0, {1}, {2}); + TfLiteTensor* output1 = interpreter_->tensor(interpreter_->outputs()[1]); + CheckIntTensor(output1, {1}, {2}); + + interpreter_->typed_input_tensor(0)[0] = false; + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + CheckIntTensor(output0, {1}, {2}); + CheckIntTensor(output1, {1}, {2}); +} + +TEST_F(IfTest, TestInputIsOutputButDifferent) { + interpreter_ = std::make_unique(); + AddSubgraphs(2); + builder_->BuildInputIsDifferentOutputSubgraph(interpreter_->subgraph(1)); + builder_->BuildInputIsDifferentOutputSubgraph(interpreter_->subgraph(2)); + builder_->BuildMultiInputIfSubgraph(&interpreter_->primary_subgraph(), 3); + + ASSERT_EQ(interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1}), + kTfLiteOk); + ASSERT_EQ(interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {1}), + kTfLiteOk); + ASSERT_EQ(interpreter_->ResizeInputTensor(interpreter_->inputs()[2], {1}), + kTfLiteOk); + ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); + interpreter_->typed_input_tensor(0)[0] = true; + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {1}); + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[2]), {2}); + + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + TfLiteTensor* output0 = interpreter_->tensor(interpreter_->outputs()[0]); + CheckIntTensor(output0, {1}, {2}); + TfLiteTensor* output1 = interpreter_->tensor(interpreter_->outputs()[1]); + CheckIntTensor(output1, {1}, {3}); + + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); +} + +TEST_F(IfTest, TestFlexOutput) { + interpreter_ = std::make_unique(); + AddSubgraphs(2); + builder_->BuildFlexOutputSubgraph(interpreter_->subgraph(1)); + builder_->BuildFlexOutputSubgraph(interpreter_->subgraph(2)); + builder_->BuildMultiInputIfSubgraph(&interpreter_->primary_subgraph(), 3); + + ASSERT_EQ(interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1}), + kTfLiteOk); + ASSERT_EQ(interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {1}), + kTfLiteOk); + ASSERT_EQ(interpreter_->ResizeInputTensor(interpreter_->inputs()[2], {2}), + kTfLiteOk); + ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); + interpreter_->typed_input_tensor(0)[0] = false; + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {1}); + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[2]), {2, 3}); + + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + TfLiteTensor* output0 = interpreter_->tensor(interpreter_->outputs()[0]); + CheckIntTensor(output0, {1}, {2}); + TfLiteTensor* output1 = interpreter_->tensor(interpreter_->outputs()[1]); + CheckIntTensor(output1, {2}, {3, 4}); + + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); +} + +TEST_F(IfTest, TestCounterOnly) { + interpreter_ = std::make_unique(); + AddSubgraphs(2); + builder_->BuildCounterOnlySubgraph(interpreter_->subgraph(1)); + builder_->BuildCounterOnlySubgraph(interpreter_->subgraph(2)); + builder_->BuildMultiInputIfSubgraph(&interpreter_->primary_subgraph(), 2); + + ASSERT_EQ(interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1}), + kTfLiteOk); + ASSERT_EQ(interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {1}), + kTfLiteOk); + ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); + interpreter_->typed_input_tensor(0)[0] = false; + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {1}); + + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + TfLiteTensor* output0 = interpreter_->tensor(interpreter_->outputs()[0]); + CheckIntTensor(output0, {1}, {2}); + + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); +} + +TEST_F(IfTest, TestAllCases) { + interpreter_ = std::make_unique(); + AddSubgraphs(2); + builder_->BuildAllInplaceScenariosSubgraph(interpreter_->subgraph(1)); + builder_->BuildAllInplaceScenariosSubgraph(interpreter_->subgraph(2)); + builder_->BuildMultiInputIfSubgraph(&interpreter_->primary_subgraph(), 6); + + ASSERT_EQ(interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1}), + kTfLiteOk); + ASSERT_EQ(interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {1}), + kTfLiteOk); + ASSERT_EQ(interpreter_->ResizeInputTensor(interpreter_->inputs()[2], {1}), + kTfLiteOk); + ASSERT_EQ(interpreter_->ResizeInputTensor(interpreter_->inputs()[3], {1}), + kTfLiteOk); + ASSERT_EQ(interpreter_->ResizeInputTensor(interpreter_->inputs()[4], {1}), + kTfLiteOk); + ASSERT_EQ(interpreter_->ResizeInputTensor(interpreter_->inputs()[5], {1}), + kTfLiteOk); + ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); + interpreter_->typed_input_tensor(0)[0] = true; + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {2}); + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[2]), {1}); + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[3]), {2}); + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[4]), {2}); + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[5]), {1}); + + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + TfLiteTensor* output0 = interpreter_->tensor(interpreter_->outputs()[0]); + CheckIntTensor(output0, {1}, {3}); + TfLiteTensor* output1 = interpreter_->tensor(interpreter_->outputs()[1]); + CheckIntTensor(output1, {1}, {3}); + TfLiteTensor* output2 = interpreter_->tensor(interpreter_->outputs()[2]); + CheckIntTensor(output2, {2}, {2, 2}); + TfLiteTensor* output3 = interpreter_->tensor(interpreter_->outputs()[3]); + CheckIntTensor(output3, {2}, {3, 3}); + TfLiteTensor* output4 = interpreter_->tensor(interpreter_->outputs()[4]); + CheckIntTensor(output4, {1}, {1}); + + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); +} + +TEST_F(IfTest, TestStaticUnconsumedOutputs) { + for (bool dynamic_tensors : {true, false}) { + interpreter_ = std::make_unique(); + AddSubgraphs(2); + builder_->BuildInputIsOutputSubgraph(interpreter_->subgraph(1)); + builder_->BuildInputIsOutputSubgraph(interpreter_->subgraph(2)); + builder_->BuildMultiInputIfSubgraphWithUnconsumedOutput( + &interpreter_->primary_subgraph(), 4); + + InterpreterOptions options; + if (dynamic_tensors) { + options.OptimizeMemoryForLargeTensors(1); + interpreter_->ApplyOptions(&options); + } + + ASSERT_EQ(interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1}), + kTfLiteOk); + ASSERT_EQ(interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {1}), + kTfLiteOk); + ASSERT_EQ(interpreter_->ResizeInputTensor(interpreter_->inputs()[2], {1}), + kTfLiteOk); + ASSERT_EQ(interpreter_->ResizeInputTensor(interpreter_->inputs()[3], {1}), + kTfLiteOk); + ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); + interpreter_->typed_input_tensor(0)[0] = true; + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {1}); + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[2]), {2}); + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[3]), {2}); + + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + TfLiteTensor* output0 = interpreter_->tensor(interpreter_->outputs()[0]); + CheckIntTensor(output0, {1}, {2}); + TfLiteTensor* output1 = interpreter_->tensor(interpreter_->outputs()[1]); + CheckIntTensor(output1, {1}, {4}); + + ASSERT_EQ(interpreter_->ResizeInputTensor(interpreter_->inputs()[3], {2}), + kTfLiteOk); + ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[3]), {2, 2}); + + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + CheckIntTensor(output1, {2}, {4, 4}); + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + interpreter_->typed_input_tensor(0)[0] = false; + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + } +} + +// Test a body subgraph which triggers the reallocation of an inplace output +// tensor whose corresponding input has not been consumed yet. This tests that +// the input pointer has be updated. +TEST_F(IfTest, TestDynamicOpTriggersAllocationOfUnsedInput) { + interpreter_ = std::make_unique(); + AddSubgraphs(2); + builder_->BuildDynamicOpTriggersAllocationOfUnsedInputSubgraph( + interpreter_->subgraph(1)); + builder_->BuildDynamicOpTriggersAllocationOfUnsedInputSubgraph( + interpreter_->subgraph(2)); + builder_->BuildMultiInputIfSubgraph(&interpreter_->primary_subgraph(), 4); + + ASSERT_EQ(interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1}), + kTfLiteOk); + ASSERT_EQ(interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {1}), + kTfLiteOk); + ASSERT_EQ(interpreter_->ResizeInputTensor(interpreter_->inputs()[2], {1}), + kTfLiteOk); + ASSERT_EQ(interpreter_->ResizeInputTensor(interpreter_->inputs()[3], {1}), + kTfLiteOk); + ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); + interpreter_->typed_input_tensor(0)[0] = false; + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {2}); + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[2]), {1}); + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[3]), {2}); + + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + TfLiteTensor* output0 = interpreter_->tensor(interpreter_->outputs()[0]); + CheckIntTensor(output0, {1}, {3}); + TfLiteTensor* output1 = interpreter_->tensor(interpreter_->outputs()[1]); + CheckIntTensor(output1, {2}, {4, 4}); + TfLiteTensor* output2 = interpreter_->tensor(interpreter_->outputs()[2]); + CheckIntTensor(output2, {2}, {2, 2}); + + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); +} + +TEST_F(IfTest, TestStaticInPlace) { + interpreter_ = std::make_unique(); + AddSubgraphs(2); + builder_->BuildDeepBodySubgraph(interpreter_->subgraph(1)); + builder_->BuildDeepBodySubgraph(interpreter_->subgraph(2)); + builder_->BuildMultiInputIfSubgraph(&interpreter_->primary_subgraph(), 3); + + ASSERT_EQ(interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1}), + kTfLiteOk); + ASSERT_EQ(interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {1}), + kTfLiteOk); + ASSERT_EQ(interpreter_->ResizeInputTensor(interpreter_->inputs()[2], {1}), + kTfLiteOk); + ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); + interpreter_->typed_input_tensor(0)[0] = false; + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {0}); + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[2]), {1}); + + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + TfLiteTensor* output1 = interpreter_->tensor(interpreter_->outputs()[0]); + CheckIntTensor(output1, {1}, {1}); + TfLiteTensor* output2 = interpreter_->tensor(interpreter_->outputs()[1]); + CheckIntTensor(output2, {1}, {3}); + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); +} + +TEST_F(IfTest, TestStaticInPlaceLarge) { + int size = 10000; + interpreter_ = std::make_unique(); + AddSubgraphs(2); + builder_->BuildLargeBodySubgraph(interpreter_->subgraph(1)); + builder_->BuildLargeBodySubgraph(interpreter_->subgraph(2)); + builder_->BuildMultiInputIfSubgraph(&interpreter_->primary_subgraph(), 3); + + ASSERT_EQ(interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {}), + kTfLiteOk); + ASSERT_EQ(interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {}), + kTfLiteOk); + ASSERT_EQ(interpreter_->ResizeInputTensor(interpreter_->inputs()[2], {size}), + kTfLiteOk); + ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); + interpreter_->typed_input_tensor(0)[0] = true; + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {1}); + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[2]), + std::vector(size, 1)); + + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + TfLiteTensor* output1 = interpreter_->tensor(interpreter_->outputs()[0]); + CheckIntTensor(output1, {}, {10000}); + TfLiteTensor* output2 = interpreter_->tensor(interpreter_->outputs()[1]); + CheckIntTensor(output2, {size}, std::vector(size, 6)); + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); +} + +// The test builds a model that produces the i-th number of +// triangular number sequence. +TEST_F(IfTest, TestTriangularNumberSequence) { + interpreter_ = std::make_unique(); + AddSubgraphs(2); + builder_->BuildAccumulateLoopBodySubgraph(interpreter_->subgraph(1)); + builder_->BuildAccumulateLoopBodySubgraph(interpreter_->subgraph(2)); + builder_->BuildMultiInputIfSubgraph(&interpreter_->primary_subgraph(), 3); + + ASSERT_EQ(interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1}), + kTfLiteOk); + ASSERT_EQ(interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {1}), + kTfLiteOk); + ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); + interpreter_->typed_input_tensor(0)[0] = true; + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {1}); + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[2]), {1}); + + // Check If BODY inputs are static tensors. + auto body_subgraph = interpreter_->subgraph(2); + TfLiteTensor* subgraph_input2 = + body_subgraph->tensor(body_subgraph->inputs()[1]); + EXPECT_EQ(subgraph_input2->allocation_type, kTfLiteCustom); + + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + TfLiteTensor* output1 = interpreter_->tensor(interpreter_->outputs()[0]); + CheckIntTensor(output1, {1}, {2}); + TfLiteTensor* output2 = interpreter_->tensor(interpreter_->outputs()[1]); + CheckIntTensor(output2, {1}, {3}); + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); +} + +TEST_F(IfTest, TestTriangularNumberSequenceWithShallowCopy) { + interpreter_ = std::make_unique(); + AddSubgraphs(2); + builder_->BuildAccumulateLoopBodySubgraph(interpreter_->subgraph(1)); + builder_->BuildAccumulateLoopBodySubgraph(interpreter_->subgraph(2)); + builder_->BuildMultiInputIfSubgraph(&interpreter_->primary_subgraph(), 3); + + interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1}); + interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {1}); + // Use 4MB inputs to test shallow copy. + interpreter_->ResizeInputTensor(interpreter_->inputs()[2], {1000000}); + // Apply DynamicAllocationForLargeTensors option to enable shallow copy. + InterpreterOptions options; + options.OptimizeMemoryForLargeTensors(1000000); + ASSERT_EQ(interpreter_->ApplyOptions(&options), kTfLiteOk); + ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); + interpreter_->typed_input_tensor(0)[0] = false; + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {1}); + const std::vector input_vector(1000000, 1); + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[2]), input_vector); + + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + + auto body_subgraph = interpreter_->subgraph(2); + // If BODY inputs are dynamic tensors with shallow copy. + TfLiteTensor* subgraph_input2 = + body_subgraph->tensor(body_subgraph->inputs()[1]); + ASSERT_EQ(subgraph_input2->allocation_type, kTfLiteCustom); + + TfLiteTensor* output1 = interpreter_->tensor(interpreter_->outputs()[0]); + CheckIntTensor(output1, {1}, {2}); + TfLiteTensor* output2 = interpreter_->tensor(interpreter_->outputs()[1]); + const std::vector expected2(1000000, 3); + CheckIntTensor(output2, {1000000}, expected2); + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); +} + +TEST_F(IfTest, TestPadLoop) { + interpreter_ = std::make_unique(); + AddSubgraphs(2); + builder_->BuildPadLoopBodySubgraph(interpreter_->subgraph(1), {1, 2}); + builder_->BuildPadLoopBodySubgraph(interpreter_->subgraph(2), {1, 2}); + builder_->BuildMultiInputIfSubgraph(&interpreter_->primary_subgraph(), 3); + + interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1}); + interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {1}); + interpreter_->ResizeInputTensor(interpreter_->inputs()[2], {2}); + ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); + interpreter_->typed_input_tensor(0)[0] = false; + + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {1}); + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[2]), {5, 7}); + + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + TfLiteTensor* output1 = interpreter_->tensor(interpreter_->outputs()[0]); + CheckIntTensor(output1, {1}, {2}); + TfLiteTensor* output2 = interpreter_->tensor(interpreter_->outputs()[1]); + CheckIntTensor(output2, {5}, {0, 5, 7, 0, 0}); + + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); +} + +TEST_F(IfTest, TestDynamicBodyWithSharingEarlyExit) { + interpreter_ = std::make_unique(); + AddSubgraphs(2); + builder_->BuildDynamicIncreasingSizeSubgraph(interpreter_->subgraph(1)); + builder_->BuildDynamicIncreasingSizeSubgraph(interpreter_->subgraph(2)); + builder_->BuildMultiInputIfSubgraph(&interpreter_->primary_subgraph(), 5); + + interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1}); + interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {1}); + interpreter_->ResizeInputTensor(interpreter_->inputs()[2], {3}); + interpreter_->ResizeInputTensor(interpreter_->inputs()[3], {10000}); + ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); + + interpreter_->typed_input_tensor(0)[0] = false; + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {1}); + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[2]), {1, 2, 3}); + + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + TfLiteTensor* output0 = interpreter_->tensor(interpreter_->outputs()[0]); + CheckIntTensor(output0, {1}, {2}); + TfLiteTensor* output1 = interpreter_->tensor(interpreter_->outputs()[1]); + CheckIntTensor(output1, {3}, {2, 3, 4}); + + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); +} + +TEST_F(IfTest, TestDynamicBodyWithSharing) { + interpreter_ = std::make_unique(); + AddSubgraphs(2); + + builder_->BuildDynamicIncreasingSizeSubgraph(interpreter_->subgraph(1)); + builder_->BuildDynamicIncreasingSizeSubgraph(interpreter_->subgraph(2)); + builder_->BuildMultiInputIfSubgraph(&interpreter_->primary_subgraph(), 5); + + interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1}); + interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {1}); + interpreter_->ResizeInputTensor(interpreter_->inputs()[2], {3}); + interpreter_->ResizeInputTensor(interpreter_->inputs()[3], {1000000}); + interpreter_->ResizeInputTensor(interpreter_->inputs()[4], {1000000}); + ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); + + interpreter_->typed_input_tensor(0)[0] = true; + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {1}); + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[2]), {1, 2, 3}); + + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + TfLiteTensor* output0 = interpreter_->tensor(interpreter_->outputs()[0]); + CheckIntTensor(output0, {1}, {2}); + TfLiteTensor* output1 = interpreter_->tensor(interpreter_->outputs()[1]); + CheckIntTensor(output1, {3}, {2, 3, 4}); + TfLiteTensor* output2 = interpreter_->tensor(interpreter_->outputs()[2]); + EXPECT_EQ(output2->dims->data[0], 1000000); + TfLiteTensor* output3 = interpreter_->tensor(interpreter_->outputs()[3]); + EXPECT_EQ(output3->dims->data[0], 1000000); + + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); +} + +TEST_F(IfTest, TestDynamicBodyWithSharingAndAliases) { + interpreter_ = std::make_unique(); + AddSubgraphs(2); + builder_->BuildDynamicBodySubgraphWithAliases(interpreter_->subgraph(1)); + builder_->BuildDynamicBodySubgraphWithAliases(interpreter_->subgraph(2)); + builder_->BuildMultiInputIfSubgraph(&interpreter_->primary_subgraph(), 6); + + interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1}); + interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {1}); + interpreter_->ResizeInputTensor(interpreter_->inputs()[2], {1}); + interpreter_->ResizeInputTensor(interpreter_->inputs()[3], {1}); + interpreter_->ResizeInputTensor(interpreter_->inputs()[4], {1}); + ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); + + interpreter_->typed_input_tensor(0)[0] = true; + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {0}); + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[2]), {1}); + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[3]), {2}); + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[4]), {3}); + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[5]), {4}); + + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + TfLiteTensor* output0 = interpreter_->tensor(interpreter_->outputs()[0]); + CheckIntTensor(output0, {1}, {1}); + TfLiteTensor* output1 = interpreter_->tensor(interpreter_->outputs()[1]); + CheckIntTensor(output1, {1}, {11}); + TfLiteTensor* output2 = interpreter_->tensor(interpreter_->outputs()[2]); + CheckIntTensor(output2, {1}, {12}); + TfLiteTensor* output3 = interpreter_->tensor(interpreter_->outputs()[4]); + CheckIntTensor(output3, {1}, {13}); + TfLiteTensor* output4 = interpreter_->tensor(interpreter_->outputs()[4]); + CheckIntTensor(output4, {1}, {13}); + + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); +} + +TEST_F(IfTest, TestOutputNotConsumed) { + interpreter_ = std::make_unique(); + AddSubgraphs(2); + builder_->BuildOutputNotConsumedSubgraph(*interpreter_->subgraph(1)); + builder_->BuildOutputNotConsumedSubgraph(*interpreter_->subgraph(2)); + builder_->BuildOutputNotConsumedIfSubgraph(&interpreter_->primary_subgraph()); + + interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1}); + interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {1}); + interpreter_->ResizeInputTensor(interpreter_->inputs()[2], {1}); + ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); + + interpreter_->typed_input_tensor(0)[0] = true; + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {1}); + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[2]), {2}); + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[3]), {3}); + + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + TfLiteTensor* output0 = interpreter_->tensor(interpreter_->outputs()[0]); + CheckIntTensor(output0, {1}, {3}); + + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); +} + +TEST_F(IfTest, TestPadLoopWithSharing) { + interpreter_ = std::make_unique(); + AddSubgraphs(2); + builder_->BuildLargePadSubgraph(interpreter_->subgraph(1), {1, 2}); + builder_->BuildLargePadSubgraph(interpreter_->subgraph(2), {1, 2}); + builder_->BuildMultiInputIfSubgraph(&interpreter_->primary_subgraph(), 4); + + interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1}); + interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {1}); + interpreter_->ResizeInputTensor(interpreter_->inputs()[2], {1}); + interpreter_->ResizeInputTensor(interpreter_->inputs()[3], {2}); + ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); + + interpreter_->typed_input_tensor(0)[0] = false; + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {1}); + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[2]), {2}); + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[3]), {3, 4}); + + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + TfLiteTensor* output0 = interpreter_->tensor(interpreter_->outputs()[0]); + CheckIntTensor(output0, {1}, {3}); + TfLiteTensor* output1 = interpreter_->tensor(interpreter_->outputs()[1]); + CheckIntTensor(output1, {2}, {5, 6}); + TfLiteTensor* output2 = interpreter_->tensor(interpreter_->outputs()[2]); + CheckIntTensor(output2, {5}, {0, 5, 6, 0, 0}); + + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); +} + +TEST_F(IfTest, TestPadLoopWithShallowCopy) { + interpreter_ = std::make_unique(); + AddSubgraphs(2); + builder_->BuildPadLoopBodySubgraph(interpreter_->subgraph(1), {1, 2}); + builder_->BuildPadLoopBodySubgraph(interpreter_->subgraph(2), {1, 2}); + builder_->BuildMultiInputIfSubgraph(&interpreter_->primary_subgraph(), 3); + + interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1}); + interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {1}); + // Use 4MB inputs to test shallow copy. + interpreter_->ResizeInputTensor(interpreter_->inputs()[2], {1000000}); + ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); + + interpreter_->typed_input_tensor(0)[0] = false; + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {1}); + std::vector input_vector(1000000, 0); + input_vector[0] = 5; + input_vector[1] = 7; + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[2]), input_vector); + + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + TfLiteTensor* output1 = interpreter_->tensor(interpreter_->outputs()[0]); + CheckIntTensor(output1, {1}, {2}); + TfLiteTensor* output2 = interpreter_->tensor(interpreter_->outputs()[1]); + std::vector output_vector(1000003, 0); + output_vector[1] = 5; + output_vector[2] = 7; + CheckIntTensor(output2, {1000003}, output_vector); + + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); +} + +TEST_F(IfTest, TestIfLoopWithDynamicTensor) { + interpreter_ = std::make_unique(); + AddSubgraphs(2); + builder_->BuildBodySubgraphWithDynamicTensor(interpreter_->subgraph(1)); + builder_->BuildBodySubgraphWithDynamicTensor(interpreter_->subgraph(2)); + builder_->BuildIfSubgraphWithDynamicTensor(&interpreter_->primary_subgraph()); + + interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1}); + interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {}); + interpreter_->ResizeInputTensor(interpreter_->inputs()[2], {}); + interpreter_->ResizeInputTensor(interpreter_->inputs()[3], {1}); + ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); + + interpreter_->typed_input_tensor(0)[0] = false; + FillScalarStringTensor(interpreter_->tensor(interpreter_->inputs()[1]), "A"); + FillScalarStringTensor(interpreter_->tensor(interpreter_->inputs()[2]), "A"); + FillIntTensor(interpreter_->tensor(interpreter_->inputs()[3]), {1}); + + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + TfLiteTensor* string_output1 = + interpreter_->tensor(interpreter_->outputs()[0]); + CheckScalarStringTensor(string_output1, "A"); + TfLiteTensor* string_output2 = + interpreter_->tensor(interpreter_->outputs()[1]); + CheckStringTensor(string_output2, {2}, {"A", "A"}); + TfLiteTensor* integer_output = + interpreter_->tensor(interpreter_->outputs()[2]); + CheckIntTensor(integer_output, {1}, {2}); + + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); +} + } // namespace } // namespace tflite diff --git a/tensorflow/lite/kernels/subgraph_test_util.cc b/tensorflow/lite/kernels/subgraph_test_util.cc index 809c185621e015..f2bf6397ca43cf 100644 --- a/tensorflow/lite/kernels/subgraph_test_util.cc +++ b/tensorflow/lite/kernels/subgraph_test_util.cc @@ -1265,6 +1265,54 @@ void SubgraphBuilder::BuildPadLoopBodySubgraph( &node_index); } +void SubgraphBuilder::BuildOutputNotConsumedIfSubgraph(Subgraph* subgraph) { + enum { + kInput0, + kInput1, + kInput2, + kInput3, + kOutput0, + kOutput1, + kOutput2, + kTensorCount + }; + + int num_inputs = 4; + int num_outputs = 3; + int first_new_tensor_index; + ASSERT_EQ(subgraph->AddTensors(kTensorCount, &first_new_tensor_index), + kTfLiteOk); + ASSERT_EQ(first_new_tensor_index, 0); + std::vector input_tensors(num_inputs); + std::vector output_tensors(num_outputs); + for (int i = 0; i < num_inputs; ++i) { + input_tensors[i] = i; + } + for (int i = 0; i < num_outputs; ++i) { + output_tensors[i] = i + num_inputs; + } + ASSERT_EQ(subgraph->SetInputs(input_tensors), kTfLiteOk); + ASSERT_EQ(subgraph->SetOutputs(output_tensors), kTfLiteOk); + SetupTensor(subgraph, input_tensors[0], kTfLiteBool); + for (int i = 1; i < num_inputs; ++i) { + SetupTensor(subgraph, input_tensors[i], kTfLiteInt32); + } + for (int i = 0; i < num_outputs; ++i) { + SetupTensor(subgraph, output_tensors[i], kTfLiteInt32); + } + + TfLiteIfParams* params = + reinterpret_cast(malloc(sizeof(TfLiteIfParams))); + params->then_subgraph_index = 1; + params->else_subgraph_index = 2; + auto* if_reg = ops::builtin::Register_IF(); + if_reg->builtin_code = kTfLiteBuiltinIf; + + int node_index; + subgraph->AddNodeWithParameters(input_tensors, output_tensors, {}, nullptr, 0, + params, if_reg, &node_index); +} + void SubgraphBuilder::BuildOutputNotConsumedWhileSubgraph(Subgraph* subgraph) { enum { kInput0, @@ -1303,6 +1351,44 @@ void SubgraphBuilder::BuildOutputNotConsumedWhileSubgraph(Subgraph* subgraph) { while_reg, &node_index); } +void SubgraphBuilder::BuildFloatIfSubgraph(Subgraph* subgraph, int num_inputs) { + int num_outputs = num_inputs - 1; + int first_new_tensor_index; + ASSERT_EQ( + subgraph->AddTensors(num_inputs + num_outputs, &first_new_tensor_index), + kTfLiteOk); + ASSERT_EQ(first_new_tensor_index, 0); + std::vector input_tensors(num_inputs); + std::vector output_tensors(num_outputs); + for (int i = 0; i < num_inputs; ++i) { + input_tensors[i] = i; + } + for (int i = 0; i < num_outputs; ++i) { + output_tensors[i] = i + num_inputs; + } + ASSERT_EQ(subgraph->SetInputs(input_tensors), kTfLiteOk); + ASSERT_EQ(subgraph->SetOutputs(output_tensors), kTfLiteOk); + + SetupTensor(subgraph, input_tensors[0], kTfLiteBool); + for (int i = 1; i < num_inputs; ++i) { + SetupTensor(subgraph, input_tensors[i], kTfLiteFloat32); + } + for (int i = 0; i < num_outputs; ++i) { + SetupTensor(subgraph, output_tensors[i], kTfLiteFloat32); + } + + TfLiteIfParams* params = + reinterpret_cast(malloc(sizeof(TfLiteWhileParams))); + params->then_subgraph_index = 1; + params->else_subgraph_index = 2; + auto* if_reg = ops::builtin::Register_IF(); + if_reg->builtin_code = kTfLiteBuiltinIf; + + int node_index; + subgraph->AddNodeWithParameters(input_tensors, output_tensors, {}, nullptr, 0, + params, if_reg, &node_index); +} + void SubgraphBuilder::BuildFloatWhileSubgraph(Subgraph* subgraph, int num_inputs) { // kInput1(0) --> +-------+ --> kOutput1(2) @@ -1339,6 +1425,46 @@ void SubgraphBuilder::BuildFloatWhileSubgraph(Subgraph* subgraph, params, while_reg, &node_index); } +void SubgraphBuilder::BuildMultiInputIfSubgraphWithUnconsumedOutput( + Subgraph* subgraph, int num_inputs) { + int num_outputs = num_inputs - 1; + int first_new_tensor_index; + ASSERT_EQ( + subgraph->AddTensors(num_inputs + num_outputs, &first_new_tensor_index), + kTfLiteOk); + ASSERT_EQ(first_new_tensor_index, 0); + std::vector input_tensors(num_inputs); + std::vector output_tensors(num_outputs); + for (int i = 0; i < num_inputs; ++i) { + input_tensors[i] = i; + } + for (int i = 0; i < num_outputs; ++i) { + output_tensors[i] = i + num_inputs; + } + SetupTensor(subgraph, input_tensors[0], kTfLiteBool); + for (int i = 1; i < num_inputs; ++i) { + SetupTensor(subgraph, input_tensors[i], kTfLiteInt32); + } + for (int i = 0; i < num_outputs; ++i) { + SetupTensor(subgraph, output_tensors[i], kTfLiteInt32); + } + + TfLiteIfParams* params = + reinterpret_cast(malloc(sizeof(TfLiteIfParams))); + params->then_subgraph_index = 1; + params->else_subgraph_index = 2; + auto* if_reg = ops::builtin::Register_IF(); + if_reg->builtin_code = kTfLiteBuiltinIf; + + int node_index; + subgraph->AddNodeWithParameters(input_tensors, output_tensors, {}, nullptr, 0, + params, if_reg, &node_index); + + output_tensors.pop_back(); + ASSERT_EQ(subgraph->SetInputs(input_tensors), kTfLiteOk); + ASSERT_EQ(subgraph->SetOutputs(output_tensors), kTfLiteOk); +} + void SubgraphBuilder::BuildMultiInputWhileSubgraphWithUnconsumedOutput( Subgraph* subgraph, int num_inputs) { // kInput1(0) --> +-------+ --> kOutput1(2) @@ -1376,6 +1502,45 @@ void SubgraphBuilder::BuildMultiInputWhileSubgraphWithUnconsumedOutput( ASSERT_EQ(subgraph->SetOutputs(output_tensors), kTfLiteOk); } +void SubgraphBuilder::BuildMultiInputIfSubgraph(Subgraph* subgraph, + int num_inputs) { + int num_outputs = num_inputs - 1; + int first_new_tensor_index; + ASSERT_EQ( + subgraph->AddTensors(num_inputs + num_outputs, &first_new_tensor_index), + kTfLiteOk); + ASSERT_EQ(first_new_tensor_index, 0); + std::vector input_tensors(num_inputs); + std::vector output_tensors(num_outputs); + for (int i = 0; i < num_inputs; ++i) { + input_tensors[i] = i; + } + for (int i = 0; i < num_outputs; ++i) { + output_tensors[i] = i + num_inputs; + } + ASSERT_EQ(subgraph->SetInputs(input_tensors), kTfLiteOk); + ASSERT_EQ(subgraph->SetOutputs(output_tensors), kTfLiteOk); + + SetupTensor(subgraph, input_tensors[0], kTfLiteBool); + for (int i = 1; i < num_inputs; ++i) { + SetupTensor(subgraph, input_tensors[i], kTfLiteInt32); + } + for (int i = 0; i < num_outputs; ++i) { + SetupTensor(subgraph, output_tensors[i], kTfLiteInt32); + } + + TfLiteIfParams* params = + reinterpret_cast(malloc(sizeof(TfLiteIfParams))); + params->then_subgraph_index = 1; + params->else_subgraph_index = 2; + auto* if_reg = ops::builtin::Register_IF(); + if_reg->builtin_code = kTfLiteBuiltinWhile; + + int node_index; + subgraph->AddNodeWithParameters(input_tensors, output_tensors, {}, nullptr, 0, + params, if_reg, &node_index); +} + void SubgraphBuilder::BuildMultiInputWhileSubgraph(Subgraph* subgraph, int num_inputs) { // kInput1(0) --> +-------+ --> kOutput1(2) @@ -1632,6 +1797,56 @@ void SubgraphBuilder::BuildBodySubgraphWithDynamicTensor(Subgraph* subgraph) { fill_reg, &node_index); } +void SubgraphBuilder::BuildIfSubgraphWithDynamicTensor(Subgraph* subgraph) { + enum { + kBoolInput0, + kStringInput1, + kStringInput2, + kIntegerInput, + kStringOutput1, + kStringOutput2, + kIntegerOutput, + kTensorCount + }; + + int num_inputs = 4; + int num_outputs = num_inputs - 1; + // Create a if op with 2 string tensor and 1 integer tensor. + int first_new_tensor_index; + std::vector input_tensors(num_inputs); + std::vector output_tensors(num_outputs); + for (int i = 0; i < num_inputs; ++i) { + input_tensors[i] = i; + } + for (int i = 0; i < num_outputs; ++i) { + output_tensors[i] = i + num_inputs; + } + ASSERT_EQ(subgraph->AddTensors(kTensorCount, &first_new_tensor_index), + kTfLiteOk); + ASSERT_EQ(first_new_tensor_index, 0); + ASSERT_EQ(subgraph->SetInputs(input_tensors), kTfLiteOk); + ASSERT_EQ(subgraph->SetOutputs(output_tensors), kTfLiteOk); + + SetupTensor(subgraph, kBoolInput0, kTfLiteBool); + SetupTensor(subgraph, kStringInput1, kTfLiteString); + SetupTensor(subgraph, kStringInput2, kTfLiteString); + SetupTensor(subgraph, kIntegerInput, kTfLiteInt32); + SetupTensor(subgraph, kStringOutput1, kTfLiteString); + SetupTensor(subgraph, kStringOutput2, kTfLiteString); + SetupTensor(subgraph, kIntegerOutput, kTfLiteInt32); + + TfLiteIfParams* params = + reinterpret_cast(malloc(sizeof(TfLiteWhileParams))); + params->then_subgraph_index = 1; + params->else_subgraph_index = 2; + auto* if_reg = ops::builtin::Register_IF(); + if_reg->builtin_code = kTfLiteBuiltinIf; + + int node_index; + subgraph->AddNodeWithParameters(input_tensors, output_tensors, {}, nullptr, 0, + params, if_reg, &node_index); +} + void SubgraphBuilder::BuildWhileSubgraphWithDynamicTensor(Subgraph* subgraph) { const int kStringInput1 = 0; const int kStringInput2 = 1; diff --git a/tensorflow/lite/kernels/subgraph_test_util.h b/tensorflow/lite/kernels/subgraph_test_util.h index 6bddd866e43f69..abda4f30bc28d0 100644 --- a/tensorflow/lite/kernels/subgraph_test_util.h +++ b/tensorflow/lite/kernels/subgraph_test_util.h @@ -66,6 +66,9 @@ class SubgraphBuilder { // Build a subgraph whose output is not consumed by the parent subgraph. void BuildOutputNotConsumedSubgraph(Subgraph& subgraph); + // Build an if subgraph with float inputs and outputs. + void BuildFloatIfSubgraph(Subgraph* subgraph, int num_inputs); + // Build a while subgraph with float inputs and outputs. void BuildFloatWhileSubgraph(Subgraph* subgraph, int num_inputs); @@ -123,13 +126,25 @@ class SubgraphBuilder { // Equivalent to (input < rhs). void BuildLessEqualCondSubgraph(Subgraph* subgraph, int rhs); + // Build an if subgraph which does not consume an output of ifs body + // subgraph. + void BuildOutputNotConsumedIfSubgraph(Subgraph* subgraph); + // Build a while subgraph which does not consume an output of ifs body // subgraph. void BuildOutputNotConsumedWhileSubgraph(Subgraph* subgraph); + // Build a if subgraph with multiple inputs. + void BuildMultiInputIfSubgraph(Subgraph* subgraph, int num_inputs); + // Build a while subgraph with multiple inputs. void BuildMultiInputWhileSubgraph(Subgraph* subgraph, int num_inputs); + // Build an if subgraph with multiple inputs and one output which is not + // consumed. + void BuildMultiInputIfSubgraphWithUnconsumedOutput(Subgraph* subgraph, + int num_inputs); + // Build a while subgraph with multiple inputs and one output which is not // consumed. void BuildMultiInputWhileSubgraphWithUnconsumedOutput(Subgraph* subgraph, @@ -201,6 +216,10 @@ class SubgraphBuilder { // (str1, Fill(str1, int_val + 1), int_val + 1). void BuildBodySubgraphWithDynamicTensor(Subgraph* subgraph); + // Build a subgraph with a single If op, that contains 4 inputs and 3 + // outputs (str1, str2, int_val). + void BuildIfSubgraphWithDynamicTensor(Subgraph* subgraph); + // Build a subgraph with a single While op, that contains 3 inputs and 3 // outputs (str1, str2, int_val). void BuildWhileSubgraphWithDynamicTensor(Subgraph* subgraph); diff --git a/tensorflow/lite/kernels/while.cc b/tensorflow/lite/kernels/while.cc index fd2251f10cfc6f..8608ccbf764def 100644 --- a/tensorflow/lite/kernels/while.cc +++ b/tensorflow/lite/kernels/while.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/lite/core/c/builtin_op_data.h" #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/core/subgraph.h" +#include "tensorflow/lite/kernels/control_flow_common.h" #include "tensorflow/lite/kernels/kernel_util.h" namespace tflite { @@ -40,139 +41,6 @@ struct OpData { namespace { -// Propagate tensor shapes and types from `src_tensor_indices` in `src_subgraph` -// to `dst_tensor_indices` in `dst_subgraph`. -// -// When `resize_subgraph_inputs` is true, the function calls subgraphs's -// `ResizeInputTensor` function, and it may trigger the memory planner to -// reallocate memory. -// When `resize_subgraph_inputs` is false, it implies `context` belongs to -// `dst_subgraph`. The function calls `context->ResizeTensor`. This happens -// when resizing `While` op's outputs. -template -TfLiteStatus CopyTensorsShapeAndType(TfLiteContext* context, - Subgraph* src_subgraph, - const SrcVector& src_tensor_indices, - Subgraph* dst_subgraph, - const DstVector& dst_tensor_indices, - bool resize_subgraph_inputs) { - TF_LITE_ENSURE_EQ(context, src_tensor_indices.size(), - dst_tensor_indices.size()); - for (int i = 0; i < src_tensor_indices.size(); ++i) { - // Skip copying unused destination tensors. - if (dst_tensor_indices[i] == kTfLiteOptionalTensor) continue; - - const TfLiteTensor* src_tensor = - src_subgraph->tensor(src_tensor_indices[i]); - - TfLiteTensor* dst_tensor = dst_subgraph->tensor(dst_tensor_indices[i]); - if (resize_subgraph_inputs) { - std::vector dims(src_tensor->dims->data, - src_tensor->dims->data + src_tensor->dims->size); - dst_subgraph->ResizeInputTensor(dst_tensor_indices[i], dims); - } else { - TF_LITE_ENSURE_OK( - context, context->ResizeTensor(context, dst_tensor, - TfLiteIntArrayCopy(src_tensor->dims))); - } - dst_tensor->type = src_tensor->type; - } - return kTfLiteOk; -} - -// Copy the tensors data from tensors `src_tensor_indices` in `src_subgraph` -// to `dst_tensor_indices` in `dst_subgraph`. -template -TfLiteStatus CopyTensorsData(TfLiteContext* context, Subgraph* src_subgraph, - const SrcVector& src_tensor_indices, - Subgraph* dst_subgraph, - const DstVector& dst_tensor_indices) { - TF_LITE_ENSURE_EQ(context, src_tensor_indices.size(), - dst_tensor_indices.size()); - for (int i = 0; i < src_tensor_indices.size(); ++i) { - // Skip copying unused destination tensors. - if (dst_tensor_indices[i] == kTfLiteOptionalTensor) continue; - - const TfLiteTensor* src_tensor = - src_subgraph->tensor(src_tensor_indices[i]); - TfLiteTensor* dst_tensor = dst_subgraph->tensor(dst_tensor_indices[i]); - if (IsDynamicTensor(dst_tensor)) { - TfLiteTensorRealloc(src_tensor->bytes, dst_tensor); - } - TF_LITE_ENSURE_OK(context, TfLiteTensorCopy(src_tensor, dst_tensor)); - } - return kTfLiteOk; -} - -// Propagate tensor shapes and types from `src_tensor_indices` in `src_subgraph` -// to `dst_tensor_indices` in `dst_subgraph` and copy data deeply. -template -TfLiteStatus DeepCopyTensorsShapeTypeData(TfLiteContext* context, - TfLiteNode* node, - Subgraph* src_subgraph, - const SrcVector& src_tensor_indices, - Subgraph* dst_subgraph, - const DstVector& dst_tensor_indices) { - const OpData* op_data = reinterpret_cast(node->user_data); - - if (op_data->body_has_dynamic_output_tensors) { - Subgraph* this_subgraph = reinterpret_cast(context->impl_); - bool resize_subgraph_inputs = (dst_subgraph != this_subgraph); - TF_LITE_ENSURE_OK( - context, CopyTensorsShapeAndType( - context, src_subgraph, src_tensor_indices, dst_subgraph, - dst_tensor_indices, resize_subgraph_inputs)); - if (resize_subgraph_inputs) { - TF_LITE_ENSURE_OK(context, dst_subgraph->AllocateTensors()); - } - } - TF_LITE_ENSURE_OK(context, - CopyTensorsData(context, src_subgraph, src_tensor_indices, - dst_subgraph, dst_tensor_indices)); - return kTfLiteOk; -} - -template -TfLiteStatus DeepOrShallowCopyTensorsShapeTypeData( - TfLiteContext* context, TfLiteNode* node, Subgraph* src_subgraph, - const SrcVector& src_tensor_indices, Subgraph* dst_subgraph, - const DstVector& dst_tensor_indices) { - for (int i = 0; i < src_tensor_indices.size(); ++i) { - // Skip copying unused destination tensors. - if (dst_tensor_indices[i] == kTfLiteOptionalTensor) continue; - if (src_tensor_indices[i] == kTfLiteOptionalTensor) continue; - - const TfLiteTensor* src_tensor = - src_subgraph->tensor(src_tensor_indices[i]); - TfLiteTensor* dst_tensor = dst_subgraph->tensor(dst_tensor_indices[i]); - std::vector dims(src_tensor->dims->data, - src_tensor->dims->data + src_tensor->dims->size); - dst_subgraph->ResizeInputTensor(dst_tensor_indices[i], dims); - dst_tensor->type = src_tensor->type; - if (!IsResourceOrVariant(src_tensor)) { - dst_tensor->bytes = 0; // Don't allocate memory with AllocateTensors(). - dst_tensor->data.raw = nullptr; - } - } - TF_LITE_ENSURE_OK(context, dst_subgraph->AllocateTensors()); - for (int i = 0; i < src_tensor_indices.size(); ++i) { - // Skip copying unused destination tensors. - if (dst_tensor_indices[i] == kTfLiteOptionalTensor) continue; - if (src_tensor_indices[i] == kTfLiteOptionalTensor) continue; - - const TfLiteTensor* src_tensor = - src_subgraph->tensor(src_tensor_indices[i]); - TfLiteTensor* dst_tensor = dst_subgraph->tensor(dst_tensor_indices[i]); - if (IsResourceOrVariant(src_tensor)) { - TfLiteTensorRealloc(src_tensor->bytes, dst_tensor); - TF_LITE_ENSURE_OK(context, TfLiteTensorCopy(src_tensor, dst_tensor)); - } else { - dst_tensor->bytes = src_tensor->bytes; - dst_tensor->data.raw = src_tensor->data.raw; - } - } - return kTfLiteOk; -} TfLiteStatus CheckCondOutput(TfLiteContext* context, const TfLiteTensor* cond_output) { @@ -436,14 +304,16 @@ TfLiteStatus Eval_dynamic(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK(context, DeepCopyTensorsShapeTypeData( context, node, this_subgraph, TfLiteIntArrayView(node->inputs), - cond_subgraph, cond_subgraph->inputs())); + cond_subgraph, cond_subgraph->inputs(), + op_data->body_has_dynamic_output_tensors)); // Step 2. node->inputs -> node->outputs TF_LITE_ENSURE_OK( - context, DeepCopyTensorsShapeTypeData(context, node, this_subgraph, - TfLiteIntArrayView(node->inputs), - this_subgraph, - TfLiteIntArrayView(node->outputs))); + context, + DeepCopyTensorsShapeTypeData( + context, node, this_subgraph, TfLiteIntArrayView(node->inputs), + this_subgraph, TfLiteIntArrayView(node->outputs), + op_data->body_has_dynamic_output_tensors)); SetupUnconsumedOutputs(node, op_data, this_subgraph, body_subgraph); @@ -474,13 +344,15 @@ TfLiteStatus Eval_dynamic(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK( context, DeepCopyTensorsShapeTypeData( context, node, body_subgraph, body_subgraph->outputs(), - cond_subgraph, cond_subgraph->inputs())); + cond_subgraph, cond_subgraph->inputs(), + op_data->body_has_dynamic_output_tensors)); // Step 7. body->outputs -> node->outputs TF_LITE_ENSURE_OK( context, DeepCopyTensorsShapeTypeData( context, node, body_subgraph, body_subgraph->outputs(), - this_subgraph, TfLiteIntArrayView(node->outputs))); + this_subgraph, TfLiteIntArrayView(node->outputs), + op_data->body_has_dynamic_output_tensors)); } return kTfLiteOk; diff --git a/tensorflow/lite/kernels/while_test.cc b/tensorflow/lite/kernels/while_test.cc index 8cc2df70233e08..0e0a3e43a72727 100644 --- a/tensorflow/lite/kernels/while_test.cc +++ b/tensorflow/lite/kernels/while_test.cc @@ -137,7 +137,6 @@ TEST_F(WhileTest, TestFlexOutput) { builder_->BuildFlexOutputSubgraph(interpreter_->subgraph(2)); builder_->BuildMultiInputWhileSubgraph(&interpreter_->primary_subgraph(), 2); - // ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); ASSERT_EQ(interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1}), kTfLiteOk); ASSERT_EQ(interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {2}), @@ -164,7 +163,6 @@ TEST_F(WhileTest, TestCounterOnly) { builder_->BuildCounterOnlySubgraph(interpreter_->subgraph(2)); builder_->BuildMultiInputWhileSubgraph(&interpreter_->primary_subgraph(), 1); - ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); ASSERT_EQ(interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1}), kTfLiteOk); ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); @@ -684,7 +682,7 @@ TEST_F(WhileTest, TestWhileLoopWithDynamicTensor) { ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); TfLiteTensor* string_output1 = - interpreter_->tensor(interpreter_->inputs()[0]); + interpreter_->tensor(interpreter_->outputs()[0]); CheckScalarStringTensor(string_output1, "A"); TfLiteTensor* string_output2 = interpreter_->tensor(interpreter_->outputs()[1]); diff --git a/tensorflow/lite/profiling/profile_summarizer_test.cc b/tensorflow/lite/profiling/profile_summarizer_test.cc index 8c180e605c9243..c2ed4d47755a51 100644 --- a/tensorflow/lite/profiling/profile_summarizer_test.cc +++ b/tensorflow/lite/profiling/profile_summarizer_test.cc @@ -183,7 +183,7 @@ TEST_F(ProfileSummarizerIfOpTest, TestIfTrue) { subgraph_test_util::CheckIntTensor(output, {1, 2}, {6, 9}); auto events = profiler.GetProfileEvents(); - EXPECT_EQ(4, events.size()); + EXPECT_EQ(5, events.size()); int event_count_of_subgraph_zero = std::count_if( events.begin(), events.end(), [](auto event) { return event->extra_event_metadata == 0; }); @@ -194,7 +194,7 @@ TEST_F(ProfileSummarizerIfOpTest, TestIfTrue) { events.begin(), events.end(), [](auto event) { return event->extra_event_metadata == 2; }); EXPECT_EQ(2, event_count_of_subgraph_zero); - EXPECT_EQ(2, event_count_of_subgraph_one); + EXPECT_EQ(3, event_count_of_subgraph_one); EXPECT_EQ(0, event_count_of_subgraph_two); } @@ -210,7 +210,7 @@ TEST_F(ProfileSummarizerIfOpTest, TestIfFalse) { subgraph_test_util::CheckIntTensor(output, {1, 2}, {5, 14}); auto events = profiler.GetProfileEvents(); - EXPECT_EQ(4, events.size()); + EXPECT_EQ(5, events.size()); int event_count_of_subgraph_zero = std::count_if( events.begin(), events.end(), [](auto event) { return event->extra_event_metadata == 0; }); @@ -222,7 +222,7 @@ TEST_F(ProfileSummarizerIfOpTest, TestIfFalse) { [](auto event) { return event->extra_event_metadata == 2; }); EXPECT_EQ(2, event_count_of_subgraph_zero); EXPECT_EQ(0, event_count_of_subgraph_one); - EXPECT_EQ(2, event_count_of_subgraph_two); + EXPECT_EQ(3, event_count_of_subgraph_two); } } // namespace From d781442d3e2ad5cf98040a8bc77f0770a8e13945 Mon Sep 17 00:00:00 2001 From: shuw Date: Fri, 28 Jul 2023 08:11:58 -0700 Subject: [PATCH 311/410] Sync 61415 --- .../compiler/xla/service/gpu/tests/gemm_rewrite_test.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc index fa84776e5430b1..be0fe80389e813 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc @@ -422,7 +422,7 @@ ENTRY AddDotsFunc { )"; - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{2.5e-5, 1e-5})); MatchOptimizedHlo(hlo_text, R"( ; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,5,3], y: f32[5,3,4]) -> f32[5,2,4] { @@ -461,7 +461,7 @@ ENTRY AddDotsFunc { )"; - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{2.5e-5, 1e-5})); MatchOptimizedHlo(hlo_text, R"( ; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[3,2,5], y: f32[5,3,4]) -> f32[5,2,4] { @@ -582,7 +582,7 @@ ENTRY AddDotsFunc { )"; - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{2.5e-5, 1e-5})); MatchOptimizedHlo(hlo_text, R"( ; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[5,2,3], y: f32[5,3,4]) -> f32[2,5,4] { @@ -622,7 +622,7 @@ ENTRY AddDotsFunc { )"; - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{2.5e-5, 1e-5})); MatchOptimizedHlo(hlo_text, R"( ; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[5,2,3], y: f32[5,3,4]) -> f32[2,4,5] { From 9bc986c1c715881c7fa8fa37a72900ba0d41b749 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 28 Jul 2023 08:12:40 -0700 Subject: [PATCH 312/410] Change log to only report every 10 seconds to reduce frequency of this log. PiperOrigin-RevId: 551856037 --- tensorflow/core/kernels/data/BUILD | 1 + tensorflow/core/kernels/data/shuffle_dataset_op.cc | 7 +++++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD index 143a57262855ea..a28faea865dedf 100644 --- a/tensorflow/core/kernels/data/BUILD +++ b/tensorflow/core/kernels/data/BUILD @@ -1169,6 +1169,7 @@ tf_kernel_library( "//tensorflow/core/data:dataset_utils", "//tensorflow/core/data:name_utils", "//tensorflow/core/data:serialization_utils", + "@com_google_absl//absl/log", "@com_google_absl//absl/random", ], ) diff --git a/tensorflow/core/kernels/data/shuffle_dataset_op.cc b/tensorflow/core/kernels/data/shuffle_dataset_op.cc index 50043f6feeb471..d41f3226175fa7 100644 --- a/tensorflow/core/kernels/data/shuffle_dataset_op.cc +++ b/tensorflow/core/kernels/data/shuffle_dataset_op.cc @@ -23,6 +23,7 @@ limitations under the License. #include #include +#include "absl/log/log.h" #include "tensorflow/core/data/dataset_utils.h" #include "tensorflow/core/data/name_utils.h" #include "tensorflow/core/data/serialization_utils.h" @@ -419,8 +420,10 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase { if (EnvTime::NowMicros() > ((num_log_entries + 1) * kLogIntervalMicros) + start_micros) { num_log_entries++; - LOG(INFO) << "Filling up shuffle buffer (this may take a while): " - << num_elements_ << " of " << BufferSizeString(); + LOG_EVERY_N_SEC(INFO, 10) + << dataset()->metadata().name() << ": " + << "Filling up shuffle buffer (this may take a while): " + << num_elements_ << " of " << BufferSizeString(); } if (!input_impl_) { TF_RETURN_IF_ERROR(PrepareNextEpoch(ctx)); From 7636f9c8fdddf99ee6fb1a8b6f064605cc988cdb Mon Sep 17 00:00:00 2001 From: Ken Franko Date: Fri, 28 Jul 2023 09:49:22 -0700 Subject: [PATCH 313/410] Avoid collision of channel_ids for send/recv and collectives. Use a higher starting value for collectives. This decreases the chance of collision but doesn't completely eliminate it. PiperOrigin-RevId: 551878287 --- .../mlir/tf2xla/tests/legalize-tf-collective.mlir | 8 ++++---- .../mlir/tf2xla/transforms/legalize_tf_collective.cc | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-collective.mlir b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-collective.mlir index db7dbac0272347..fa0bc94a980eb0 100644 --- a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-collective.mlir +++ b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf-collective.mlir @@ -26,13 +26,13 @@ func.func @all_reduce_cross_replica_and_partition(%input: tensor) -> tensor // CHECK: "mhlo.all_reduce" // CHECK: mhlo.add // CHECK: mhlo.return - // CHECK-NEXT: channel_handle = #mhlo.channel_handle + // CHECK-NEXT: channel_handle = #mhlo.channel_handle // CHECK-SAME{LITERAL}: replica_groups = dense<[[0], [1]]> : tensor<2x1xi64> %0 = "tf.XlaAllReduce"(%input, %group_assignment) {reduce_op = "Add", mode = "CrossReplicaAndPartition"} : (tensor, tensor<2x1xi32>) -> tensor // CHECK: "mhlo.all_reduce" // CHECK: mhlo.add // CHECK: mhlo.return - // CHECK-NEXT: channel_handle = #mhlo.channel_handle + // CHECK-NEXT: channel_handle = #mhlo.channel_handle // CHECK-SAME{LITERAL}: replica_groups = dense<[[0], [1]]> : tensor<2x1xi64> %1 = "tf.XlaAllReduce"(%input, %group_assignment) {reduce_op = "Add", mode = "CrossReplicaAndPartition"} : (tensor, tensor<2x1xi32>) -> tensor %2 = "tf.Add"(%0, %1) : (tensor, tensor) -> tensor @@ -112,13 +112,13 @@ func.func @collective_reduce_v2(%input: tensor) -> tensor { // CHECK: "mhlo.all_reduce" // CHECK: mhlo.add // CHECK: mhlo.return - // CHECK-NEXT: channel_handle = #mhlo.channel_handle + // CHECK-NEXT: channel_handle = #mhlo.channel_handle // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> %0 = "tf.CollectiveReduceV2"(%input, %group_size, %group_key, %instance_key) {merge_op = "Add", final_op = "Id"} : (tensor, tensor, tensor, tensor) -> tensor // CHECK: "mhlo.all_reduce" // CHECK: mhlo.add // CHECK: mhlo.return - // CHECK-NEXT: channel_handle = #mhlo.channel_handle + // CHECK-NEXT: channel_handle = #mhlo.channel_handle // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> %1 = "tf.CollectiveReduceV2"(%input, %group_size, %group_key, %instance_key) {merge_op = "Add", final_op = "Id"} : (tensor, tensor, tensor, tensor) -> tensor %2 = "tf.Add"(%0, %1) : (tensor, tensor) -> tensor diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_collective.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_collective.cc index 4f355d5255c5cd..8d098c555765b1 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_collective.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_collective.cc @@ -359,8 +359,8 @@ class ConvertCollectiveReduceV2 void LegalizeTFCollective::runOnOperation() { // FIXME(b/226139061): Figure out a way to share the channel_id with - // send/recv Ops. - int64_t channel_id = 1; + // send/recv Ops. For now, start with a different range to avoid collision. + int64_t channel_id = 10000; auto module = getOperation(); MLIRContext* context = module->getContext(); From d43e620678747053acf2983818a2b16993e44cfe Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 28 Jul 2023 10:19:10 -0700 Subject: [PATCH 314/410] Some CustomCalls may wish their called_computation to be sharded, too. This plumbs in the same mechanism that is used for kWhile and kCondition for CustomCall. PiperOrigin-RevId: 551886137 --- .../xla/service/custom_call_sharding_helper.h | 7 +++++++ .../compiler/xla/service/sharding_propagation.cc | 11 +++++++++-- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/xla/service/custom_call_sharding_helper.h b/tensorflow/compiler/xla/service/custom_call_sharding_helper.h index 72a350340aa8d2..7999c797e59d2d 100644 --- a/tensorflow/compiler/xla/service/custom_call_sharding_helper.h +++ b/tensorflow/compiler/xla/service/custom_call_sharding_helper.h @@ -15,6 +15,7 @@ limitations under the License. #include #include +#include #include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_sharding.h" @@ -40,6 +41,12 @@ class CustomCallShardingHelper { // Returns if the instruction passed as parameter is a supported custom-call // for which the functions of this class are implemented. virtual bool IsCustomCallShardable(const HloInstruction* instruction) const; + // Returns the list of instructions in sub-computations that must be sharded + // in the same way as `instruction`. + virtual std::vector GetRelatedInstructions( + HloInstruction* instruction) const { + return {}; + } virtual ~CustomCallShardingHelper() = default; }; diff --git a/tensorflow/compiler/xla/service/sharding_propagation.cc b/tensorflow/compiler/xla/service/sharding_propagation.cc index 51b0864274ae3a..615c321befbadc 100644 --- a/tensorflow/compiler/xla/service/sharding_propagation.cc +++ b/tensorflow/compiler/xla/service/sharding_propagation.cc @@ -2557,7 +2557,7 @@ StatusOr ShardingPropagation::Run( // Instructions that are related through a computation and need to share the // same sharding. - auto get_related_instructions = [](HloInstruction* inst) { + auto get_related_instructions = [this](HloInstruction* inst) { if (inst->opcode() == HloOpcode::kWhile) { return std::vector{ inst, inst->while_body()->root_instruction(), @@ -2572,6 +2572,12 @@ StatusOr ShardingPropagation::Run( comps.push_back(c->root_instruction()); } return comps; + } else if (inst->opcode() == HloOpcode::kCustomCall) { + if (sharding_helper_ && sharding_helper_->IsCustomCallShardable(inst)) { + return sharding_helper_->GetRelatedInstructions(inst); + } else { + return std::vector{}; + } } else { CHECK(false); } @@ -2601,7 +2607,8 @@ StatusOr ShardingPropagation::Run( }; if (instruction->opcode() == HloOpcode::kConditional || - instruction->opcode() == HloOpcode::kWhile) { + instruction->opcode() == HloOpcode::kWhile || + instruction->opcode() == HloOpcode::kCustomCall) { propagate_to_instruction(instruction); } From 059fab84aae5d456f91113a9bba6d76acc33d76b Mon Sep 17 00:00:00 2001 From: Ken Franko Date: Fri, 28 Jul 2023 10:25:57 -0700 Subject: [PATCH 315/410] Allow the use of different sized program keys for _XlaSendFromHost and _XlaRecvAtHost. TPUCompile returns a program_key with 3 strings in a vector. XLACompile returns a program key with a single scalar string. PiperOrigin-RevId: 551887983 --- .../core/tpu/kernels/host_compute_ops.cc | 27 ++++++++++++++++--- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/tensorflow/core/tpu/kernels/host_compute_ops.cc b/tensorflow/core/tpu/kernels/host_compute_ops.cc index 005453a8be03c7..be02eb0d9109b4 100644 --- a/tensorflow/core/tpu/kernels/host_compute_ops.cc +++ b/tensorflow/core/tpu/kernels/host_compute_ops.cc @@ -102,16 +102,25 @@ class RecvAtHostOp : public AsyncOpKernel { VLOG(2) << " cpu_device_ = " << cpu_device_; } + // Depending on the compile op, the input could be either + // a scalar program key and a 1D vector of length 3 with program key at + // index 1. const Tensor& input = ctx->input(0); VLOG(2) << input.DebugString(); OP_REQUIRES_ASYNC( ctx, TensorShapeUtils::IsVector(input.shape()) && - input.shape().dim_size(0) == 3, + input.shape().dim_size(0) == 3 || + TensorShapeUtils::IsScalar(input.shape()), errors::InvalidArgument("Input shape ", input.shape().DebugString(), " is not a vector of length 3."), done); - const string rendezvous_key_base = input.vec()(1); + string rendezvous_key_base; + if (TensorShapeUtils::IsVector(input.shape())) { + rendezvous_key_base = input.vec()(1); + } else { + rendezvous_key_base = input.flat()(0); + } OP_REQUIRES_ASYNC( ctx, ctx->rendezvous() != nullptr, errors::Internal("Op kernel context needs to provide a rendezvous."), @@ -262,13 +271,23 @@ class SendFromHostOp : public OpKernel { const int num_send_inputs = ctx->num_inputs() - (device_ordinal_is_attr ? 1 : 2); const Tensor& key_input = ctx->input(num_send_inputs); + // Depending on the compile op, the key_input could be either + // a scalar program key and a 1D vector of length 3 with program key at + // index 1. OP_REQUIRES(ctx, TensorShapeUtils::IsVector(key_input.shape()) && - key_input.shape().dim_size(0) == 3, + key_input.shape().dim_size(0) == 3 || + TensorShapeUtils::IsScalar(key_input.shape()), errors::InvalidArgument("Key input shape ", key_input.shape().DebugString(), " is not a vector of length 3.")); - const string rendezvous_key_base = key_input.vec()(1); + string rendezvous_key_base; + if (TensorShapeUtils::IsVector(key_input.shape())) { + rendezvous_key_base = key_input.vec()(1); + } else { + rendezvous_key_base = key_input.flat()(0); + } + OP_REQUIRES( ctx, ctx->rendezvous() != nullptr, errors::Internal("Op kernel context needs to provide a rendezvous.")); From 1631dadbbfce0f1243b0035f07499057f3b0c805 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 28 Jul 2023 10:35:38 -0700 Subject: [PATCH 316/410] Integrate LLVM at llvm/llvm-project@cb924ddca514 Updates LLVM usage to match [cb924ddca514](https://github.com/llvm/llvm-project/commit/cb924ddca514) PiperOrigin-RevId: 551890570 --- .../tensorflow/analysis/resource_dataflow.h | 5 +- third_party/llvm/generated.patch | 1197 +++++++++++++++++ third_party/llvm/workspace.bzl | 4 +- third_party/triton/cl551490193.patch | 55 + third_party/triton/workspace.bzl | 1 + 5 files changed, 1258 insertions(+), 4 deletions(-) create mode 100644 third_party/triton/cl551490193.patch diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/resource_dataflow.h b/tensorflow/compiler/mlir/tensorflow/analysis/resource_dataflow.h index 4ddc3610adfe05..68f6fa2d44c763 100644 --- a/tensorflow/compiler/mlir/tensorflow/analysis/resource_dataflow.h +++ b/tensorflow/compiler/mlir/tensorflow/analysis/resource_dataflow.h @@ -55,11 +55,12 @@ struct ResourceConstructingOps { }; class ResourceDataflowAnalysis - : public dataflow::SparseDataFlowAnalysis< + : public dataflow::SparseForwardDataFlowAnalysis< dataflow::Lattice> { public: using StateT = dataflow::Lattice; - using dataflow::SparseDataFlowAnalysis::SparseDataFlowAnalysis; + using dataflow::SparseForwardDataFlowAnalysis< + StateT>::SparseForwardDataFlowAnalysis; ~ResourceDataflowAnalysis() override = default; void visitOperation(Operation *op, ArrayRef operands, diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch index 509398da979e83..0fc235f89ab7e5 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch @@ -1 +1,1198 @@ Auto generated patch. Do not edit or delete it, even if empty. +diff -ruN --strip-trailing-cr a/clang/lib/Sema/SemaCast.cpp b/clang/lib/Sema/SemaCast.cpp +--- a/clang/lib/Sema/SemaCast.cpp ++++ b/clang/lib/Sema/SemaCast.cpp +@@ -935,14 +935,6 @@ + << isClangCL; + } + +- // For a dynamic_cast to a final type, IR generation might emit a reference +- // to the vtable. +- if (DestRecord) { +- auto *DestDecl = DestRecord->getAsCXXRecordDecl(); +- if (DestDecl->isEffectivelyFinal()) +- Self.MarkVTableUsed(OpRange.getBegin(), DestDecl); +- } +- + // Done. Everything else is run-time checks. + Kind = CK_Dynamic; + } +diff -ruN --strip-trailing-cr a/clang/test/CodeGenCXX/dynamic-cast-exact.cpp b/clang/test/CodeGenCXX/dynamic-cast-exact.cpp +--- a/clang/test/CodeGenCXX/dynamic-cast-exact.cpp ++++ b/clang/test/CodeGenCXX/dynamic-cast-exact.cpp +@@ -76,12 +76,3 @@ + // CHECK: phi ptr [ %[[RESULT]], %[[LABEL_NOTNULL]] ], [ null, %[[LABEL_FAILED]] ] + return dynamic_cast(a); + } +- +-namespace GH64088 { +- // Ensure we mark the B vtable as used here, because we're going to emit a +- // reference to it. +- // CHECK: define {{.*}} @_ZN7GH640881BD0 +- struct A { virtual ~A(); }; +- struct B final : A { virtual ~B() = default; }; +- B *cast(A *p) { return dynamic_cast(p); } +-} +diff -ruN --strip-trailing-cr a/llvm/include/llvm/ProfileData/SampleProf.h b/llvm/include/llvm/ProfileData/SampleProf.h +--- a/llvm/include/llvm/ProfileData/SampleProf.h ++++ b/llvm/include/llvm/ProfileData/SampleProf.h +@@ -318,14 +318,6 @@ + + raw_ostream &operator<<(raw_ostream &OS, const LineLocation &Loc); + +-static inline hash_code hashFuncName(StringRef F) { +- // If function name is already MD5 string, do not hash again. +- uint64_t Hash; +- if (F.getAsInteger(10, Hash)) +- Hash = MD5Hash(F); +- return Hash; +-} +- + /// Representation of a single sample record. + /// + /// A sample record is represented by a positive integer value, which +@@ -638,13 +630,9 @@ + return getContextString(FullContext, false); + } + +- hash_code getHashCode() const { +- if (hasContext()) +- return hash_value(getContextFrames()); +- +- // For non-context function name, use its MD5 as hash value, so that it is +- // consistent with the profile map's key. +- return hashFuncName(getName()); ++ uint64_t getHashCode() const { ++ return hasContext() ? hash_value(getContextFrames()) ++ : hash_value(getName()); + } + + /// Set the name of the function and clear the current context. +@@ -722,12 +710,9 @@ + uint32_t Attributes; + }; + +-static inline hash_code hash_value(const SampleContext &Context) { +- return Context.getHashCode(); +-} +- +-inline raw_ostream &operator<<(raw_ostream &OS, const SampleContext &Context) { +- return OS << Context.toString(); ++static inline hash_code hash_value(const SampleContext &arg) { ++ return arg.hasContext() ? hash_value(arg.getContextFrames()) ++ : hash_value(arg.getName()); + } + + class FunctionSamples; +@@ -1221,9 +1206,6 @@ + return !(*this == Other); + } + +- template +- const T &getKey() const; +- + private: + /// CFG hash value for the function. + uint64_t FunctionHash = 0; +@@ -1287,176 +1269,12 @@ + const LocToLocMap *IRToProfileLocationMap = nullptr; + }; + +-template <> +-inline const SampleContext &FunctionSamples::getKey() const { +- return getContext(); +-} +- + raw_ostream &operator<<(raw_ostream &OS, const FunctionSamples &FS); + +-/// This class is a wrapper to associative container MapT using +-/// the hash value of the original key as the new key. This greatly improves the +-/// performance of insert and query operations especially when hash values of +-/// keys are available a priori, and reduces memory usage if KeyT has a large +-/// size. +-/// When performing any action, if an existing entry with a given key is found, +-/// and the interface "KeyT ValueT::getKey() const" to retrieve a value's +-/// original key exists, this class checks if the given key actually matches +-/// the existing entry's original key. If they do not match, this class behaves +-/// as if the entry did not exist (for insertion, this means the new value will +-/// replace the existing entry's value, as if it is newly inserted). If +-/// ValueT::getKey() is not available, all keys with the same hash value +-/// are considered equivalent (i.e. hash collision is silently ignored). Given +-/// such feature this class should only be used where it does not affect +-/// compilation correctness, for example, when loading a sample profile. +-/// Assuming the hashing algorithm is uniform, the probability of hash collision +-/// with 1,000,000 entries is +-/// (2^64)!/((2^64-1000000)!*(2^64)^1000000) ~= 3*10^-8. +-template